增加了多阶段模型参数的单独读取

This commit is contained in:
jinye_huang 2025-07-09 14:51:02 +08:00
parent 5fd1e2fd85
commit d4d23068e5
12 changed files with 95 additions and 22 deletions

View File

@ -2,9 +2,9 @@
"model": "qwen-plus", "model": "qwen-plus",
"api_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", "api_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "sk-bd5ee62703bc41fc9b8a55d748dc1eb8", "api_key": "sk-bd5ee62703bc41fc9b8a55d748dc1eb8",
"temperature": 0.8, "temperature": 0.3,
"top_p": 0.5, "top_p": 0.4,
"presence_penalty": 1.1, "presence_penalty": 1.2,
"timeout": 120, "timeout": 120,
"max_retries": 3 "max_retries": 3
} }

View File

@ -3,5 +3,15 @@
"content_user_prompt": "resource/prompt/generateContent/user.txt", "content_user_prompt": "resource/prompt/generateContent/user.txt",
"judger_system_prompt": "resource/prompt/judgeContent/system.txt", "judger_system_prompt": "resource/prompt/judgeContent/system.txt",
"judger_user_prompt": "resource/prompt/judgeContent/user.txt", "judger_user_prompt": "resource/prompt/judgeContent/user.txt",
"enable_content_judge": true "enable_content_judge": true,
"model": {
"temperature": 0.3,
"top_p": 0.5,
"presence_penalty": 1.2
},
"judger_model": {
"temperature": 0.2,
"top_p": 0.3,
"presence_penalty": 0.8
}
} }

View File

