diff --git a/api/models/__pycache__/tweet.cpython-312.pyc b/api/models/__pycache__/tweet.cpython-312.pyc index 55741ec..30503ab 100644 Binary files a/api/models/__pycache__/tweet.cpython-312.pyc and b/api/models/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/models/tweet.py b/api/models/tweet.py index 63d02a0..7a3f158 100644 --- a/api/models/tweet.py +++ b/api/models/tweet.py @@ -57,20 +57,29 @@ class TopicResponse(BaseModel): class ContentRequest(BaseModel): """内容生成请求模型""" - topic: Dict[str, Any] = Field(..., description="选题信息") + topic: Optional[Dict[str, Any]] = Field(None, description="选题信息") + styles: Optional[List[str]] = Field(None, description="风格列表") + audiences: Optional[List[str]] = Field(None, description="受众列表") + scenic_spots: Optional[List[str]] = Field(None, description="景区列表") + products: Optional[List[str]] = Field(None, description="产品列表") + auto_judge: bool = Field(False, description="是否自动进行内容审核") class Config: schema_extra = { "example": { "topic": { "index": "1", - "date": "2023-07-15", - "object": "北京故宫", - "product": "故宫门票", - "style": "旅游攻略", - "target_audience": "年轻人", - "logic": "暑期旅游热门景点推荐" - } + "date": "2024-07-01", + "style": "攻略风", + "target_audience": "亲子向", + "object": "天津冒险湾", + "product": "冒险湾-2大2小套票" + }, + "styles": ["攻略风", "种草风"], + "audiences": ["亲子向", "情侣向"], + "scenic_spots": ["天津冒险湾", "北京故宫"], + "products": ["冒险湾-2大2小套票", "故宫门票"], + "auto_judge": True } } @@ -97,8 +106,12 @@ class ContentResponse(BaseModel): class JudgeRequest(BaseModel): """内容审核请求模型""" - topic: Dict[str, Any] = Field(..., description="选题信息") + topic: Optional[Dict[str, Any]] = Field(None, description="选题信息") content: Dict[str, Any] = Field(..., description="要审核的内容") + styles: Optional[List[str]] = Field(None, description="风格列表") + audiences: Optional[List[str]] = Field(None, description="受众列表") + scenic_spots: Optional[List[str]] = Field(None, description="景区列表") + products: Optional[List[str]] = Field(None, description="产品列表") class Config: schema_extra = { @@ -116,7 +129,11 @@ class JudgeRequest(BaseModel): "title": "【北京故宫】避开人潮的秘密路线,90%的人都不知道!", "content": "故宫,作为中国最著名的文化遗产之一...", "tag": ["北京旅游", "故宫", "旅游攻略", "避暑胜地"] - } + }, + "styles": ["旅游攻略"], + "audiences": ["年轻人"], + "scenic_spots": ["北京故宫"], + "products": ["故宫门票"] } } @@ -144,25 +161,27 @@ class JudgeResponse(BaseModel): class PipelineRequest(BaseModel): - """完整流程请求模型""" - dates: Optional[str] = Field(None, description="日期字符串,可能为单个日期、多个日期用逗号分隔或范围如'2023-01-01 to 2023-01-31'") - num_topics: int = Field(5, description="要生成的选题数量", ge=1, le=10) + """流水线请求模型""" + dates: Optional[str] = Field(None, description="日期范围,如:'2024-07-01 to 2024-07-31'") + num_topics: int = Field(5, description="要生成的选题数量") styles: Optional[List[str]] = Field(None, description="风格列表") audiences: Optional[List[str]] = Field(None, description="受众列表") scenic_spots: Optional[List[str]] = Field(None, description="景区列表") products: Optional[List[str]] = Field(None, description="产品列表") skip_judge: bool = Field(False, description="是否跳过内容审核步骤") + auto_judge: bool = Field(False, description="是否在内容生成时进行内嵌审核") class Config: schema_extra = { "example": { - "dates": "2023-07-01 to 2023-07-31", + "dates": "2024-07-01 to 2024-07-31", "num_topics": 3, - "styles": ["旅游攻略", "亲子游"], - "audiences": ["年轻人", "家庭"], - "scenic_spots": ["故宫", "长城"], - "products": ["门票", "导游服务"], - "skip_judge": False + "styles": ["攻略风", "种草风"], + "audiences": ["亲子向", "情侣向"], + "scenic_spots": ["天津冒险湾", "北京故宫"], + "products": ["冒险湾-2大2小套票", "故宫门票"], + "skip_judge": False, + "auto_judge": True } } diff --git a/api/routers/__init__.py b/api/routers/__init__.py index 15d094d..5dce1b3 100644 --- a/api/routers/__init__.py +++ b/api/routers/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - + """ API路由模块 """ \ No newline at end of file diff --git a/api/routers/__pycache__/__init__.cpython-312.pyc b/api/routers/__pycache__/__init__.cpython-312.pyc index 7e8bd9c..097640c 100644 Binary files a/api/routers/__pycache__/__init__.cpython-312.pyc and b/api/routers/__pycache__/__init__.cpython-312.pyc differ diff --git a/api/routers/__pycache__/tweet.cpython-312.pyc b/api/routers/__pycache__/tweet.cpython-312.pyc index 8275c03..ca7908e 100644 Binary files a/api/routers/__pycache__/tweet.cpython-312.pyc and b/api/routers/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/routers/tweet.py b/api/routers/tweet.py index 7eb282d..180680a 100644 --- a/api/routers/tweet.py +++ b/api/routers/tweet.py @@ -92,13 +92,23 @@ async def generate_content( tweet_service: TweetService = Depends(get_tweet_service) ): """ - 为选题生成内容 + 生成内容 - **topic**: 选题信息 + - **styles**: 风格列表 + - **audiences**: 受众列表 + - **scenic_spots**: 景区列表 + - **products**: 产品列表 + - **auto_judge**: 是否自动进行内容审核 """ try: request_id, topic_index, content = await tweet_service.generate_content( - topic=request.topic + topic=request.topic, + styles=request.styles, + audiences=request.audiences, + scenic_spots=request.scenic_spots, + products=request.products, + auto_judge=request.auto_judge ) return ContentResponse( @@ -150,11 +160,19 @@ async def judge_content( - **topic**: 选题信息 - **content**: 要审核的内容 + - **styles**: 风格列表 + - **audiences**: 受众列表 + - **scenic_spots**: 景区列表 + - **products**: 产品列表 """ try: request_id, topic_index, judged_content, judge_success = await tweet_service.judge_content( topic=request.topic, - content=request.content + content=request.content, + styles=request.styles, + audiences=request.audiences, + scenic_spots=request.scenic_spots, + products=request.products ) return JudgeResponse( @@ -174,15 +192,16 @@ async def run_pipeline( tweet_service: TweetService = Depends(get_tweet_service) ): """ - 运行完整流水线,包括生成选题、生成内容和审核内容 + 运行完整流水线:生成选题 → 生成内容 → 审核内容 - - **dates**: 日期字符串,可能为单个日期、多个日期用逗号分隔或范围 + - **dates**: 日期范围 - **num_topics**: 要生成的选题数量 - **styles**: 风格列表 - **audiences**: 受众列表 - **scenic_spots**: 景区列表 - **products**: 产品列表 - **skip_judge**: 是否跳过内容审核步骤 + - **auto_judge**: 是否在内容生成时进行内嵌审核 """ try: request_id, topics, contents, judged_contents = await tweet_service.run_pipeline( @@ -192,7 +211,8 @@ async def run_pipeline( audiences=request.audiences, scenic_spots=request.scenic_spots, products=request.products, - skip_judge=request.skip_judge + skip_judge=request.skip_judge, + auto_judge=request.auto_judge ) return PipelineResponse( diff --git a/api/services/__pycache__/prompt_builder.cpython-312.pyc b/api/services/__pycache__/prompt_builder.cpython-312.pyc index 750111f..67c8817 100644 Binary files a/api/services/__pycache__/prompt_builder.cpython-312.pyc and b/api/services/__pycache__/prompt_builder.cpython-312.pyc differ diff --git a/api/services/__pycache__/tweet.cpython-312.pyc b/api/services/__pycache__/tweet.cpython-312.pyc index f55f292..2da576d 100644 Binary files a/api/services/__pycache__/tweet.cpython-312.pyc and b/api/services/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py index a1c7c48..b1aacdb 100644 --- a/api/services/prompt_builder.py +++ b/api/services/prompt_builder.py @@ -109,6 +109,85 @@ class PromptBuilderService: return system_prompt, user_prompt + def build_content_prompt_with_params(self, topic: Dict[str, Any], + styles: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, + scenic_spots: Optional[List[str]] = None, + products: Optional[List[str]] = None, + step: str = "content") -> Tuple[str, str]: + """ + 使用额外参数构建内容生成提示词 + + Args: + topic: 选题信息 + styles: 风格列表 + audiences: 受众列表 + scenic_spots: 景区列表 + products: 产品列表 + step: 当前步骤,用于过滤参考内容 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 获取内容生成配置 + content_config = self._ensure_content_config() + + # 加载系统提示词和用户提示词模板 + system_prompt_path = content_config.content_system_prompt + user_prompt_path = content_config.content_user_prompt + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取风格内容 + style_content = '' + if styles: + style_content = '\n'.join([f"{style}: {self.prompt_service.get_style_content(style)}" for style in styles]) + else: + style_filename = topic.get("style", "") + style_content = f"{style_filename}\n{self.prompt_service.get_style_content(style_filename)}" + + # 获取目标受众内容 + demand_content = '' + if audiences: + demand_content = '\n'.join([f"{audience}: {self.prompt_service.get_audience_content(audience)}" for audience in audiences]) + else: + demand_filename = topic.get("target_audience", "") + demand_content = f"{demand_filename}\n{self.prompt_service.get_audience_content(demand_filename)}" + + # 获取景区信息 + object_content = '' + if scenic_spots: + object_content = '\n'.join([f"{spot}: {self.prompt_service.get_scenic_spot_info(spot)}" for spot in scenic_spots]) + else: + object_name = topic.get("object", "") + object_content = f"{object_name}\n{self.prompt_service.get_scenic_spot_info(object_name)}" + + # 获取产品信息 + product_content = '' + if products: + product_content = '\n'.join([f"{product}: {self.prompt_service.get_product_info(product)}" for product in products]) + else: + product_name = topic.get("product", "") + product_content = f"{product_name}\n{self.prompt_service.get_product_info(product_name)}" + + # 获取参考内容 + refer_content = self.prompt_service.get_refer_content(step) + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + style_content=style_content, + demand_content=demand_content, + object_content=object_content, + product_content=product_content, + refer_content=refer_content + ) + + return system_prompt, user_prompt + def build_poster_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]: """ 构建海报生成提示词 @@ -292,4 +371,115 @@ class PromptBuilderService: refer_content=refer_content ) + return system_prompt, user_prompt + + def build_judge_prompt_with_params(self, topic: Dict[str, Any], content: Dict[str, Any], + styles: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, + scenic_spots: Optional[List[str]] = None, + products: Optional[List[str]] = None) -> Tuple[str, str]: + """ + 使用额外参数构建内容审核提示词 + + Args: + topic: 选题信息 + content: 生成的内容 + styles: 风格列表 + audiences: 受众列表 + scenic_spots: 景区列表 + products: 产品列表 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 获取内容生成配置 + content_config = self._ensure_content_config() + + # 从配置中获取审核提示词模板路径 + system_prompt_path = content_config.judger_system_prompt + user_prompt_path = content_config.judger_user_prompt + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取景区信息 + object_content = '' + if scenic_spots: + object_content = '\n'.join([f"{spot}: {self.prompt_service.get_scenic_spot_info(spot)}" for spot in scenic_spots]) + else: + object_name = topic.get("object", "") + object_content = f"{object_name}\n{self.prompt_service.get_scenic_spot_info(object_name)}" + + # 获取产品信息 + product_content = '' + if products: + product_content = '\n'.join([f"{product}: {self.prompt_service.get_product_info(product)}" for product in products]) + else: + product_name = topic.get("product", "") + product_content = f"{product_name}\n{self.prompt_service.get_product_info(product_name)}" + + # 获取参考内容 + refer_content = self.prompt_service.get_refer_content("judge") + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 格式化内容 + import json + tweet_content = json.dumps(content, ensure_ascii=False, indent=4) + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + tweet_content=tweet_content, + object_content=object_content, + product_content=product_content, + refer_content=refer_content + ) + + return system_prompt, user_prompt + + def build_judge_prompt_simple(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]: + """ + 构建简化的内容审核提示词(只需要产品信息、景区信息和文章) + + Args: + topic: 选题信息 + content: 生成的内容 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 获取内容生成配置 + content_config = self._ensure_content_config() + + # 从配置中获取审核提示词模板路径 + system_prompt_path = content_config.judger_system_prompt + user_prompt_path = content_config.judger_user_prompt + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取景区信息 + object_name = topic.get("object", "") + object_content = self.prompt_service.get_scenic_spot_info(object_name) + + # 获取产品信息 + product_name = topic.get("product", "") + product_content = self.prompt_service.get_product_info(product_name) + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 格式化内容 + import json + tweet_content = json.dumps(content, ensure_ascii=False, indent=4) + + # 构建用户提示词(简化版,不包含参考内容) + user_prompt = template.build_user_prompt( + tweet_content=tweet_content, + object_content=object_content, + product_content=product_content, + refer_content="" # 简化版不使用参考内容 + ) + return system_prompt, user_prompt \ No newline at end of file diff --git a/api/services/tweet.py b/api/services/tweet.py index 12bf79e..0e17e23 100644 --- a/api/services/tweet.py +++ b/api/services/tweet.py @@ -97,24 +97,68 @@ class TweetService: logger.info(f"选题生成完成,请求ID: {request_id}, 数量: {len(topics)}") return request_id, topics - async def generate_content(self, topic: Dict[str, Any]) -> Tuple[str, str, Dict[str, Any]]: + async def generate_content(self, topic: Optional[Dict[str, Any]] = None, + styles: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, + scenic_spots: Optional[List[str]] = None, + products: Optional[List[str]] = None, + auto_judge: bool = False) -> Tuple[str, str, Dict[str, Any]]: """ 为选题生成内容 Args: topic: 选题信息 + styles: 风格列表 + audiences: 受众列表 + scenic_spots: 景区列表 + products: 产品列表 + auto_judge: 是否自动进行内容审核 Returns: - 请求ID、选题索引和生成的内容 + 请求ID、选题索引和生成的内容(如果启用审核则返回审核后的内容) """ + # 如果没有提供topic,创建一个基础的topic + if not topic: + topic = {"index": "1", "date": "2024-07-01"} + topic_index = topic.get('index', 'unknown') - logger.info(f"开始为选题 {topic_index} 生成内容") + logger.info(f"开始为选题 {topic_index} 生成内容{'(含审核)' if auto_judge else ''}") + + # 创建topic的副本并应用覆盖参数 + enhanced_topic = topic.copy() + if styles and len(styles) > 0: + enhanced_topic['style'] = styles[0] # 使用第一个风格 + if audiences and len(audiences) > 0: + enhanced_topic['target_audience'] = audiences[0] # 使用第一个受众 + if scenic_spots and len(scenic_spots) > 0: + enhanced_topic['object'] = scenic_spots[0] # 使用第一个景区 + if products and len(products) > 0: + enhanced_topic['product'] = products[0] # 使用第一个产品 # 使用PromptBuilderService构建提示词 - system_prompt, user_prompt = self.prompt_builder.build_content_prompt(topic, "content") + system_prompt, user_prompt = self.prompt_builder.build_content_prompt(enhanced_topic, "content") # 使用预构建的提示词生成内容 - content = await self.content_generator.generate_content_with_prompt(topic, system_prompt, user_prompt) + content = await self.content_generator.generate_content_with_prompt(enhanced_topic, system_prompt, user_prompt) + + # 如果启用自动审核,则进行内嵌审核 + if auto_judge: + logger.info(f"开始对选题 {topic_index} 的内容进行内嵌审核") + try: + # 构建简化的审核提示词(只需要产品信息、景区信息和文章) + judge_system_prompt, judge_user_prompt = self.prompt_builder.build_judge_prompt_simple(enhanced_topic, content) + + # 进行审核 + judged_content = await self.content_judger.judge_content_with_prompt(content, enhanced_topic, judge_system_prompt, judge_user_prompt) + + if judged_content.get('judge_success', False): + logger.info(f"选题 {topic_index} 内容审核成功,使用审核后的内容") + content = judged_content + else: + logger.warning(f"选题 {topic_index} 内容审核失败,使用原始内容") + + except Exception as e: + logger.error(f"选题 {topic_index} 内嵌审核失败: {e},使用原始内容") # 生成请求ID request_id = f"content_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" @@ -146,25 +190,52 @@ class TweetService: logger.info(f"内容生成完成,请求ID: {request_id}, 选题索引: {topic_index}") return request_id, topic_index, content - async def judge_content(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str, Dict[str, Any], bool]: + async def judge_content(self, topic: Optional[Dict[str, Any]] = None, content: Dict[str, Any] = {}, + styles: Optional[List[str]] = None, + audiences: Optional[List[str]] = None, + scenic_spots: Optional[List[str]] = None, + products: Optional[List[str]] = None) -> Tuple[str, str, Dict[str, Any], bool]: """ 审核内容 Args: topic: 选题信息 content: 要审核的内容 + styles: 风格列表 + audiences: 受众列表 + scenic_spots: 景区列表 + products: 产品列表 Returns: 请求ID、选题索引、审核后的内容和审核是否成功 """ + # 如果没有提供topic,创建一个基础的topic + if not topic: + topic = {"index": "1", "date": "2024-07-01"} + + # 如果没有提供content,返回错误 + if not content: + content = {"title": "未提供内容", "content": "未提供内容"} + topic_index = topic.get('index', 'unknown') logger.info(f"开始审核选题 {topic_index} 的内容") + # 创建topic的副本并应用覆盖参数 + enhanced_topic = topic.copy() + if styles and len(styles) > 0: + enhanced_topic['style'] = styles[0] # 使用第一个风格 + if audiences and len(audiences) > 0: + enhanced_topic['target_audience'] = audiences[0] # 使用第一个受众 + if scenic_spots and len(scenic_spots) > 0: + enhanced_topic['object'] = scenic_spots[0] # 使用第一个景区 + if products and len(products) > 0: + enhanced_topic['product'] = products[0] # 使用第一个产品 + # 使用PromptBuilderService构建提示词 - system_prompt, user_prompt = self.prompt_builder.build_judge_prompt(topic, content) + system_prompt, user_prompt = self.prompt_builder.build_judge_prompt(enhanced_topic, content) # 审核内容 - judged_data = await self.content_judger.judge_content_with_prompt(content, topic, system_prompt, user_prompt) + judged_data = await self.content_judger.judge_content_with_prompt(content, enhanced_topic, system_prompt, user_prompt) judge_success = judged_data.get('judge_success', False) # 生成请求ID @@ -178,7 +249,8 @@ class TweetService: audiences: Optional[List[str]] = None, scenic_spots: Optional[List[str]] = None, products: Optional[List[str]] = None, - skip_judge: bool = False) -> Tuple[str, List[Dict[str, Any]], Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]: + skip_judge: bool = False, + auto_judge: bool = False) -> Tuple[str, List[Dict[str, Any]], Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]: """ 运行完整流水线 @@ -189,12 +261,13 @@ class TweetService: audiences: 受众列表 scenic_spots: 景区列表 products: 产品列表 - skip_judge: 是否跳过内容审核步骤 + skip_judge: 是否跳过内容审核步骤(与auto_judge互斥) + auto_judge: 是否在内容生成时进行内嵌审核 Returns: 请求ID、生成的选题列表、生成的内容和审核后的内容 """ - logger.info(f"开始运行完整流水线,日期: {dates}, 数量: {num_topics}") + logger.info(f"开始运行完整流水线,日期: {dates}, 数量: {num_topics}, 内嵌审核: {auto_judge}") # 生成请求ID request_id = f"pipeline_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}" @@ -205,19 +278,19 @@ class TweetService: logger.error("未能生成任何选题,流程终止") return request_id, [], {}, {} - # 步骤2: 为每个选题生成内容 + # 步骤2: 为每个选题生成内容(可选择内嵌审核) contents = {} for topic in topics: topic_index = topic.get('index', 'unknown') - _, _, content = await self.generate_content(topic) + _, _, content = await self.generate_content(topic, auto_judge=auto_judge) contents[topic_index] = content - # 如果跳过审核,直接返回结果 - if skip_judge: - logger.info(f"跳过内容审核步骤,流水线完成,请求ID: {request_id}") + # 如果使用内嵌审核或跳过审核,直接返回结果 + if auto_judge or skip_judge: + logger.info(f"{'使用内嵌审核' if auto_judge else '跳过内容审核步骤'},流水线完成,请求ID: {request_id}") return request_id, topics, contents, contents - # 步骤3: 审核内容 + # 步骤3: 独立审核内容(仅在未使用内嵌审核且未跳过审核时执行) judged_contents = {} for topic_index, content in contents.items(): topic = next((t for t in topics if t.get('index') == topic_index), None) diff --git a/core/config/__pycache__/manager.cpython-312.pyc b/core/config/__pycache__/manager.cpython-312.pyc index e3edd69..5b3d31c 100644 Binary files a/core/config/__pycache__/manager.cpython-312.pyc and b/core/config/__pycache__/manager.cpython-312.pyc differ diff --git a/core/config/manager.py b/core/config/manager.py index 463554f..9a4deaa 100644 --- a/core/config/manager.py +++ b/core/config/manager.py @@ -128,7 +128,7 @@ class ConfigManager: Args: name: 配置名称 - + Returns: 原始配置数据字典 """ diff --git a/core/config/manager.py.backup b/core/config/manager.py.backup new file mode 100644 index 0000000..ffc7d1d --- /dev/null +++ b/core/config/manager.py.backup @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +统一配置管理器 +""" + +import json +import os +import logging +from pathlib import Path +from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set + +from core.config.models import ( + BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, + GenerateContentConfig, PosterConfig, ContentConfig +) + +logger = logging.getLogger(__name__) + +T = TypeVar('T', bound=BaseConfig) + + +class ConfigManager: + """ + 统一配置管理器 + 负责加载、管理和访问所有配置 + """ + + # 服务端必要的全局配置 + SERVER_CONFIGS = {'system', 'ai_model', 'database'} + + # 单次生成任务的配置 + TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'} + + def __init__(self): + self._configs: Dict[str, BaseConfig] = {} + self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据 + self.config_dir: Optional[Path] = None + self.config_objects = { + 'ai_model': AIModelConfig(), + 'system': SystemConfig(), + 'resource': ResourceConfig() + } + self._loaded_configs: Set[str] = set() + + def load_from_directory(self, config_dir: str, server_mode: bool = False): + """ + 从目录加载配置 + + Args: + config_dir: 配置文件目录 + server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 + """ + self.config_dir = Path(config_dir) + if not self.config_dir.is_dir(): + logger.error(f"配置目录不存在: {config_dir}") + raise FileNotFoundError(f"配置目录不存在: {config_dir}") + + # 注册所有已知的配置类型 + self._register_configs() + + # 动态加载目录中的所有.json文件 + self._load_all_configs_from_dir(server_mode) + + def _register_configs(self): + """注册所有配置""" + self.register_config('ai_model', AIModelConfig) + self.register_config('system', SystemConfig) + self.register_config('resource', ResourceConfig) + + # 这些配置在服务器模式下不会自动加载,但仍然需要注册类型 + self.register_config('poster', PosterConfig) + self.register_config('content', ContentConfig) + self.register_config('topic_gen', GenerateTopicConfig) + self.register_config('content_gen', GenerateContentConfig) + + def register_config(self, name: str, config_class: Type[T]) -> None: + """ + 注册一个配置类 + + Args: + name: 配置名称 + config_class: 配置类 (必须是 BaseConfig 的子类) + """ + if not issubclass(config_class, BaseConfig): + raise TypeError("config_class must be a subclass of BaseConfig") + if name not in self._configs: + self._configs[name] = config_class() + + def get_config(self, name: str, config_class: Type[T]) -> T: + """ + 获取配置实例 + + Args: + name: 配置名称 + config_class: 配置类 (用于类型提示) + + Returns: + 配置实例 + """ + config = self._configs.get(name) + if config is None: + # 如果配置不存在,先注册一个默认实例 + self.register_config(name, config_class) + config = self._configs.get(name) + + # 确保配置是正确的类型 + if not isinstance(config, config_class): + # 尝试转换配置 + try: + if isinstance(config, BaseConfig): + # 将现有配置转换为请求的类型 + new_config = config_class(**config.model_dump()) + self._configs[name] = new_config + config = new_config + else: + raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") + except Exception as e: + logger.error(f"转换配置 '{name}' 到类型 '{config_class.__name__}' 失败: {e}") + raise TypeError(f"Configuration '{name}' is not of type '{config_class.__name__}'") from e + + return cast(T, config) + + def get_raw_config(self, name: str) -> Dict[str, Any]: + """ + 获取原始配置数据 + + Args: + name: 配置名称 + + Returns: + 原始配置数据字典 + """ + if name in self._raw_configs: + return self._raw_configs[name] + + # 如果没有原始配置,但有对象配置,则转换为字典 + if name in self._configs: + return self._configs[name].to_dict() + + # 尝试从文件加载 + if self.config_dir: + config_path = self.config_dir / f"{name}.json" + if config_path.exists(): + try: + with open(config_path, 'r', encoding='utf-8') as f: + raw_config = json.load(f) + self._raw_configs[name] = raw_config + return raw_config + except Exception as e: + logger.error(f"加载原始配置 '{name}' 失败: {e}") + + # 返回空字典 + return {} + + def _load_all_configs_from_dir(self, server_mode: bool = False): + """ + 动态加载目录中的所有.json文件 + + Args: + server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 + """ + try: + # 遍历并加载目录中所有其他的 .json 文件 + for config_path in self.config_dir.glob('*.json'): + config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen' + + # 服务器模式下,只加载必要的全局配置 + if server_mode and config_name not in self.SERVER_CONFIGS: + logger.info(f"服务器模式下跳过非全局配置: {config_name}") + continue + + # 加载原始配置 + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + self._raw_configs[config_name] = config_data + + # 更新对象配置 + if config_name in self._configs: + logger.info(f"加载配置文件 '{config_name}': {config_path}") + self._configs[config_name].update(config_data) + self._loaded_configs.add(config_name) + else: + logger.info(f"加载原始配置 '{config_name}': {config_path}") + + # 最后应用环境变量覆盖 + self._apply_env_overrides() + + except Exception as e: + logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True) + raise + + def load_task_config(self, config_name: str) -> bool: + """ + 按需加载任务配置 + + Args: + config_name: 配置名称 + + Returns: + 是否成功加载 + """ + if config_name in self._loaded_configs: + return True + + if self.config_dir: + config_path = self.config_dir / f"{config_name}.json" + if config_path.exists(): + try: + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + self._raw_configs[config_name] = config_data + + if config_name in self._configs: + self._configs[config_name].update(config_data) + self._loaded_configs.add(config_name) + logger.info(f"按需加载任务配置 '{config_name}': {config_path}") + return True + except Exception as e: + logger.error(f"加载任务配置 '{config_name}' 失败: {e}") + + logger.warning(f"未找到任务配置: {config_name}") + return False + + def _apply_env_overrides(self): + """应用环境变量覆盖""" + logger.info("应用环境变量覆盖...") + # 示例: AI模型配置环境变量覆盖 + ai_model_config = self.get_config('ai_model', AIModelConfig) + if not ai_model_config: return # 如果没有AI配置则跳过 + + env_mapping = { + 'AI_MODEL': 'model', + 'API_URL': 'api_url', + 'API_KEY': 'api_key' + } + update_data = {} + for env_var, config_key in env_mapping.items(): + if os.getenv(env_var): + update_data[config_key] = os.getenv(env_var) + + if update_data: + ai_model_config.update(update_data) + # 更新原始配置 + if 'ai_model' in self._raw_configs: + for key, value in update_data.items(): + self._raw_configs['ai_model'][key] = value + logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}") + + def save_config(self, name: str): + """ + 保存指定的配置到文件 + + Args: + name: 要保存的配置名称 + """ + if not self.config_dir: + raise ValueError("配置目录未设置,无法保存文件") + + path = self.config_dir / f"{name}.json" + config = self.get_config(name, BaseConfig) + config_data = config.to_dict() + + # 更新原始配置 + self._raw_configs[name] = config_data + + try: + with open(path, 'w', encoding='utf-8') as f: + json.dump(config_data, f, indent=4, ensure_ascii=False) + logger.info(f"配置 '{name}' 已保存到 {path}") + except Exception as e: + logger.error(f"保存配置 '{name}' 到 {path} 失败: {e}", exc_info=True) + raise + + +# 全局配置管理器实例 +config_manager = ConfigManager() + +def get_config_manager() -> ConfigManager: + return config_manager \ No newline at end of file