#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 海报配置管理器 负责加载和管理海报相关的配置信息 """ import os import yaml import json import logging from typing import Dict, Any, Optional, List from pathlib import Path logger = logging.getLogger(__name__) class PosterConfigManager: """海报配置管理器""" def __init__(self, config_file: Optional[str] = None): """ 初始化配置管理器 Args: config_file: 配置文件路径,如果为空则使用默认路径 """ if config_file is None: config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml") self.config_file = config_file self.config = {} self.load_config() def load_config(self): """加载配置文件""" try: with open(self.config_file, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) logger.info(f"成功加载配置文件: {self.config_file}") except FileNotFoundError: logger.error(f"配置文件不存在: {self.config_file}") self.config = self._get_default_config() except yaml.YAMLError as e: logger.error(f"解析配置文件失败: {e}") self.config = self._get_default_config() except Exception as e: logger.error(f"加载配置文件时发生未知错误: {e}") self.config = self._get_default_config() def _get_default_config(self) -> Dict[str, Any]: """获取默认配置""" return { "poster_prompts": {}, "templates": {}, "defaults": { "template": "vibrant", "temperature": 0.7, "output_dir": "result/posters", "image_dir": "/root/TravelContentCreator/data/images", "font_dir": "/root/TravelContentCreator/assets/font" } } def get_template_list(self) -> List[Dict[str, Any]]: """获取所有可用的模板列表""" templates = self.config.get("templates", {}) return [ { "id": template_id, "name": template_info.get("name", template_id), "description": template_info.get("description", ""), "size": template_info.get("size", [900, 1200]), "required_fields": template_info.get("required_fields", []), "optional_fields": template_info.get("optional_fields", []) } for template_id, template_info in templates.items() ] def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]: """获取指定模板的详细信息""" return self.config.get("templates", {}).get(template_id) def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]: """获取指定模板的提示词配置""" template_info = self.get_template_info(template_id) if not template_info: return None prompt_key = template_info.get("prompt_key", template_id) return self.config.get("poster_prompts", {}).get(prompt_key) def get_system_prompt(self, template_id: str) -> Optional[str]: """获取系统提示词""" prompt_config = self.get_prompt_config(template_id) if prompt_config: return prompt_config.get("system_prompt") return None def get_user_prompt_template(self, template_id: str) -> Optional[str]: """获取用户提示词模板""" prompt_config = self.get_prompt_config(template_id) if prompt_config: return prompt_config.get("user_prompt_template") return None def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]: """ 格式化用户提示词 Args: template_id: 模板ID **kwargs: 用于格式化的参数 Returns: 格式化后的用户提示词 """ template = self.get_user_prompt_template(template_id) if not template: return None try: # 确保参数中的字典类型转为JSON字符串 formatted_kwargs = {} for key, value in kwargs.items(): if isinstance(value, (dict, list)): formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2) else: formatted_kwargs[key] = str(value) return template.format(**formatted_kwargs) except KeyError as e: logger.error(f"格式化提示词失败,缺少参数: {e}") return None except Exception as e: logger.error(f"格式化提示词时发生错误: {e}") return None def get_default_config(self, key: str) -> Any: """获取默认配置值""" return self.config.get("defaults", {}).get(key) def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]: """ 验证模板内容是否符合要求 Args: template_id: 模板ID content: 要验证的内容 Returns: (是否有效, 错误信息列表) """ template_info = self.get_template_info(template_id) if not template_info: return False, [f"未知的模板ID: {template_id}"] errors = [] required_fields = template_info.get("required_fields", []) # 检查必填字段 for field in required_fields: if field not in content: errors.append(f"缺少必填字段: {field}") elif not content[field]: errors.append(f"必填字段 {field} 不能为空") return len(errors) == 0, errors def get_template_class(self, template_id: str): """ 动态获取模板类 Args: template_id: 模板ID Returns: 模板类 """ template_mapping = { "vibrant": "poster.templates.vibrant_template.VibrantTemplate", "business": "poster.templates.business_template.BusinessTemplate" } class_path = template_mapping.get(template_id) if not class_path: raise ValueError(f"未支持的模板类型: {template_id}") # 动态导入模板类 module_path, class_name = class_path.rsplit(".", 1) try: module = __import__(module_path, fromlist=[class_name]) return getattr(module, class_name) except ImportError as e: logger.error(f"导入模板类失败: {e}") raise ValueError(f"无法加载模板类: {template_id}") def reload_config(self): """重新加载配置""" logger.info("正在重新加载配置...") self.load_config() # 全局配置管理器实例 _config_manager = None def get_poster_config_manager() -> PosterConfigManager: """获取全局配置管理器实例""" global _config_manager if _config_manager is None: _config_manager = PosterConfigManager() return _config_manager