mainpy load审核器
This commit is contained in:
parent
0753c733db
commit
8dc8f32b87
6
main.py
6
main.py
@ -89,7 +89,8 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
||||
prompts_config=config.get("prompts_config"), # 新的配置方式
|
||||
resource_dir_config=config.get("resource_dir", []),
|
||||
topic_gen_num=config.get("num", 1), # Topic gen num/date used by topic prompts
|
||||
topic_gen_date=config.get("date", "")
|
||||
topic_gen_date=config.get("date", ""),
|
||||
content_judger_system_prompt_path=config.get("content_judger_system_prompt") # 添加内容审核系统提示词路径
|
||||
)
|
||||
logging.info("PromptManager instance created for Step 2.")
|
||||
except KeyError as e:
|
||||
@ -137,7 +138,8 @@ def generate_content_and_posters_step(config, run_id, topics_list, output_handle
|
||||
variants=content_variants,
|
||||
temperature=content_temp,
|
||||
top_p=content_top_p,
|
||||
presence_penalty=content_presence_penalty
|
||||
presence_penalty=content_presence_penalty,
|
||||
enable_content_judge=config.get("enable_content_judge", False)
|
||||
)
|
||||
|
||||
# if tweet_content_list: # generate_content_for_topic 现在返回 bool
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"date": "5月15日",
|
||||
"num": 5,
|
||||
"num": 2,
|
||||
"variants": 20,
|
||||
"topic_temperature": 0.2,
|
||||
"topic_top_p": 0.3,
|
||||
@ -8,13 +8,15 @@
|
||||
"content_temperature": 0.3,
|
||||
"content_top_p": 0.5,
|
||||
"content_presence_penalty": 1.2,
|
||||
"model": "qwen3-30B-A3B",
|
||||
"model": "qwenQWQ",
|
||||
"api_url": "http://localhost:8000/v1/",
|
||||
"api_key": "EMPTY",
|
||||
"enable_content_judge":true,
|
||||
"topic_system_prompt": "./SelectPrompt/systemPrompt.txt",
|
||||
"topic_user_prompt": "./SelectPrompt/userPrompt.txt",
|
||||
"content_system_prompt": "./genPrompts/systemPrompt.txt",
|
||||
"poster_content_system_prompt": "./genPrompts/poster_content_systemPrompt.txt",
|
||||
"content_judger_system_prompt": "./genPrompts/judgerSystemPrompt.txt",
|
||||
"prompts_config": [
|
||||
{
|
||||
"type": "Style",
|
||||
|
||||
@ -28,6 +28,7 @@ from utils import content_generator as core_contentGen
|
||||
from core import poster_gen as core_posterGen
|
||||
from core import simple_collage as core_simple_collage
|
||||
from .output_handler import OutputHandler # <-- 添加导入
|
||||
from utils.content_judger import ContentJudger # <-- 添加ContentJudger导入
|
||||
|
||||
class tweetTopic:
|
||||
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
|
||||
@ -372,7 +373,8 @@ def generate_content_for_topic(ai_agent: AI_Agent,
|
||||
variants: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
presence_penalty: float):
|
||||
presence_penalty: float,
|
||||
enable_content_judge: bool):
|
||||
"""Generates all content variants for a single topic item and uses OutputHandler.
|
||||
|
||||
Args:
|
||||
@ -384,6 +386,7 @@ def generate_content_for_topic(ai_agent: AI_Agent,
|
||||
output_handler: An instance of OutputHandler to process results.
|
||||
variants: Number of variants to generate.
|
||||
temperature, top_p, presence_penalty: AI generation parameters.
|
||||
enable_content_judge: Whether to enable content judge.
|
||||
Returns:
|
||||
bool: True if at least one variant was successfully generated and handled, False otherwise.
|
||||
"""
|
||||
@ -391,6 +394,81 @@ def generate_content_for_topic(ai_agent: AI_Agent,
|
||||
success_flag = False # Track if any variant succeeded
|
||||
# 使用传入的 variants 参数
|
||||
# variants = config.get("variants", 1)
|
||||
|
||||
# 如果启用了内容审核,获取产品资料
|
||||
product_info = None
|
||||
content_judger = None
|
||||
if enable_content_judge:
|
||||
logging.info(f"内容审核功能已启用,准备获取产品资料...")
|
||||
# 从topic_item中获取产品名称和对象名称
|
||||
product_name = topic_item.get("product", "")
|
||||
object_name = topic_item.get("object", "")
|
||||
|
||||
# 组合获取产品资料
|
||||
product_info = ""
|
||||
|
||||
# 获取对象信息
|
||||
if object_name:
|
||||
# 通过PromptManager获取对象和产品资料
|
||||
# 这一部分逻辑来自PromptManager.get_content_prompts中对object_info的构建
|
||||
found_object_info = False
|
||||
all_description_files = []
|
||||
|
||||
# 从resource_dir_config搜集所有可能的资源文件
|
||||
for dir_info in prompt_manager.resource_dir_config:
|
||||
if dir_info.get("type") in ["Object", "Description"]:
|
||||
all_description_files.extend(dir_info.get("file_path", []))
|
||||
|
||||
# 尝试精确匹配对象资料
|
||||
for file_path in all_description_files:
|
||||
if object_name in os.path.basename(file_path):
|
||||
from utils.resource_loader import ResourceLoader
|
||||
info = ResourceLoader.load_file_content(file_path)
|
||||
if info:
|
||||
product_info += f"Object: {object_name}\n{info}\n\n"
|
||||
logging.info(f"为内容审核找到对象'{object_name}'的资源文件: {file_path}")
|
||||
found_object_info = True
|
||||
break
|
||||
|
||||
# 如果未找到对象资料,记录警告但继续处理
|
||||
if not found_object_info:
|
||||
logging.warning(f"未能为内容审核找到对象'{object_name}'的资源文件")
|
||||
|
||||
# 获取产品信息
|
||||
if product_name:
|
||||
found_product_info = False
|
||||
all_product_files = []
|
||||
|
||||
# 搜集所有可能的产品资源文件
|
||||
for dir_info in prompt_manager.resource_dir_config:
|
||||
if dir_info.get("type") == "Product":
|
||||
all_product_files.extend(dir_info.get("file_path", []))
|
||||
|
||||
# 尝试精确匹配产品资料
|
||||
for file_path in all_product_files:
|
||||
if product_name in os.path.basename(file_path):
|
||||
from utils.resource_loader import ResourceLoader
|
||||
info = ResourceLoader.load_file_content(file_path)
|
||||
if info:
|
||||
product_info += f"Product: {product_name}\n{info}\n\n"
|
||||
logging.info(f"为内容审核找到产品'{product_name}'的资源文件: {file_path}")
|
||||
found_product_info = True
|
||||
break
|
||||
|
||||
# 如果未找到产品资料,记录警告但继续处理
|
||||
if not found_product_info:
|
||||
logging.warning(f"未能为内容审核找到产品'{product_name}'的资源文件")
|
||||
|
||||
# 如果成功获取产品资料,初始化ContentJudger
|
||||
if product_info:
|
||||
logging.info("成功获取产品资料,初始化ContentJudger...")
|
||||
# 从配置中读取系统提示词路径(脚本级别无法直接获取,需要传递)
|
||||
# 使用ai_agent的model_name或api_url判断是否使用主AI模型,避免额外资源占用
|
||||
content_judger_system_prompt_path = prompt_manager._system_prompt_cache.get("judger_system_prompt")
|
||||
content_judger = ContentJudger(ai_agent, system_prompt_path=content_judger_system_prompt_path)
|
||||
else:
|
||||
logging.warning("未能获取产品资料,内容审核功能将被跳过")
|
||||
enable_content_judge = False
|
||||
|
||||
for j in range(variants):
|
||||
variant_index = j + 1
|
||||
@ -417,6 +495,37 @@ def generate_content_for_topic(ai_agent: AI_Agent,
|
||||
# 简化检查,只要content_json不是None就处理它
|
||||
# 即使是空标题和内容也是有效的结果
|
||||
if content_json is not None:
|
||||
# 进行内容审核(如果启用且ContentJudger已初始化)
|
||||
if enable_content_judge and content_judger and product_info:
|
||||
logging.info(f" 对Topic {topic_index}, Variant {variant_index}进行内容审核...")
|
||||
|
||||
# 准备审核内容
|
||||
content_to_judge = f"""title: {content_json.get('title', '')}
|
||||
|
||||
content: {content_json.get('content', '')}
|
||||
"""
|
||||
|
||||
# 调用ContentJudger进行审核
|
||||
try:
|
||||
judged_result = content_judger.judge_content(product_info, content_to_judge)
|
||||
if judged_result and isinstance(judged_result, dict):
|
||||
if "title" in judged_result and "content" in judged_result:
|
||||
# 使用审核后的内容替换原内容
|
||||
logging.info(f" 内容审核成功,使用审核后的内容替换原内容")
|
||||
content_json["title"] = judged_result["title"]
|
||||
content_json["content"] = judged_result["content"]
|
||||
# 添加审核标记
|
||||
content_json["judged"] = True
|
||||
# 可选:保存审核分析结果
|
||||
if "不良内容分析" in judged_result:
|
||||
content_json["judge_analysis"] = judged_result["不良内容分析"]
|
||||
else:
|
||||
logging.warning(f" 审核结果缺少title或content字段,保留原内容")
|
||||
else:
|
||||
logging.warning(f" 内容审核返回无效结果,保留原内容")
|
||||
except Exception as judge_err:
|
||||
logging.exception(f" 内容审核过程出错: {judge_err},保留原内容")
|
||||
|
||||
# Use the output handler to process/save the result
|
||||
output_handler.handle_content_variant(
|
||||
run_id, topic_index, variant_index, content_json, prompt_data or ""
|
||||
@ -528,8 +637,7 @@ def generate_posters_for_topic(topic_item: dict,
|
||||
poster_gen_instance.set_text_bg_possibility(text_bg_possibility)
|
||||
except Exception as e:
|
||||
logging.exception("Error initializing generators for poster creation:")
|
||||
return False
|
||||
|
||||
return False
|
||||
# --- Setup: Paths and Object Name ---
|
||||
object_name = topic_item.get("object", "")
|
||||
if not object_name:
|
||||
@ -805,12 +913,4 @@ def generate_posters_for_topic(topic_item: dict,
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
return any_poster_attempted
|
||||
|
||||
# main 函数不再使用,注释掉或移除
|
||||
# def main():
|
||||
# """主函数入口"""
|
||||
# # ... (旧的 main 函数逻辑)
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
return any_poster_attempted
|
||||
Loading…
x
Reference in New Issue
Block a user