@ -16,7 +16,7 @@
"refer_list": [ "refer_list": [
{ "path": "resource/prompt/Refer/2025各月节日宣传节点时间表.md", "sampling_rate": 1, "step": "topic" }, { "path": "resource/prompt/Refer/2025各月节日宣传节点时间表.md", "sampling_rate": 1, "step": "topic" },
{ "path": "resource/prompt/Refer/标题参考格式.json", "sampling_rate": 0.25, "step": "content" }, { "path": "resource/prompt/Refer/标题参考格式.json", "sampling_rate": 0.25, "step": "content" },
{ "path": "resource/prompt/Refer/正文范文参考.json", "sampling_rate": 0.25, "step": "content" } { "path": "resource/prompt/Refer/正文范文参考.json", "sampling_rate": 0.5, "step": "content" }
] ]
}, },
"object": { "object": {

View File

@ -1,7 +1,11 @@
{ {
"topic_system_prompt": "resource/prompt/generateTopics/system.txt", "topic_system_prompt": "resource/prompt/generateTopics/system.txt",
"topic_user_prompt": "resource/prompt/generateTopics/user.txt", "topic_user_prompt": "resource/prompt/generateTopics/user.txt",
"model": {}, "model": {
"temperature": 0.2,
"top_p": 0.3,
"presence_penalty": 1.5
},
"topic": { "topic": {
"date": "2024-07-20", "date": "2024-07-20",
"num": 5, "num": 5,

View File

@ -45,7 +45,9 @@ class AIAgent:
# self.tokenizer = tiktoken.get_encoding("cl100k_base") # self.tokenizer = tiktoken.get_encoding("cl100k_base")
async def generate_text( async def generate_text(
self, system_prompt: str, user_prompt: str, use_stream: bool = False self, system_prompt: str, user_prompt: str, use_stream: bool = False,
temperature: Optional[float] = None, top_p: Optional[float] = None,
presence_penalty: Optional[float] = None, stage: str = ""
) -> Tuple[str, int, int, float]: ) -> Tuple[str, int, int, float]:
""" """
生成文本 (支持流式和非流式) 生成文本 (支持流式和非流式)
@ -54,6 +56,10 @@ class AIAgent:
system_prompt: 系统提示 system_prompt: 系统提示
user_prompt: 用户提示 user_prompt: 用户提示
use_stream: 是否流式返回 use_stream: 是否流式返回
temperature: 温度参数控制随机性
top_p: Top-p采样参数
presence_penalty: 存在惩罚参数
stage: 当前所处阶段用于日志记录
Returns: Returns:
一个元组 (generated_text, input_tokens, output_tokens, time_cost) 一个元组 (generated_text, input_tokens, output_tokens, time_cost)
@ -63,8 +69,17 @@ class AIAgent:
{"role": "user", "content": user_prompt} {"role": "user", "content": user_prompt}
] ]
# 使用传入的参数或默认配置
temp = temperature if temperature is not None else self.config.temperature
tp = top_p if top_p is not None else self.config.top_p
pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty
# 记录使用的模型参数
stage_info = f"[{stage}]" if stage else ""
logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}")
input_tokens = self.count_tokens(system_prompt + user_prompt) input_tokens = self.count_tokens(system_prompt + user_prompt)
logger.info(f"开始生成任务... 输入token数: {input_tokens}") logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}")
last_exception = None last_exception = None
backoff_time = 1.0 # Start with 1 second backoff_time = 1.0 # Start with 1 second
@ -75,9 +90,9 @@ class AIAgent:
response = await self.client.chat.completions.create( response = await self.client.chat.completions.create(
model=self.config.model, model=self.config.model,
messages=messages, messages=messages,
temperature=self.config.temperature, temperature=temp,
top_p=self.config.top_p, top_p=tp,
presence_penalty=self.config.presence_penalty, presence_penalty=pp,
stream=use_stream stream=use_stream
) )
@ -90,26 +105,27 @@ class AIAgent:
# 简化处理:流式模式下,我们返回拼接后的完整文本 # 简化处理:流式模式下,我们返回拼接后的完整文本
full_text = "".join([chunk for chunk in self._process_stream(response)]) full_text = "".join([chunk for chunk in self._process_stream(response)])
output_tokens = self.count_tokens(full_text) output_tokens = self.count_tokens(full_text)
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
return full_text, input_tokens, output_tokens, time_cost return full_text, input_tokens, output_tokens, time_cost
else: else:
output_text = response.choices[0].message.content.strip() output_text = response.choices[0].message.content.strip()
output_tokens = self.count_tokens(output_text) output_tokens = self.count_tokens(output_text)
logger.info(f"任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}") logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
return output_text, input_tokens, output_tokens, time_cost return output_text, input_tokens, output_tokens, time_cost
except (APITimeoutError, APIConnectionError) as e: except (APITimeoutError, APIConnectionError) as e:
last_exception = RetryableError(f"AI模型连接或超时错误: {e}") last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
logger.warning(f"尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. " logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
f"将在 {backoff_time:.1f} 秒后重试...") f"将在 {backoff_time:.1f} 秒后重试...")
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 2 # Exponential backoff backoff_time *= 2 # Exponential backoff
except (RateLimitError, APIStatusError) as e: except (RateLimitError, APIStatusError) as e:
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}") last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
logger.error(f"发生不可重试的API错误: {last_exception}") logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}")
break # Do not retry on these errors break # Do not retry on these errors
except Exception as e: except Exception as e:
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}") last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
logger.error(f"发生未知错误: {last_exception}\n{traceback.format_exc()}") logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}")
break break
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception

View File

@ -10,7 +10,7 @@ import json
from typing import Dict, Any from typing import Dict, Any
from core.ai import AIAgent from core.ai import AIAgent
from core.config import ConfigManager, GenerateTopicConfig from core.config import ConfigManager, GenerateTopicConfig, GenerateContentConfig
from utils.prompts import ContentPromptBuilder from utils.prompts import ContentPromptBuilder
from utils.file_io import OutputManager from utils.file_io import OutputManager
@ -22,7 +22,9 @@ class ContentGenerator:
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager): def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
self.ai_agent = ai_agent self.ai_agent = ai_agent
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig) self.config_manager = config_manager
self.topic_config = config_manager.get_config('topic_gen', GenerateTopicConfig)
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
self.output_manager = output_manager self.output_manager = output_manager
self.prompt_builder = ContentPromptBuilder(config_manager) self.prompt_builder = ContentPromptBuilder(config_manager)
@ -49,12 +51,25 @@ class ContentGenerator:
self.output_manager.save_text(system_prompt, "content_system_prompt.txt", subdir=output_dir.name) self.output_manager.save_text(system_prompt, "content_system_prompt.txt", subdir=output_dir.name)
self.output_manager.save_text(user_prompt, "content_user_prompt.txt", subdir=output_dir.name) self.output_manager.save_text(user_prompt, "content_user_prompt.txt", subdir=output_dir.name)
# 获取模型参数
model_params = {}
if hasattr(self.content_config, 'model') and isinstance(self.content_config.model, dict):
model_params = {
'temperature': self.content_config.model.get('temperature'),
'top_p': self.content_config.model.get('top_p'),
'presence_penalty': self.content_config.model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 2. 调用AI # 2. 调用AI
try: try:
raw_result, _, _, _ = await self.ai_agent.generate_text( raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt, system_prompt=system_prompt,
user_prompt=user_prompt, user_prompt=user_prompt,
use_stream=False use_stream=False,
stage="内容生成",
**model_params
) )
self.output_manager.save_text(raw_result, "content_raw_response.txt", subdir=output_dir.name) self.output_manager.save_text(raw_result, "content_raw_response.txt", subdir=output_dir.name)
except Exception as e: except Exception as e:

