增加了多阶段模型参数的单独读取
This commit is contained in:
parent
5fd1e2fd85
commit
d4d23068e5
@ -2,9 +2,9 @@
|
|||||||
"model": "qwen-plus",
|
"model": "qwen-plus",
|
||||||
"api_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
"api_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
"api_key": "sk-bd5ee62703bc41fc9b8a55d748dc1eb8",
|
"api_key": "sk-bd5ee62703bc41fc9b8a55d748dc1eb8",
|
||||||
"temperature": 0.8,
|
"temperature": 0.3,
|
||||||
"top_p": 0.5,
|
"top_p": 0.4,
|
||||||
"presence_penalty": 1.1,
|
"presence_penalty": 1.2,
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"max_retries": 3
|
"max_retries": 3
|
||||||
}
|
}
|
||||||
@ -3,5 +3,15 @@
|
|||||||
"content_user_prompt": "resource/prompt/generateContent/user.txt",
|
"content_user_prompt": "resource/prompt/generateContent/user.txt",
|
||||||
"judger_system_prompt": "resource/prompt/judgeContent/system.txt",
|
"judger_system_prompt": "resource/prompt/judgeContent/system.txt",
|
||||||
"judger_user_prompt": "resource/prompt/judgeContent/user.txt",
|
"judger_user_prompt": "resource/prompt/judgeContent/user.txt",
|
||||||
"enable_content_judge": true
|
"enable_content_judge": true,
|
||||||
|
"model": {
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 0.5,
|
||||||
|
"presence_penalty": 1.2
|
||||||
|
},
|
||||||
|
"judger_model": {
|
||||||
|
"temperature": 0.2,
|
||||||
|
"top_p": 0.3,
|
||||||
|
"presence_penalty": 0.8
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -16,7 +16,7 @@
|
|||||||
"refer_list": [
|
"refer_list": [
|
||||||
{ "path": "resource/prompt/Refer/2025各月节日宣传节点时间表.md", "sampling_rate": 1, "step": "topic" },
|
{ "path": "resource/prompt/Refer/2025各月节日宣传节点时间表.md", "sampling_rate": 1, "step": "topic" },
|
||||||
{ "path": "resource/prompt/Refer/标题参考格式.json", "sampling_rate": 0.25, "step": "content" },
|
{ "path": "resource/prompt/Refer/标题参考格式.json", "sampling_rate": 0.25, "step": "content" },
|
||||||
{ "path": "resource/prompt/Refer/正文范文参考.json", "sampling_rate": 0.25, "step": "content" }
|
{ "path": "resource/prompt/Refer/正文范文参考.json", "sampling_rate": 0.5, "step": "content" }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"object": {
|
"object": {
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
{
|
{
|
||||||
"topic_system_prompt": "resource/prompt/generateTopics/system.txt",
|
"topic_system_prompt": "resource/prompt/generateTopics/system.txt",
|
||||||
"topic_user_prompt": "resource/prompt/generateTopics/user.txt",
|
"topic_user_prompt": "resource/prompt/generateTopics/user.txt",
|
||||||
"model": {},
|
"model": {
|
||||||
|
"temperature": 0.2,
|
||||||
|
"top_p": 0.3,
|
||||||
|
"presence_penalty": 1.5
|
||||||
|
},
|
||||||
"topic": {
|
"topic": {
|
||||||
"date": "2024-07-20",
|
"date": "2024-07-20",
|
||||||
"num": 5,
|
"num": 5,
|
||||||
|
|||||||
Binary file not shown.
@ -45,7 +45,9 @@ class AIAgent:
|
|||||||
# self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
# self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
async def generate_text(
|
async def generate_text(
|
||||||
self, system_prompt: str, user_prompt: str, use_stream: bool = False
|
self, system_prompt: str, user_prompt: str, use_stream: bool = False,
|
||||||
|
temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None, stage: str = ""
|
||||||
) -> Tuple[str, int, int, float]:
|
) -> Tuple[str, int, int, float]:
|
||||||
"""
|
"""
|
||||||
生成文本 (支持流式和非流式)
|
生成文本 (支持流式和非流式)
|
||||||
@ -54,6 +56,10 @@ class AIAgent:
|
|||||||
system_prompt: 系统提示
|
system_prompt: 系统提示
|
||||||
user_prompt: 用户提示
|
user_prompt: 用户提示
|
||||||
use_stream: 是否流式返回
|
use_stream: 是否流式返回
|
||||||
|
temperature: 温度参数,控制随机性
|
||||||
|
top_p: Top-p采样参数
|
||||||
|
presence_penalty: 存在惩罚参数
|
||||||
|
stage: 当前所处阶段,用于日志记录
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
一个元组 (generated_text, input_tokens, output_tokens, time_cost)
|
一个元组 (generated_text, input_tokens, output_tokens, time_cost)
|
||||||
@ -63,8 +69,17 @@ class AIAgent:
|
|||||||
{"role": "user", "content": user_prompt}
|
{"role": "user", "content": user_prompt}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 使用传入的参数或默认配置
|
||||||
|
temp = temperature if temperature is not None else self.config.temperature
|
||||||
|
tp = top_p if top_p is not None else self.config.top_p
|
||||||
|
pp = presence_penalty if presence_penalty is not None else self.config.presence_penalty
|
||||||
|
|
||||||
|
# 记录使用的模型参数
|
||||||
|
stage_info = f"[{stage}]" if stage else ""
|
||||||
|
logger.info(f"{stage_info} 使用模型参数: temperature={temp:.2f}, top_p={tp:.2f}, presence_penalty={pp:.2f}")
|
||||||
|
|
||||||
input_tokens = self.count_tokens(system_prompt + user_prompt)
|
input_tokens = self.count_tokens(system_prompt + user_prompt)
|
||||||
logger.info(f"开始生成任务... 输入token数: {input_tokens}")
|
logger.info(f"{stage_info} 开始生成任务... 输入token数: {input_tokens}")
|
||||||
|
|
||||||
last_exception = None
|
last_exception = None
|
||||||
backoff_time = 1.0 # Start with 1 second
|
backoff_time = 1.0 # Start with 1 second
|
||||||
@ -75,9 +90,9 @@ class AIAgent:
|
|||||||
response = await self.client.chat.completions.create(
|
response = await self.client.chat.completions.create(
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.config.temperature,
|
temperature=temp,
|
||||||
top_p=self.config.top_p,
|
top_p=tp,
|
||||||
presence_penalty=self.config.presence_penalty,
|
presence_penalty=pp,
|
||||||
stream=use_stream
|
stream=use_stream
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,26 +105,27 @@ class AIAgent:
|
|||||||
# 简化处理:流式模式下,我们返回拼接后的完整文本
|
# 简化处理:流式模式下,我们返回拼接后的完整文本
|
||||||
full_text = "".join([chunk for chunk in self._process_stream(response)])
|
full_text = "".join([chunk for chunk in self._process_stream(response)])
|
||||||
output_tokens = self.count_tokens(full_text)
|
output_tokens = self.count_tokens(full_text)
|
||||||
|
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
||||||
return full_text, input_tokens, output_tokens, time_cost
|
return full_text, input_tokens, output_tokens, time_cost
|
||||||
else:
|
else:
|
||||||
output_text = response.choices[0].message.content.strip()
|
output_text = response.choices[0].message.content.strip()
|
||||||
output_tokens = self.count_tokens(output_text)
|
output_tokens = self.count_tokens(output_text)
|
||||||
logger.info(f"任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
logger.info(f"{stage_info} 任务完成,耗时 {time_cost:.2f} 秒. 输出token数: {output_tokens}")
|
||||||
return output_text, input_tokens, output_tokens, time_cost
|
return output_text, input_tokens, output_tokens, time_cost
|
||||||
|
|
||||||
except (APITimeoutError, APIConnectionError) as e:
|
except (APITimeoutError, APIConnectionError) as e:
|
||||||
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
|
last_exception = RetryableError(f"AI模型连接或超时错误: {e}")
|
||||||
logger.warning(f"尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
|
logger.warning(f"{stage_info} 尝试 {attempt + 1}/{self.config.max_retries} 失败: {last_exception}. "
|
||||||
f"将在 {backoff_time:.1f} 秒后重试...")
|
f"将在 {backoff_time:.1f} 秒后重试...")
|
||||||
time.sleep(backoff_time)
|
time.sleep(backoff_time)
|
||||||
backoff_time *= 2 # Exponential backoff
|
backoff_time *= 2 # Exponential backoff
|
||||||
except (RateLimitError, APIStatusError) as e:
|
except (RateLimitError, APIStatusError) as e:
|
||||||
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
|
last_exception = NonRetryableError(f"AI模型API错误 (不可重试): {e}")
|
||||||
logger.error(f"发生不可重试的API错误: {last_exception}")
|
logger.error(f"{stage_info} 发生不可重试的API错误: {last_exception}")
|
||||||
break # Do not retry on these errors
|
break # Do not retry on these errors
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
|
last_exception = AIModelError(f"调用AI模型时发生未知错误: {e}")
|
||||||
logger.error(f"发生未知错误: {last_exception}\n{traceback.format_exc()}")
|
logger.error(f"{stage_info} 发生未知错误: {last_exception}\n{traceback.format_exc()}")
|
||||||
break
|
break
|
||||||
|
|
||||||
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
|
raise AIModelError(f"AI模型调用在 {self.config.max_retries} 次重试后失败") from last_exception
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -10,7 +10,7 @@ import json
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from core.ai import AIAgent
|
from core.ai import AIAgent
|
||||||
from core.config import ConfigManager, GenerateTopicConfig
|
from core.config import ConfigManager, GenerateTopicConfig, GenerateContentConfig
|
||||||
from utils.prompts import ContentPromptBuilder
|
from utils.prompts import ContentPromptBuilder
|
||||||
from utils.file_io import OutputManager
|
from utils.file_io import OutputManager
|
||||||
|
|
||||||
@ -22,7 +22,9 @@ class ContentGenerator:
|
|||||||
|
|
||||||
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
|
def __init__(self, ai_agent: AIAgent, config_manager: ConfigManager, output_manager: OutputManager):
|
||||||
self.ai_agent = ai_agent
|
self.ai_agent = ai_agent
|
||||||
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
self.config_manager = config_manager
|
||||||
|
self.topic_config = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
||||||
|
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
|
||||||
self.output_manager = output_manager
|
self.output_manager = output_manager
|
||||||
self.prompt_builder = ContentPromptBuilder(config_manager)
|
self.prompt_builder = ContentPromptBuilder(config_manager)
|
||||||
|
|
||||||
@ -49,12 +51,25 @@ class ContentGenerator:
|
|||||||
self.output_manager.save_text(system_prompt, "content_system_prompt.txt", subdir=output_dir.name)
|
self.output_manager.save_text(system_prompt, "content_system_prompt.txt", subdir=output_dir.name)
|
||||||
self.output_manager.save_text(user_prompt, "content_user_prompt.txt", subdir=output_dir.name)
|
self.output_manager.save_text(user_prompt, "content_user_prompt.txt", subdir=output_dir.name)
|
||||||
|
|
||||||
|
# 获取模型参数
|
||||||
|
model_params = {}
|
||||||
|
if hasattr(self.content_config, 'model') and isinstance(self.content_config.model, dict):
|
||||||
|
model_params = {
|
||||||
|
'temperature': self.content_config.model.get('temperature'),
|
||||||
|
'top_p': self.content_config.model.get('top_p'),
|
||||||
|
'presence_penalty': self.content_config.model.get('presence_penalty')
|
||||||
|
}
|
||||||
|
# 移除None值
|
||||||
|
model_params = {k: v for k, v in model_params.items() if v is not None}
|
||||||
|
|
||||||
# 2. 调用AI
|
# 2. 调用AI
|
||||||
try:
|
try:
|
||||||
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
use_stream=False
|
use_stream=False,
|
||||||
|
stage="内容生成",
|
||||||
|
**model_params
|
||||||
)
|
)
|
||||||
self.output_manager.save_text(raw_result, "content_raw_response.txt", subdir=output_dir.name)
|
self.output_manager.save_text(raw_result, "content_raw_response.txt", subdir=output_dir.name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from json_repair import loads as json_repair_loads
|
|||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from core.ai import AIAgent
|
from core.ai import AIAgent
|
||||||
from core.config import ConfigManager, GenerateTopicConfig
|
from core.config import ConfigManager, GenerateTopicConfig, GenerateContentConfig
|
||||||
from utils.prompts import JudgerPromptBuilder
|
from utils.prompts import JudgerPromptBuilder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -30,7 +30,9 @@ class ContentJudger:
|
|||||||
output_manager: 输出管理器,用于保存提示词和响应
|
output_manager: 输出管理器,用于保存提示词和响应
|
||||||
"""
|
"""
|
||||||
self.ai_agent = ai_agent
|
self.ai_agent = ai_agent
|
||||||
self.config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
self.config_manager = config_manager
|
||||||
|
self.topic_config = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
||||||
|
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
|
||||||
self.prompt_builder = JudgerPromptBuilder(config_manager)
|
self.prompt_builder = JudgerPromptBuilder(config_manager)
|
||||||
self.output_manager = output_manager
|
self.output_manager = output_manager
|
||||||
|
|
||||||
@ -63,12 +65,25 @@ class ContentJudger:
|
|||||||
self.output_manager.save_text(system_prompt, f"{topic_dir}/judger_system_prompt.txt")
|
self.output_manager.save_text(system_prompt, f"{topic_dir}/judger_system_prompt.txt")
|
||||||
self.output_manager.save_text(user_prompt, f"{topic_dir}/judger_user_prompt.txt")
|
self.output_manager.save_text(user_prompt, f"{topic_dir}/judger_user_prompt.txt")
|
||||||
|
|
||||||
|
# 获取模型参数
|
||||||
|
model_params = {}
|
||||||
|
if hasattr(self.content_config, 'judger_model') and isinstance(self.content_config.judger_model, dict):
|
||||||
|
model_params = {
|
||||||
|
'temperature': self.content_config.judger_model.get('temperature'),
|
||||||
|
'top_p': self.content_config.judger_model.get('top_p'),
|
||||||
|
'presence_penalty': self.content_config.judger_model.get('presence_penalty')
|
||||||
|
}
|
||||||
|
# 移除None值
|
||||||
|
model_params = {k: v for k, v in model_params.items() if v is not None}
|
||||||
|
|
||||||
# 2. 调用AI进行审核
|
# 2. 调用AI进行审核
|
||||||
try:
|
try:
|
||||||
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
use_stream=False
|
use_stream=False,
|
||||||
|
stage="内容审核",
|
||||||
|
**model_params
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存原始响应
|
# 保存原始响应
|
||||||
|
|||||||
@ -55,12 +55,25 @@ class TopicGenerator:
|
|||||||
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
|
self.output_manager.save_text(system_prompt, "topic_system_prompt.txt")
|
||||||
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
|
self.output_manager.save_text(user_prompt, "topic_user_prompt.txt")
|
||||||
|
|
||||||
|
# 获取模型参数
|
||||||
|
model_params = {}
|
||||||
|
if hasattr(self.config, 'model') and isinstance(self.config.model, dict):
|
||||||
|
model_params = {
|
||||||
|
'temperature': self.config.model.get('temperature'),
|
||||||
|
'top_p': self.config.model.get('top_p'),
|
||||||
|
'presence_penalty': self.config.model.get('presence_penalty')
|
||||||
|
}
|
||||||
|
# 移除None值
|
||||||
|
model_params = {k: v for k, v in model_params.items() if v is not None}
|
||||||
|
|
||||||
# 2. 调用AI生成
|
# 2. 调用AI生成
|
||||||
try:
|
try:
|
||||||
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
raw_result, _, _, _ = await self.ai_agent.generate_text(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
use_stream=False # 选题生成通常不需要流式输出
|
use_stream=False, # 选题生成通常不需要流式输出
|
||||||
|
stage="选题生成",
|
||||||
|
**model_params
|
||||||
)
|
)
|
||||||
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
|
self.output_manager.save_text(raw_result, "topics_raw_response.txt")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user