315 lines
36 KiB
JSON
315 lines
36 KiB
JSON
{
|
||
"file_path": "travel-algorithms/travel_algorithms/content_generation/topic_generator.py",
|
||
"file_size": 10140,
|
||
"line_count": 312,
|
||
"functions": [
|
||
{
|
||
"name": "__init__",
|
||
"line_start": 26,
|
||
"line_end": 48,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "config",
|
||
"type_hint": "AlgorithmConfig"
|
||
}
|
||
],
|
||
"return_type": null,
|
||
"docstring": "初始化主题生成器\n\nArgs:\n config: 算法配置",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def __init__(self, config: AlgorithmConfig):\n \"\"\"\n 初始化主题生成器\n\n Args:\n config: 算法配置\n \"\"\"\n self.config = config\n self.ai_service = AIService(config.ai_model)\n self.output_manager = OutputManager(config.output)\n self.prompt_manager = PromptManager(config.prompts, config.resources)\n \n # 初始化JSON处理器\n self.json_processor = JSONProcessor(\n enable_repair=config.content_generation.enable_json_repair,\n max_repair_attempts=config.content_generation.json_repair_attempts\n )\n \n # 获取任务特定的模型配置和字段配置\n self.task_model_config = config.ai_model.get_task_config(\"topic_generation\")\n self.field_config = config.content_generation.result_field_mapping.get(\"topic_generation\", {})\n \n logger.info(f\"主题生成器初始化完成,使用模型参数: {self.task_model_config}\")",
|
||
"code_hash": "a63ef7f57ab7947ba5b330c3170b45ae"
|
||
},
|
||
{
|
||
"name": "_parse_topics_result",
|
||
"line_start": 137,
|
||
"line_end": 176,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "raw_output",
|
||
"type_hint": "str"
|
||
}
|
||
],
|
||
"return_type": "List[Dict[str, Any]]",
|
||
"docstring": "解析主题生成结果\n\nArgs:\n raw_output: AI原始输出\n\nReturns:\n 解析后的主题列表",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def _parse_topics_result(self, raw_output: str) -> List[Dict[str, Any]]:\n \"\"\"\n 解析主题生成结果\n\n Args:\n raw_output: AI原始输出\n\n Returns:\n 解析后的主题列表\n \"\"\"\n try:\n # 使用JSON处理器解析\n parsed_data = self.json_processor.parse_llm_output(\n raw_output=raw_output,\n expected_fields=self.field_config.get(\"expected_fields\"),\n required_fields=self.field_config.get(\"required_fields\")\n )\n\n # 确保返回列表格式\n if isinstance(parsed_data, dict):\n # 如果是单个对象,转为列表\n topics = [parsed_data]\n elif isinstance(parsed_data, list):\n topics = parsed_data\n else:\n logger.error(f\"解析结果格式错误: {type(parsed_data)}\")\n return []\n\n # 验证和标准化主题数据\n validated_topics = []\n for i, topic in enumerate(topics):\n validated_topic = self._validate_and_normalize_topic(topic, i)\n if validated_topic:\n validated_topics.append(validated_topic)\n\n return validated_topics\n\n except Exception as e:\n logger.error(f\"主题结果解析失败: {e}\")\n return []",
|
||
"code_hash": "8e540e0c278bec397e75c009f1a73216"
|
||
},
|
||
{
|
||
"name": "_validate_and_normalize_topic",
|
||
"line_start": 178,
|
||
"line_end": 235,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "topic",
|
||
"type_hint": "Dict[str, Any]"
|
||
},
|
||
{
|
||
"name": "index",
|
||
"type_hint": "int"
|
||
}
|
||
],
|
||
"return_type": "Optional[Dict[str, Any]]",
|
||
"docstring": "验证和标准化单个主题数据\n\nArgs:\n topic: 原始主题数据\n index: 主题索引\n\nReturns:\n 标准化后的主题数据",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def _validate_and_normalize_topic(self, topic: Dict[str, Any], index: int) -> Optional[Dict[str, Any]]:\n \"\"\"\n 验证和标准化单个主题数据\n\n Args:\n topic: 原始主题数据\n index: 主题索引\n\n Returns:\n 标准化后的主题数据\n \"\"\"\n try:\n if not isinstance(topic, dict):\n logger.warning(f\"主题 {index + 1} 不是字典格式,跳过\")\n return None\n\n # 根据实际提示词的字段进行标准化\n normalized_topic = {\n \"index\": topic.get(\"index\", str(index + 1)),\n \"date\": topic.get(\"date\", \"\"),\n \"logic\": topic.get(\"logic\", \"\"),\n \"object\": topic.get(\"object\", \"None\"),\n \"product\": topic.get(\"product\", \"None\"),\n \"productLogic\": topic.get(\"productLogic\", \"\"),\n \"style\": topic.get(\"style\", \"\"),\n \"styleLogic\": topic.get(\"styleLogic\", \"\"),\n \"targetAudience\": topic.get(\"targetAudience\", \"\"),\n \"targetAudienceLogic\": topic.get(\"targetAudienceLogic\", \"\"),\n # 额外的元数据\n \"metadata\": {\n \"original_data\": topic,\n \"generated_at\": self.output_manager.run_id,\n \"validation_passed\": True\n }\n }\n\n # 验证必需字段\n required_fields = self.field_config.get(\"required_fields\", [\"index\", \"date\", \"logic\", \"object\"])\n missing_fields = [field for field in required_fields if not normalized_topic.get(field)]\n \n if missing_fields:\n logger.warning(f\"主题 {index + 1} 缺少必需字段: {missing_fields}\")\n # 尝试填充默认值\n for field in missing_fields:\n if field == \"index\":\n normalized_topic[\"index\"] = str(index + 1)\n elif field == \"date\":\n normalized_topic[\"date\"] = \"未指定\"\n elif field == \"logic\":\n normalized_topic[\"logic\"] = f\"主题{index + 1}的策划逻辑\"\n elif field == \"object\":\n normalized_topic[\"object\"] = \"None\"\n\n return normalized_topic\n\n except Exception as e:\n logger.error(f\"主题 {index + 1} 验证失败: {e}\")\n return None",
|
||
"code_hash": "d5d015fd1c3a07903e991889995faa24"
|
||
},
|
||
{
|
||
"name": "get_generation_stats",
|
||
"line_start": 274,
|
||
"line_end": 288,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
}
|
||
],
|
||
"return_type": "Dict[str, Any]",
|
||
"docstring": "获取生成统计信息\n\nReturns:\n 统计信息字典",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def get_generation_stats(self) -> Dict[str, Any]:\n \"\"\"\n 获取生成统计信息\n\n Returns:\n 统计信息字典\n \"\"\"\n return {\n \"task_model_config\": self.task_model_config,\n \"field_config\": self.field_config,\n \"output_directory\": str(self.output_manager.run_output_dir),\n \"ai_model_info\": self.ai_service.get_model_info(),\n \"prompt_templates\": self.prompt_manager.get_available_templates().get(\"topic_generation\", {}),\n \"json_processor_enabled\": self.json_processor.enable_repair\n }",
|
||
"code_hash": "3b89da3c33befd464dfeb3ab5c151526"
|
||
}
|
||
],
|
||
"classes": [
|
||
{
|
||
"name": "TopicGenerator",
|
||
"line_start": 20,
|
||
"line_end": 313,
|
||
"bases": [],
|
||
"methods": [
|
||
{
|
||
"name": "__init__",
|
||
"line_start": 26,
|
||
"line_end": 48,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "config",
|
||
"type_hint": "AlgorithmConfig"
|
||
}
|
||
],
|
||
"return_type": null,
|
||
"docstring": "初始化主题生成器\n\nArgs:\n config: 算法配置",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def __init__(self, config: AlgorithmConfig):\n \"\"\"\n 初始化主题生成器\n\n Args:\n config: 算法配置\n \"\"\"\n self.config = config\n self.ai_service = AIService(config.ai_model)\n self.output_manager = OutputManager(config.output)\n self.prompt_manager = PromptManager(config.prompts, config.resources)\n \n # 初始化JSON处理器\n self.json_processor = JSONProcessor(\n enable_repair=config.content_generation.enable_json_repair,\n max_repair_attempts=config.content_generation.json_repair_attempts\n )\n \n # 获取任务特定的模型配置和字段配置\n self.task_model_config = config.ai_model.get_task_config(\"topic_generation\")\n self.field_config = config.content_generation.result_field_mapping.get(\"topic_generation\", {})\n \n logger.info(f\"主题生成器初始化完成,使用模型参数: {self.task_model_config}\")",
|
||
"code_hash": "a63ef7f57ab7947ba5b330c3170b45ae"
|
||
},
|
||
{
|
||
"name": "generate_topics",
|
||
"line_start": 50,
|
||
"line_end": 135,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "creative_materials",
|
||
"type_hint": "str"
|
||
},
|
||
{
|
||
"name": "num_topics",
|
||
"type_hint": "int"
|
||
},
|
||
{
|
||
"name": "month",
|
||
"type_hint": "str"
|
||
}
|
||
],
|
||
"return_type": "Tuple[str, List[Dict[str, Any]]]",
|
||
"docstring": "生成主题列表\n\nArgs:\n creative_materials: 创作材料\n num_topics: 生成数量\n month: 选题月份\n **kwargs: 其他参数\n\nReturns:\n Tuple[请求ID, 主题列表]\n\nRaises:\n ContentGenerationError: 生成失败时抛出",
|
||
"is_async": true,
|
||
"decorators": [],
|
||
"code": " async def generate_topics(\n self, \n creative_materials: str,\n num_topics: int,\n month: str,\n **kwargs\n ) -> Tuple[str, List[Dict[str, Any]]]:\n \"\"\"\n 生成主题列表\n\n Args:\n creative_materials: 创作材料\n num_topics: 生成数量\n month: 选题月份\n **kwargs: 其他参数\n\n Returns:\n Tuple[请求ID, 主题列表]\n\n Raises:\n ContentGenerationError: 生成失败时抛出\n \"\"\"\n try:\n logger.info(f\"开始执行主题生成流程,数量: {num_topics}, 月份: {month}\")\n\n # 1. 构建提示词\n system_prompt = self.prompt_manager.get_prompt(\"topic_generation\", \"system\")\n user_prompt_template = self.prompt_manager.get_prompt(\"topic_generation\", \"user\")\n \n # 格式化用户提示词\n user_prompt = self.prompt_manager.format_prompt(\n user_prompt_template,\n creative_materials=creative_materials,\n numTopics=num_topics,\n month=month,\n **kwargs\n )\n\n # 保存提示词(如果配置允许)\n if self.config.output.save_prompts:\n self.output_manager.save_text(system_prompt, \"system_prompt\", \"topic_generation\")\n self.output_manager.save_text(user_prompt, \"user_prompt\", \"topic_generation\")\n\n # 2. 调用AI生成\n content, input_tokens, output_tokens, elapsed_time = await self.ai_service.generate_text(\n system_prompt=system_prompt,\n user_prompt=user_prompt,\n stage=\"主题生成\",\n **self.task_model_config\n )\n\n # 保存原始响应(如果配置允许)\n if self.config.output.save_raw_responses:\n self.output_manager.save_text(content, \"raw_response\", \"topic_generation\")\n\n # 3. 解析结果(使用JSON处理器)\n topics = self._parse_topics_result(content)\n if not topics:\n raise ContentGenerationError(\"未能从AI响应中解析出任何有效主题\")\n\n # 4. 保存结果\n self.output_manager.save_json(topics, \"topics\")\n \n # 5. 保存元数据\n metadata = {\n \"creative_materials\": creative_materials[:200] + \"...\" if len(creative_materials) > 200 else creative_materials,\n \"num_topics\": num_topics,\n \"month\": month,\n \"generated_count\": len(topics),\n \"model_config\": self.task_model_config,\n \"field_config\": self.field_config,\n \"tokens\": {\n \"input\": input_tokens,\n \"output\": output_tokens\n },\n \"elapsed_time\": elapsed_time\n }\n self.output_manager.save_metadata(metadata, \"topic_generation\")\n\n logger.info(f\"成功生成并保存 {len(topics)} 个主题\")\n return self.output_manager.run_id, topics\n\n except Exception as e:\n error_msg = f\"主题生成失败: {str(e)}\"\n logger.error(error_msg, exc_info=True)\n raise ContentGenerationError(error_msg)",
|
||
"code_hash": "fa20b6068c4dc2d1364b92e01ca44669"
|
||
},
|
||
{
|
||
"name": "_parse_topics_result",
|
||
"line_start": 137,
|
||
"line_end": 176,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "raw_output",
|
||
"type_hint": "str"
|
||
}
|
||
],
|
||
"return_type": "List[Dict[str, Any]]",
|
||
"docstring": "解析主题生成结果\n\nArgs:\n raw_output: AI原始输出\n\nReturns:\n 解析后的主题列表",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def _parse_topics_result(self, raw_output: str) -> List[Dict[str, Any]]:\n \"\"\"\n 解析主题生成结果\n\n Args:\n raw_output: AI原始输出\n\n Returns:\n 解析后的主题列表\n \"\"\"\n try:\n # 使用JSON处理器解析\n parsed_data = self.json_processor.parse_llm_output(\n raw_output=raw_output,\n expected_fields=self.field_config.get(\"expected_fields\"),\n required_fields=self.field_config.get(\"required_fields\")\n )\n\n # 确保返回列表格式\n if isinstance(parsed_data, dict):\n # 如果是单个对象,转为列表\n topics = [parsed_data]\n elif isinstance(parsed_data, list):\n topics = parsed_data\n else:\n logger.error(f\"解析结果格式错误: {type(parsed_data)}\")\n return []\n\n # 验证和标准化主题数据\n validated_topics = []\n for i, topic in enumerate(topics):\n validated_topic = self._validate_and_normalize_topic(topic, i)\n if validated_topic:\n validated_topics.append(validated_topic)\n\n return validated_topics\n\n except Exception as e:\n logger.error(f\"主题结果解析失败: {e}\")\n return []",
|
||
"code_hash": "8e540e0c278bec397e75c009f1a73216"
|
||
},
|
||
{
|
||
"name": "_validate_and_normalize_topic",
|
||
"line_start": 178,
|
||
"line_end": 235,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "topic",
|
||
"type_hint": "Dict[str, Any]"
|
||
},
|
||
{
|
||
"name": "index",
|
||
"type_hint": "int"
|
||
}
|
||
],
|
||
"return_type": "Optional[Dict[str, Any]]",
|
||
"docstring": "验证和标准化单个主题数据\n\nArgs:\n topic: 原始主题数据\n index: 主题索引\n\nReturns:\n 标准化后的主题数据",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def _validate_and_normalize_topic(self, topic: Dict[str, Any], index: int) -> Optional[Dict[str, Any]]:\n \"\"\"\n 验证和标准化单个主题数据\n\n Args:\n topic: 原始主题数据\n index: 主题索引\n\n Returns:\n 标准化后的主题数据\n \"\"\"\n try:\n if not isinstance(topic, dict):\n logger.warning(f\"主题 {index + 1} 不是字典格式,跳过\")\n return None\n\n # 根据实际提示词的字段进行标准化\n normalized_topic = {\n \"index\": topic.get(\"index\", str(index + 1)),\n \"date\": topic.get(\"date\", \"\"),\n \"logic\": topic.get(\"logic\", \"\"),\n \"object\": topic.get(\"object\", \"None\"),\n \"product\": topic.get(\"product\", \"None\"),\n \"productLogic\": topic.get(\"productLogic\", \"\"),\n \"style\": topic.get(\"style\", \"\"),\n \"styleLogic\": topic.get(\"styleLogic\", \"\"),\n \"targetAudience\": topic.get(\"targetAudience\", \"\"),\n \"targetAudienceLogic\": topic.get(\"targetAudienceLogic\", \"\"),\n # 额外的元数据\n \"metadata\": {\n \"original_data\": topic,\n \"generated_at\": self.output_manager.run_id,\n \"validation_passed\": True\n }\n }\n\n # 验证必需字段\n required_fields = self.field_config.get(\"required_fields\", [\"index\", \"date\", \"logic\", \"object\"])\n missing_fields = [field for field in required_fields if not normalized_topic.get(field)]\n \n if missing_fields:\n logger.warning(f\"主题 {index + 1} 缺少必需字段: {missing_fields}\")\n # 尝试填充默认值\n for field in missing_fields:\n if field == \"index\":\n normalized_topic[\"index\"] = str(index + 1)\n elif field == \"date\":\n normalized_topic[\"date\"] = \"未指定\"\n elif field == \"logic\":\n normalized_topic[\"logic\"] = f\"主题{index + 1}的策划逻辑\"\n elif field == \"object\":\n normalized_topic[\"object\"] = \"None\"\n\n return normalized_topic\n\n except Exception as e:\n logger.error(f\"主题 {index + 1} 验证失败: {e}\")\n return None",
|
||
"code_hash": "d5d015fd1c3a07903e991889995faa24"
|
||
},
|
||
{
|
||
"name": "generate_topics_batch",
|
||
"line_start": 237,
|
||
"line_end": 272,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
},
|
||
{
|
||
"name": "materials_list",
|
||
"type_hint": "List[str]"
|
||
},
|
||
{
|
||
"name": "num_topics_per_batch",
|
||
"type_hint": "int"
|
||
},
|
||
{
|
||
"name": "month",
|
||
"type_hint": "str"
|
||
}
|
||
],
|
||
"return_type": "Dict[str, List[Dict[str, Any]]]",
|
||
"docstring": "批量生成主题\n\nArgs:\n materials_list: 创作材料列表\n num_topics_per_batch: 每批生成数量\n month: 月份\n\nReturns:\n 材料索引->主题列表的字典",
|
||
"is_async": true,
|
||
"decorators": [],
|
||
"code": " async def generate_topics_batch(\n self, \n materials_list: List[str],\n num_topics_per_batch: int,\n month: str\n ) -> Dict[str, List[Dict[str, Any]]]:\n \"\"\"\n 批量生成主题\n\n Args:\n materials_list: 创作材料列表\n num_topics_per_batch: 每批生成数量\n month: 月份\n\n Returns:\n 材料索引->主题列表的字典\n \"\"\"\n results = {}\n \n for i, materials in enumerate(materials_list):\n try:\n logger.info(f\"批量生成主题 {i+1}/{len(materials_list)}\")\n \n request_id, topics = await self.generate_topics(\n creative_materials=materials,\n num_topics=num_topics_per_batch,\n month=month\n )\n \n results[f\"batch_{i+1}\"] = topics\n \n except Exception as e:\n logger.error(f\"批量生成第 {i+1} 项失败: {e}\")\n results[f\"batch_{i+1}\"] = []\n \n return results",
|
||
"code_hash": "062a41b2a1a32017293e260159357949"
|
||
},
|
||
{
|
||
"name": "get_generation_stats",
|
||
"line_start": 274,
|
||
"line_end": 288,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
}
|
||
],
|
||
"return_type": "Dict[str, Any]",
|
||
"docstring": "获取生成统计信息\n\nReturns:\n 统计信息字典",
|
||
"is_async": false,
|
||
"decorators": [],
|
||
"code": " def get_generation_stats(self) -> Dict[str, Any]:\n \"\"\"\n 获取生成统计信息\n\n Returns:\n 统计信息字典\n \"\"\"\n return {\n \"task_model_config\": self.task_model_config,\n \"field_config\": self.field_config,\n \"output_directory\": str(self.output_manager.run_output_dir),\n \"ai_model_info\": self.ai_service.get_model_info(),\n \"prompt_templates\": self.prompt_manager.get_available_templates().get(\"topic_generation\", {}),\n \"json_processor_enabled\": self.json_processor.enable_repair\n }",
|
||
"code_hash": "3b89da3c33befd464dfeb3ab5c151526"
|
||
},
|
||
{
|
||
"name": "test_generation",
|
||
"line_start": 290,
|
||
"line_end": 313,
|
||
"args": [
|
||
{
|
||
"name": "self"
|
||
}
|
||
],
|
||
"return_type": "bool",
|
||
"docstring": "测试主题生成功能\n\nReturns:\n 测试是否成功",
|
||
"is_async": true,
|
||
"decorators": [],
|
||
"code": " async def test_generation(self) -> bool:\n \"\"\"\n 测试主题生成功能\n\n Returns:\n 测试是否成功\n \"\"\"\n try:\n test_materials = \"\"\"\n 选题数量:1\n 选题日期:2024-12-01\n \n 测试创作材料\n \"\"\"\n \n _, topics = await self.generate_topics(\n creative_materials=test_materials,\n num_topics=1,\n month=\"2024-12-01\"\n )\n return len(topics) > 0\n except Exception as e:\n logger.error(f\"主题生成测试失败: {e}\")\n return False ",
|
||
"code_hash": "df4ed3ec6429833e5b5ccbf205884181"
|
||
}
|
||
],
|
||
"docstring": "主题生成器 - 重构版本\n负责生成旅游相关的选题,支持配置化参数和动态提示词",
|
||
"decorators": [],
|
||
"code": "class TopicGenerator:\n \"\"\"\n 主题生成器 - 重构版本\n 负责生成旅游相关的选题,支持配置化参数和动态提示词\n \"\"\"\n\n def __init__(self, config: AlgorithmConfig):\n \"\"\"\n 初始化主题生成器\n\n Args:\n config: 算法配置\n \"\"\"\n self.config = config\n self.ai_service = AIService(config.ai_model)\n self.output_manager = OutputManager(config.output)\n self.prompt_manager = PromptManager(config.prompts, config.resources)\n \n # 初始化JSON处理器\n self.json_processor = JSONProcessor(\n enable_repair=config.content_generation.enable_json_repair,\n max_repair_attempts=config.content_generation.json_repair_attempts\n )\n \n # 获取任务特定的模型配置和字段配置\n self.task_model_config = config.ai_model.get_task_config(\"topic_generation\")\n self.field_config = config.content_generation.result_field_mapping.get(\"topic_generation\", {})\n \n logger.info(f\"主题生成器初始化完成,使用模型参数: {self.task_model_config}\")\n\n async def generate_topics(\n self, \n creative_materials: str,\n num_topics: int,\n month: str,\n **kwargs\n ) -> Tuple[str, List[Dict[str, Any]]]:\n \"\"\"\n 生成主题列表\n\n Args:\n creative_materials: 创作材料\n num_topics: 生成数量\n month: 选题月份\n **kwargs: 其他参数\n\n Returns:\n Tuple[请求ID, 主题列表]\n\n Raises:\n ContentGenerationError: 生成失败时抛出\n \"\"\"\n try:\n logger.info(f\"开始执行主题生成流程,数量: {num_topics}, 月份: {month}\")\n\n # 1. 构建提示词\n system_prompt = self.prompt_manager.get_prompt(\"topic_generation\", \"system\")\n user_prompt_template = self.prompt_manager.get_prompt(\"topic_generation\", \"user\")\n \n # 格式化用户提示词\n user_prompt = self.prompt_manager.format_prompt(\n user_prompt_template,\n creative_materials=creative_materials,\n numTopics=num_topics,\n month=month,\n **kwargs\n )\n\n # 保存提示词(如果配置允许)\n if self.config.output.save_prompts:\n self.output_manager.save_text(system_prompt, \"system_prompt\", \"topic_generation\")\n self.output_manager.save_text(user_prompt, \"user_prompt\", \"topic_generation\")\n\n # 2. 调用AI生成\n content, input_tokens, output_tokens, elapsed_time = await self.ai_service.generate_text(\n system_prompt=system_prompt,\n user_prompt=user_prompt,\n stage=\"主题生成\",\n **self.task_model_config\n )\n\n # 保存原始响应(如果配置允许)\n if self.config.output.save_raw_responses:\n self.output_manager.save_text(content, \"raw_response\", \"topic_generation\")\n\n # 3. 解析结果(使用JSON处理器)\n topics = self._parse_topics_result(content)\n if not topics:\n raise ContentGenerationError(\"未能从AI响应中解析出任何有效主题\")\n\n # 4. 保存结果\n self.output_manager.save_json(topics, \"topics\")\n \n # 5. 保存元数据\n metadata = {\n \"creative_materials\": creative_materials[:200] + \"...\" if len(creative_materials) > 200 else creative_materials,\n \"num_topics\": num_topics,\n \"month\": month,\n \"generated_count\": len(topics),\n \"model_config\": self.task_model_config,\n \"field_config\": self.field_config,\n \"tokens\": {\n \"input\": input_tokens,\n \"output\": output_tokens\n },\n \"elapsed_time\": elapsed_time\n }\n self.output_manager.save_metadata(metadata, \"topic_generation\")\n\n logger.info(f\"成功生成并保存 {len(topics)} 个主题\")\n return self.output_manager.run_id, topics\n\n except Exception as e:\n error_msg = f\"主题生成失败: {str(e)}\"\n logger.error(error_msg, exc_info=True)\n raise ContentGenerationError(error_msg)\n\n def _parse_topics_result(self, raw_output: str) -> List[Dict[str, Any]]:\n \"\"\"\n 解析主题生成结果\n\n Args:\n raw_output: AI原始输出\n\n Returns:\n 解析后的主题列表\n \"\"\"\n try:\n # 使用JSON处理器解析\n parsed_data = self.json_processor.parse_llm_output(\n raw_output=raw_output,\n expected_fields=self.field_config.get(\"expected_fields\"),\n required_fields=self.field_config.get(\"required_fields\")\n )\n\n # 确保返回列表格式\n if isinstance(parsed_data, dict):\n # 如果是单个对象,转为列表\n topics = [parsed_data]\n elif isinstance(parsed_data, list):\n topics = parsed_data\n else:\n logger.error(f\"解析结果格式错误: {type(parsed_data)}\")\n return []\n\n # 验证和标准化主题数据\n validated_topics = []\n for i, topic in enumerate(topics):\n validated_topic = self._validate_and_normalize_topic(topic, i)\n if validated_topic:\n validated_topics.append(validated_topic)\n\n return validated_topics\n\n except Exception as e:\n logger.error(f\"主题结果解析失败: {e}\")\n return []\n\n def _validate_and_normalize_topic(self, topic: Dict[str, Any], index: int) -> Optional[Dict[str, Any]]:\n \"\"\"\n 验证和标准化单个主题数据\n\n Args:\n topic: 原始主题数据\n index: 主题索引\n\n Returns:\n 标准化后的主题数据\n \"\"\"\n try:\n if not isinstance(topic, dict):\n logger.warning(f\"主题 {index + 1} 不是字典格式,跳过\")\n return None\n\n # 根据实际提示词的字段进行标准化\n normalized_topic = {\n \"index\": topic.get(\"index\", str(index + 1)),\n \"date\": topic.get(\"date\", \"\"),\n \"logic\": topic.get(\"logic\", \"\"),\n \"object\": topic.get(\"object\", \"None\"),\n \"product\": topic.get(\"product\", \"None\"),\n \"productLogic\": topic.get(\"productLogic\", \"\"),\n \"style\": topic.get(\"style\", \"\"),\n \"styleLogic\": topic.get(\"styleLogic\", \"\"),\n \"targetAudience\": topic.get(\"targetAudience\", \"\"),\n \"targetAudienceLogic\": topic.get(\"targetAudienceLogic\", \"\"),\n # 额外的元数据\n \"metadata\": {\n \"original_data\": topic,\n \"generated_at\": self.output_manager.run_id,\n \"validation_passed\": True\n }\n }\n\n # 验证必需字段\n required_fields = self.field_config.get(\"required_fields\", [\"index\", \"date\", \"logic\", \"object\"])\n missing_fields = [field for field in required_fields if not normalized_topic.get(field)]\n \n if missing_fields:\n logger.warning(f\"主题 {index + 1} 缺少必需字段: {missing_fields}\")\n # 尝试填充默认值\n for field in missing_fields:\n if field == \"index\":\n normalized_topic[\"index\"] = str(index + 1)\n elif field == \"date\":\n normalized_topic[\"date\"] = \"未指定\"\n elif field == \"logic\":\n normalized_topic[\"logic\"] = f\"主题{index + 1}的策划逻辑\"\n elif field == \"object\":\n normalized_topic[\"object\"] = \"None\"\n\n return normalized_topic\n\n except Exception as e:\n logger.error(f\"主题 {index + 1} 验证失败: {e}\")\n return None\n\n async def generate_topics_batch(\n self, \n materials_list: List[str],\n num_topics_per_batch: int,\n month: str\n ) -> Dict[str, List[Dict[str, Any]]]:\n \"\"\"\n 批量生成主题\n\n Args:\n materials_list: 创作材料列表\n num_topics_per_batch: 每批生成数量\n month: 月份\n\n Returns:\n 材料索引->主题列表的字典\n \"\"\"\n results = {}\n \n for i, materials in enumerate(materials_list):\n try:\n logger.info(f\"批量生成主题 {i+1}/{len(materials_list)}\")\n \n request_id, topics = await self.generate_topics(\n creative_materials=materials,\n num_topics=num_topics_per_batch,\n month=month\n )\n \n results[f\"batch_{i+1}\"] = topics\n \n except Exception as e:\n logger.error(f\"批量生成第 {i+1} 项失败: {e}\")\n results[f\"batch_{i+1}\"] = []\n \n return results\n\n def get_generation_stats(self) -> Dict[str, Any]:\n \"\"\"\n 获取生成统计信息\n\n Returns:\n 统计信息字典\n \"\"\"\n return {\n \"task_model_config\": self.task_model_config,\n \"field_config\": self.field_config,\n \"output_directory\": str(self.output_manager.run_output_dir),\n \"ai_model_info\": self.ai_service.get_model_info(),\n \"prompt_templates\": self.prompt_manager.get_available_templates().get(\"topic_generation\", {}),\n \"json_processor_enabled\": self.json_processor.enable_repair\n }\n\n async def test_generation(self) -> bool:\n \"\"\"\n 测试主题生成功能\n\n Returns:\n 测试是否成功\n \"\"\"\n try:\n test_materials = \"\"\"\n 选题数量:1\n 选题日期:2024-12-01\n \n 测试创作材料\n \"\"\"\n \n _, topics = await self.generate_topics(\n creative_materials=test_materials,\n num_topics=1,\n month=\"2024-12-01\"\n )\n return len(topics) > 0\n except Exception as e:\n logger.error(f\"主题生成测试失败: {e}\")\n return False ",
|
||
"code_hash": "e97634262996862e24c8d112dc1c9fc7"
|
||
}
|
||
],
|
||
"imports": [
|
||
{
|
||
"type": "import",
|
||
"modules": [
|
||
"logging"
|
||
],
|
||
"aliases": []
|
||
},
|
||
{
|
||
"type": "from_import",
|
||
"module": "typing",
|
||
"names": [
|
||
"Dict",
|
||
"Any",
|
||
"List",
|
||
"Optional",
|
||
"Tuple"
|
||
],
|
||
"aliases": [],
|
||
"level": 0
|
||
},
|
||
{
|
||
"type": "import",
|
||
"modules": [
|
||
"json"
|
||
],
|
||
"aliases": []
|
||
},
|
||
{
|
||
"type": "from_import",
|
||
"module": "config",
|
||
"names": [
|
||
"AlgorithmConfig"
|
||
],
|
||
"aliases": [],
|
||
"level": 2
|
||
},
|
||
{
|
||
"type": "from_import",
|
||
"module": "core",
|
||
"names": [
|
||
"AIService",
|
||
"OutputManager",
|
||
"PromptManager",
|
||
"JSONProcessor"
|
||
],
|
||
"aliases": [],
|
||
"level": 2
|
||
},
|
||
{
|
||
"type": "from_import",
|
||
"module": "exceptions",
|
||
"names": [
|
||
"ContentGenerationError"
|
||
],
|
||
"aliases": [],
|
||
"level": 2
|
||
}
|
||
],
|
||
"constants": [],
|
||
"docstring": "Topic Generator\n主题生成器 - 重构版本,使用动态提示词加载和JSON修复",
|
||
"content_hash": "41e55f0fb2ec82a898b8997806cbc409"
|
||
} |