View File

@ -11,7 +11,7 @@ from json_repair import loads as json_repair_loads
from typing import Dict, Any from typing import Dict, Any
from core.ai import AIAgent from core.ai import AIAgent
from core.config import ConfigManager, GenerateTopicConfig from core.config import ConfigManager, GenerateTopicConfig, GenerateContentConfig
from utils.prompts import JudgerPromptBuilder from utils.prompts import JudgerPromptBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +30,9 @@ class ContentJudger:
output_manager: 输出管理器用于保存提示词和响应 output_manager: 输出管理器用于保存提示词和响应
""" """
self.ai_agent = ai_agent self.ai_agent = ai_agent
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig) self.config_manager = config_manager
self.topic_config = config_manager.get_config('topic_gen', GenerateTopicConfig)
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
self.prompt_builder = JudgerPromptBuilder(config_manager) self.prompt_builder = JudgerPromptBuilder(config_manager)
self.output_manager = output_manager self.output_manager = output_manager
@ -63,12 +65,25 @@ class ContentJudger:
self.output_manager.save_text(system_prompt, f"{topic_dir}/judger_system_prompt.txt") self.output_manager.save_text(system_prompt, f"{topic_dir}/judger_system_prompt.txt")
self.output_manager.save_text(user_prompt, f"{topic_dir}/judger_user_prompt.txt") self.output_manager.save_text(user_prompt, f"{topic_dir}/judger_user_prompt.txt")
# 获取模型参数
model_params = {}
if hasattr(self.content_config, 'judger_model') and isinstance(self.content_config.judger_model, dict):
model_params = {
'temperature': self.content_config.judger_model.get('temperature'),
'top_p': self.content_config.judger_model.get('top_p'),
'presence_penalty': self.content_config.judger_model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 2. 调用AI进行审核 # 2. 调用AI进行审核
try: try:
raw_result, _, _, _ = await self.ai_agent.generate_text( raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt, system_prompt=system_prompt,
user_prompt=user_prompt, user_prompt=user_prompt,
use_stream=False use_stream=False,
stage="内容审核",
**model_params
) )
# 保存原始响应 # 保存原始响应

View File

@ -55,12 +55,25 @@ class TopicGenerator:
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt") self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt") self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
# 获取模型参数
model_params = {}
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
model_params = {
'temperature': self.config.model.get('temperature'),
'top_p': self.config.model.get('top_p'),
'presence_penalty': self.config.model.get('presence_penalty')
}
# 移除None值
model_params = {k: v for k, v in model_params.items() if v is not None}
# 2. 调用AI生成 # 2. 调用AI生成
try: try:
raw_result, _, _, _ = await self.ai_agent.generate_text( raw_result, _, _, _ = await self.ai_agent.generate_text(
system_prompt=system_prompt, system_prompt=system_prompt,
user_prompt=user_prompt, user_prompt=user_prompt,
use_stream=False # 选题生成通常不需要流式输出 use_stream=False, # 选题生成通常不需要流式输出
stage="选题生成",
**model_params
) )
self.output_manager.save_text(raw_result, "topics_raw_response.txt") self.output_manager.save_text(raw_result, "topics_raw_response.txt")
except Exception as e: except Exception as e: