diff --git a/api/__pycache__/dependencies.cpython-312.pyc b/api/__pycache__/dependencies.cpython-312.pyc index 17cdd93..bd44757 100644 Binary files a/api/__pycache__/dependencies.cpython-312.pyc and b/api/__pycache__/dependencies.cpython-312.pyc differ diff --git a/api/dependencies.py b/api/dependencies.py index 7d14fdf..854c355 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -5,22 +5,23 @@ API依赖注入模块 """ +from typing import Optional from core.config import get_config_manager, ConfigManager from core.ai import AIAgent from utils.file_io import OutputManager # 全局依赖 -config_manager = None -ai_agent = None -output_manager = None +config_manager: Optional[ConfigManager] = None +ai_agent: Optional[AIAgent] = None +output_manager: Optional[OutputManager] = None def initialize_dependencies(): """初始化全局依赖""" global config_manager, ai_agent, output_manager - # 初始化配置 + # 初始化配置 - 使用服务器模式 config_manager = get_config_manager() - config_manager.load_from_directory("config") + config_manager.load_from_directory("config", server_mode=True) # 初始化输出管理器 from datetime import datetime @@ -34,12 +35,18 @@ def initialize_dependencies(): def get_config() -> ConfigManager: """获取配置管理器""" + if config_manager is None: + raise RuntimeError("配置管理器未初始化") return config_manager def get_ai_agent() -> AIAgent: """获取AI代理""" + if ai_agent is None: + raise RuntimeError("AI代理未初始化") return ai_agent def get_output_manager() -> OutputManager: """获取输出管理器""" + if output_manager is None: + raise RuntimeError("输出管理器未初始化") return output_manager \ No newline at end of file diff --git a/api/services/__pycache__/prompt_builder.cpython-312.pyc b/api/services/__pycache__/prompt_builder.cpython-312.pyc index a8788c4..901ba5e 100644 Binary files a/api/services/__pycache__/prompt_builder.cpython-312.pyc and b/api/services/__pycache__/prompt_builder.cpython-312.pyc differ diff --git a/api/services/__pycache__/prompt_service.cpython-312.pyc b/api/services/__pycache__/prompt_service.cpython-312.pyc index cf75200..10051f9 100644 Binary files a/api/services/__pycache__/prompt_service.cpython-312.pyc and b/api/services/__pycache__/prompt_service.cpython-312.pyc differ diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py index 0a074f4..db27d5f 100644 --- a/api/services/prompt_builder.py +++ b/api/services/prompt_builder.py @@ -10,7 +10,7 @@ import logging from typing import Dict, Any, Optional, Tuple from pathlib import Path -from core.config import ConfigManager, GenerateContentConfig +from core.config import ConfigManager, GenerateContentConfig, GenerateTopicConfig, PosterConfig from utils.prompts import PromptTemplate from api.services.prompt_service import PromptService @@ -30,7 +30,30 @@ class PromptBuilderService: """ self.config_manager = config_manager self.prompt_service = prompt_service - self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig) + + def _ensure_content_config(self) -> GenerateContentConfig: + """确保内容生成配置已加载""" + # 按需加载内容生成配置 + if not self.config_manager.load_task_config('content_gen'): + logger.warning("未找到内容生成配置,将使用默认配置") + + return self.config_manager.get_config('content_gen', GenerateContentConfig) + + def _ensure_topic_config(self) -> GenerateTopicConfig: + """确保选题生成配置已加载""" + # 按需加载选题生成配置 + if not self.config_manager.load_task_config('topic_gen'): + logger.warning("未找到选题生成配置,将使用默认配置") + + return self.config_manager.get_config('topic_gen', GenerateTopicConfig) + + def _ensure_poster_config(self) -> PosterConfig: + """确保海报生成配置已加载""" + # 按需加载海报生成配置 + if not self.config_manager.load_task_config('poster_gen'): + logger.warning("未找到海报生成配置,将使用默认配置") + + return self.config_manager.get_config('poster_gen', PosterConfig) def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]: """ @@ -43,9 +66,12 @@ class PromptBuilderService: Returns: 系统提示词和用户提示词的元组 """ + # 获取内容生成配置 + content_config = self._ensure_content_config() + # 加载系统提示词和用户提示词模板 - system_prompt_path = self.content_config.content_system_prompt - user_prompt_path = self.content_config.content_user_prompt + system_prompt_path = content_config.content_system_prompt + user_prompt_path = content_config.content_user_prompt # 创建提示词模板 template = PromptTemplate(system_prompt_path, user_prompt_path) @@ -83,6 +109,47 @@ class PromptBuilderService: return system_prompt, user_prompt + def build_poster_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]: + """ + 构建海报生成提示词 + + Args: + topic: 选题信息 + content: 生成的内容 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 获取海报生成配置 + poster_config = self._ensure_poster_config() + + # 从配置中获取海报提示词模板路径 + system_prompt_path = poster_config.poster_system_prompt + user_prompt_path = poster_config.poster_user_prompt + + if not system_prompt_path or not user_prompt_path: + raise ValueError("海报提示词模板路径不完整") + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取景区信息 + object_name = topic.get("object", "") + object_content = self.prompt_service.get_scenic_spot_info(object_name) + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + content=content.get("content", ""), + title=content.get("title", ""), + object_name=object_name, + object_content=object_content + ) + + return system_prompt, user_prompt + def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]: """ 构建选题生成提示词 @@ -94,13 +161,11 @@ class PromptBuilderService: Returns: 系统提示词和用户提示词的元组 """ - # 从配置中获取选题提示词模板路径 - topic_config = self.config_manager.get_config('topic_gen', dict) - if not topic_config: - raise ValueError("未找到选题生成配置") + # 获取选题生成配置 + topic_config = self._ensure_topic_config() - system_prompt_path = topic_config.get("topic_system_prompt", "") - user_prompt_path = topic_config.get("topic_user_prompt", "") + system_prompt_path = topic_config.topic_system_prompt + user_prompt_path = topic_config.topic_user_prompt if not system_prompt_path or not user_prompt_path: raise ValueError("选题提示词模板路径不完整") @@ -155,9 +220,12 @@ class PromptBuilderService: Returns: 系统提示词和用户提示词的元组 """ + # 获取内容生成配置 + content_config = self._ensure_content_config() + # 从配置中获取审核提示词模板路径 - system_prompt_path = self.content_config.judger_system_prompt - user_prompt_path = self.content_config.judger_user_prompt + system_prompt_path = content_config.judger_system_prompt + user_prompt_path = content_config.judger_user_prompt # 创建提示词模板 template = PromptTemplate(system_prompt_path, user_prompt_path) diff --git a/api/services/prompt_service.py b/api/services/prompt_service.py index 076fd57..de8fe01 100644 --- a/api/services/prompt_service.py +++ b/api/services/prompt_service.py @@ -10,6 +10,9 @@ import logging import json import os import re +import sys +import traceback +import time from typing import Dict, Any, Optional, List, cast from pathlib import Path import mysql.connector @@ -33,121 +36,130 @@ class PromptService: config_manager: 配置管理器 """ self.config_manager = config_manager - self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig) + + # 按需加载resource配置 + if 'resource' not in self.config_manager._loaded_configs: + self.config_manager.load_task_config('resource') + + self.resource_config = self.config_manager.get_config('resource', ResourceConfig) + + # ResourceLoader是静态类,不需要实例化 + self.resource_dirs = self.resource_config.resource_dirs # 初始化数据库连接池 - self._init_db_pool() + self.db_pool = self._init_db_pool() - def _init_db_pool(self): - """初始化数据库连接池""" - try: - # 尝试直接从配置文件加载数据库配置 - config_dir = Path("config") - db_config_path = config_dir / "database.json" - - if not db_config_path.exists(): - logger.warning(f"数据库配置文件不存在: {db_config_path}") - self.db_pool = None - return - - # 加载配置文件 - with open(db_config_path, 'r', encoding='utf-8') as f: - db_config = json.load(f) - - # 处理环境变量 - processed_config = {} - for key, value in db_config.items(): - if isinstance(value, str) and "${" in value: - # 匹配 ${ENV_VAR:-default} 格式 - pattern = r'\${([^:-]+)(?::-([^}]+))?}' - match = re.match(pattern, value) - if match: - env_var, default = match.groups() - processed_value = os.environ.get(env_var, default) - # 尝试转换为数字 - if key == "port": - try: - processed_value = int(processed_value) - except (ValueError, TypeError): - processed_value = 3306 - processed_config[key] = processed_value - else: - processed_config[key] = value - - # 创建连接池 - self.db_pool = pooling.MySQLConnectionPool( - pool_name="prompt_pool", - pool_size=5, - host=processed_config.get("host", "localhost"), - user=processed_config.get("user", "root"), - password=processed_config.get("password", ""), - database=processed_config.get("database", "travel_content"), - port=processed_config.get("port", 3306), - charset=processed_config.get("charset", "utf8mb4") - ) - logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}") - except Exception as e: - logger.error(f"初始化数据库连接池失败: {e}") - self.db_pool = None + # 创建必要的目录结构 + self._create_resource_directories() - def get_style_content(self, style_name: str) -> str: - """ - 获取内容风格提示词 + def _create_resource_directories(self): + """创建必要的资源目录""" + try: + dirs_to_create = ['styles', 'attractions', 'products', 'audiences'] + for dir_name in dirs_to_create: + dir_path = Path(dir_name) + if not dir_path.exists(): + dir_path.mkdir(parents=True, exist_ok=True) + logger.info(f"创建目录: {dir_path}") + except Exception as e: + logger.error(f"创建资源目录失败: {e}") + + def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]: + """处理配置中的环境变量""" + processed = {} + for key, value in config.items(): + if isinstance(value, str) and "${" in value: + # 匹配 ${ENV_VAR:-default} 格式 + pattern = r'\${([^:-]+)(?::-([^}]+))?}' + match = re.match(pattern, value) + if match: + env_var, default = match.groups() + processed_value = os.environ.get(env_var, default or "") + # 尝试转换为数字 + if key == "port": + try: + processed_value = int(processed_value) + except (ValueError, TypeError): + processed_value = 3306 + processed[key] = processed_value + else: + processed[key] = value + else: + processed[key] = value + return processed + + def _init_db_pool(self): + """初始化数据库连接池,尝试多种连接方式""" + # 获取数据库配置 + raw_db_config = self.config_manager.get_raw_config('database') - Args: - style_name: 风格名称 - - Returns: - 风格提示词内容 - """ - # 优先从数据库获取 + # 处理环境变量 + db_config = self._process_env_vars(raw_db_config) + + # 连接尝试配置 + connection_attempts = [ + {"desc": "使用配置文件中的设置", "config": db_config}, + {"desc": "使用明确的密码", "config": {**db_config, "password": "password"}}, + {"desc": "使用空密码", "config": {**db_config, "password": ""}}, + {"desc": "使用auth_plugin", "config": {**db_config, "auth_plugin": "mysql_native_password"}} + ] + + # 尝试不同的连接方式 + for attempt in connection_attempts: + try: + # 打印连接信息(不包含密码) + connection_info = {k: v for k, v in attempt["config"].items() if k != 'password'} + logger.info(f"尝试连接数据库 ({attempt['desc']}): {connection_info}") + + # 创建连接池 + pool = pooling.MySQLConnectionPool( + pool_name=f"prompt_service_pool_{int(time.time())}", + pool_size=5, + **attempt["config"] + ) + + # 测试连接 + with pool.get_connection() as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchall() + + logger.info(f"数据库连接池初始化成功 ({attempt['desc']})") + return pool + except Exception as e: + error_details = traceback.format_exc() + logger.error(f"数据库连接尝试 ({attempt['desc']}) 失败: {e}\n{error_details}") + + logger.warning("所有数据库连接尝试均失败,将使用文件系统作为数据源") + return None + + def get_style_content(self, style_name: str) -> str: + """获取风格提示词内容""" + # 尝试从数据库获取 if self.db_pool: try: - conn = self.db_pool.get_connection() - cursor = conn.cursor(dictionary=True) - cursor.execute( - "SELECT description FROM contentStyle WHERE styleName = %s", - (style_name,) - ) - result = cursor.fetchone() - cursor.close() - conn.close() - - if result: - logger.info(f"从数据库获取风格提示词: {style_name}") - return result["description"] + with self.db_pool.get_connection() as conn: + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT description FROM contentStyle WHERE styleName = %s AND isDelete = 0", + (style_name,) + ) + result = cursor.fetchone() + if result: + return result['description'] except Exception as e: logger.error(f"从数据库获取风格提示词失败: {e}") - # 回退到文件系统 - try: - style_paths = self.resource_config.style.paths - for path_str in style_paths: - try: - if style_name in path_str: - full_path = self._get_full_path(path_str) - if full_path.exists(): - logger.info(f"从文件系统获取风格提示词: {style_name}") - return full_path.read_text('utf-8') - except Exception as e: - logger.error(f"读取风格文件失败 {path_str}: {e}") - - # 如果没有精确匹配,尝试模糊匹配 - for path_str in style_paths: - try: - full_path = self._get_full_path(path_str) - if full_path.exists() and full_path.is_file(): - content = full_path.read_text('utf-8') - if style_name.lower() in full_path.stem.lower(): - logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}") - return content - except Exception as e: - logger.error(f"读取风格文件失败 {path_str}: {e}") - except Exception as e: - logger.error(f"获取风格提示词失败: {e}") - - logger.warning(f"未找到风格提示词: {style_name},将使用默认值") - return "通用风格" + # 从文件系统获取 + logger.info(f"从文件系统获取风格提示词: {style_name}") + for dir_path in self.resource_dirs: + style_file = ResourceLoader.find_file(os.path.join(dir_path, "styles"), style_name) + if style_file: + content = ResourceLoader.load_text_file(style_file) + if content: + return content + + return f"未找到风格 '{style_name}' 的提示词" def get_audience_content(self, audience_name: str) -> str: """ @@ -224,7 +236,7 @@ class PromptService: conn = self.db_pool.get_connection() cursor = conn.cursor(dictionary=True) cursor.execute( - "SELECT description FROM scenicSpot WHERE spotName = %s", + "SELECT description FROM scenicSpot WHERE name = %s", (spot_name,) ) result = cursor.fetchone() @@ -271,7 +283,7 @@ class PromptService: conn = self.db_pool.get_connection() cursor = conn.cursor(dictionary=True) cursor.execute( - "SELECT description FROM product WHERE productName = %s", + "SELECT description FROM product WHERE name = %s", (product_name,) ) result = cursor.fetchone() @@ -453,7 +465,7 @@ class PromptService: try: conn = self.db_pool.get_connection() cursor = conn.cursor(dictionary=True) - cursor.execute("SELECT spotName as name, description FROM scenicSpot") + cursor.execute("SELECT name as name, description FROM scenicSpot") results = cursor.fetchall() cursor.close() conn.close() diff --git a/config/database.json b/config/database.json new file mode 100644 index 0000000..67f7c88 --- /dev/null +++ b/config/database.json @@ -0,0 +1,8 @@ +{ + "host": "localhost", + "user": "root", + "password": "password", + "database": "travel_content", + "port": 3306, + "charset": "utf8mb4" +} \ No newline at end of file diff --git a/core/config/__pycache__/manager.cpython-312.pyc b/core/config/__pycache__/manager.cpython-312.pyc index f76d679..d913be1 100644 Binary files a/core/config/__pycache__/manager.cpython-312.pyc and b/core/config/__pycache__/manager.cpython-312.pyc differ diff --git a/core/config/__pycache__/models.cpython-312.pyc b/core/config/__pycache__/models.cpython-312.pyc index 6e86875..76b0456 100644 Binary files a/core/config/__pycache__/models.cpython-312.pyc and b/core/config/__pycache__/models.cpython-312.pyc differ diff --git a/core/config/manager.py b/core/config/manager.py index f73a7e0..463554f 100644 --- a/core/config/manager.py +++ b/core/config/manager.py @@ -9,7 +9,7 @@ import json import os import logging from pathlib import Path -from typing import Dict, Type, TypeVar, Optional, Any, cast +from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set from core.config.models import ( BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, @@ -27,23 +27,30 @@ class ConfigManager: 负责加载、管理和访问所有配置 """ + # 服务端必要的全局配置 + SERVER_CONFIGS = {'system', 'ai_model', 'database'} + + # 单次生成任务的配置 + TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'} + def __init__(self): self._configs: Dict[str, BaseConfig] = {} + self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据 self.config_dir: Optional[Path] = None self.config_objects = { 'ai_model': AIModelConfig(), 'system': SystemConfig(), - 'topic_gen': GenerateTopicConfig(), - 'content_gen': GenerateContentConfig(), 'resource': ResourceConfig() } + self._loaded_configs: Set[str] = set() - def load_from_directory(self, config_dir: str): + def load_from_directory(self, config_dir: str, server_mode: bool = False): """ 从目录加载配置 Args: config_dir: 配置文件目录 + server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 """ self.config_dir = Path(config_dir) if not self.config_dir.is_dir(): @@ -54,15 +61,17 @@ class ConfigManager: self._register_configs() # 动态加载目录中的所有.json文件 - self._load_all_configs_from_dir() + self._load_all_configs_from_dir(server_mode) def _register_configs(self): """注册所有配置""" self.register_config('ai_model', AIModelConfig) + self.register_config('system', SystemConfig) + self.register_config('resource', ResourceConfig) + + # 这些配置在服务器模式下不会自动加载,但仍然需要注册类型 self.register_config('poster', PosterConfig) self.register_config('content', ContentConfig) - self.register_config('resource', ResourceConfig) - self.register_config('system', SystemConfig) self.register_config('topic_gen', GenerateTopicConfig) self.register_config('content_gen', GenerateContentConfig) @@ -113,49 +122,106 @@ class ConfigManager: return cast(T, config) - def _load_all_configs_from_dir(self): - """动态加载目录中的所有.json文件""" - try: - # 1. 优先加载旧的主配置文件以实现向后兼容 - main_config_path = self.config_dir / 'poster_gen_config.json' - if main_config_path.exists(): - self._load_main_config(main_config_path) - else: - logger.warning(f"旧的主配置文件不存在: {main_config_path}") + def get_raw_config(self, name: str) -> Dict[str, Any]: + """ + 获取原始配置数据 + + Args: + name: 配置名称 - # 2. 遍历并加载目录中所有其他的 .json 文件 + Returns: + 原始配置数据字典 + """ + if name in self._raw_configs: + return self._raw_configs[name] + + # 如果没有原始配置,但有对象配置,则转换为字典 + if name in self._configs: + return self._configs[name].to_dict() + + # 尝试从文件加载 + if self.config_dir: + config_path = self.config_dir / f"{name}.json" + if config_path.exists(): + try: + with open(config_path, 'r', encoding='utf-8') as f: + raw_config = json.load(f) + self._raw_configs[name] = raw_config + return raw_config + except Exception as e: + logger.error(f"加载原始配置 '{name}' 失败: {e}") + + # 返回空字典 + return {} + + def _load_all_configs_from_dir(self, server_mode: bool = False): + """ + 动态加载目录中的所有.json文件 + + Args: + server_mode: 是否为服务器模式,如果是则只加载必要的全局配置 + """ + try: + # 遍历并加载目录中所有其他的 .json 文件 for config_path in self.config_dir.glob('*.json'): - if config_path.name == 'poster_gen_config.json': - continue # 跳过已处理的主配置文件 - config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen' + + # 服务器模式下,只加载必要的全局配置 + if server_mode and config_name not in self.SERVER_CONFIGS: + logger.info(f"服务器模式下跳过非全局配置: {config_name}") + continue + + # 加载原始配置 + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + self._raw_configs[config_name] = config_data + + # 更新对象配置 if config_name in self._configs: logger.info(f"加载配置文件 '{config_name}': {config_path}") - with open(config_path, 'r', encoding='utf-8') as f: - config_data = json.load(f) self._configs[config_name].update(config_data) + self._loaded_configs.add(config_name) else: - logger.warning(f"在 '{config_path}' 中找到的配置 '{config_name}' 没有对应的已注册配置类型,已跳过。") + logger.info(f"加载原始配置 '{config_name}': {config_path}") - # 3. 最后应用环境变量覆盖 + # 最后应用环境变量覆盖 self._apply_env_overrides() except Exception as e: logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True) raise - def _load_main_config(self, path: Path): - """加载主配置文件,并分发到各个配置对象""" - logger.info(f"加载主配置文件: {path}") - with open(path, 'r', encoding='utf-8') as f: - config_data = json.load(f) + def load_task_config(self, config_name: str) -> bool: + """ + 按需加载任务配置 - for name, config_obj in self._configs.items(): - # 主配置文件可能是扁平结构或嵌套结构 - if name in config_data: # 嵌套结构 - config_obj.update(config_data[name]) - else: # 尝试从根部更新 (扁平结构) - config_obj.update(config_data) + Args: + config_name: 配置名称 + + Returns: + 是否成功加载 + """ + if config_name in self._loaded_configs: + return True + + if self.config_dir: + config_path = self.config_dir / f"{config_name}.json" + if config_path.exists(): + try: + with open(config_path, 'r', encoding='utf-8') as f: + config_data = json.load(f) + self._raw_configs[config_name] = config_data + + if config_name in self._configs: + self._configs[config_name].update(config_data) + self._loaded_configs.add(config_name) + logger.info(f"按需加载任务配置 '{config_name}': {config_path}") + return True + except Exception as e: + logger.error(f"加载任务配置 '{config_name}' 失败: {e}") + + logger.warning(f"未找到任务配置: {config_name}") + return False def _apply_env_overrides(self): """应用环境变量覆盖""" @@ -176,6 +242,10 @@ class ConfigManager: if update_data: ai_model_config.update(update_data) + # 更新原始配置 + if 'ai_model' in self._raw_configs: + for key, value in update_data.items(): + self._raw_configs['ai_model'][key] = value logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}") def save_config(self, name: str): @@ -191,6 +261,9 @@ class ConfigManager: path = self.config_dir / f"{name}.json" config = self.get_config(name, BaseConfig) config_data = config.to_dict() + + # 更新原始配置 + self._raw_configs[name] = config_data try: with open(path, 'w', encoding='utf-8') as f: diff --git a/core/config/models.py b/core/config/models.py index 7e9c217..e9655b7 100644 --- a/core/config/models.py +++ b/core/config/models.py @@ -161,6 +161,8 @@ class PosterConfig(BaseConfig): additional_images_enabled: bool = True template_selection: str = "random" # random, business, vibrant, original available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"]) + poster_system_prompt: str = "resource/prompt/generatePoster/system.txt" + poster_user_prompt: str = "resource/prompt/generatePoster/user.txt" class ContentConfig(BaseConfig):