TravelContentCreator/api/config/poster_config_manager.py

231 lines
8.0 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
海报配置管理器
负责加载和管理海报相关的配置信息
"""
import os
import yaml
import json
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
logger = logging.getLogger(__name__)
class PosterConfigManager:
"""海报配置管理器"""
def __init__(self, config_file: Optional[str] = None):
"""
初始化配置管理器
Args:
config_file: 配置文件路径,如果为空则使用默认路径
"""
if config_file is None:
config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml")
self.config_file = config_file
self.config = {}
self.load_config()
def load_config(self):
"""加载配置文件"""
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
self.config = yaml.safe_load(f)
logger.info(f"成功加载配置文件: {self.config_file}")
except FileNotFoundError:
logger.error(f"配置文件不存在: {self.config_file}")
self.config = self._get_default_config()
except yaml.YAMLError as e:
logger.error(f"解析配置文件失败: {e}")
self.config = self._get_default_config()
except Exception as e:
logger.error(f"加载配置文件时发生未知错误: {e}")
self.config = self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""获取默认配置"""
return {
"poster_prompts": {},
"templates": {},
"defaults": {
"template": "vibrant",
"temperature": 0.7,
"output_dir": "result/posters",
"image_dir": "/root/TravelContentCreator/data/images",
"font_dir": "/root/TravelContentCreator/assets/font"
}
}
def get_template_list(self) -> List[Dict[str, Any]]:
"""获取所有可用的模板列表"""
templates = self.config.get("templates", {})
return [
{
"id": template_id,
"name": template_info.get("name", template_id),
"description": template_info.get("description", ""),
"size": template_info.get("size", [900, 1200]),
"required_fields": template_info.get("required_fields", []),
"optional_fields": template_info.get("optional_fields", [])
}
for template_id, template_info in templates.items()
]
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
"""获取指定模板的详细信息"""
return self.config.get("templates", {}).get(template_id)
def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]:
"""获取指定模板的提示词配置"""
template_info = self.get_template_info(template_id)
if not template_info:
return None
prompt_key = template_info.get("prompt_key", template_id)
return self.config.get("poster_prompts", {}).get(prompt_key)
def get_system_prompt(self, template_id: str) -> Optional[str]:
"""获取系统提示词"""
prompt_config = self.get_prompt_config(template_id)
if prompt_config:
return prompt_config.get("system_prompt")
return None
def get_user_prompt_template(self, template_id: str) -> Optional[str]:
"""获取用户提示词模板"""
prompt_config = self.get_prompt_config(template_id)
if prompt_config:
return prompt_config.get("user_prompt_template")
return None
def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]:
"""
格式化用户提示词
Args:
template_id: 模板ID
**kwargs: 用于格式化的参数
Returns:
格式化后的用户提示词
"""
template = self.get_user_prompt_template(template_id)
if not template:
return None
try:
# 确保参数中的字典类型转为JSON字符串
formatted_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, (dict, list)):
formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2)
else:
formatted_kwargs[key] = str(value)
return template.format(**formatted_kwargs)
except KeyError as e:
logger.error(f"格式化提示词失败,缺少参数: {e}")
return None
except Exception as e:
logger.error(f"格式化提示词时发生错误: {e}")
return None
def get_default_config(self, key: str) -> Any:
"""获取默认配置值"""
return self.config.get("defaults", {}).get(key)
def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]:
"""
验证模板内容是否符合要求
Args:
template_id: 模板ID
content: 要验证的内容
Returns:
(是否有效, 错误信息列表)
"""
template_info = self.get_template_info(template_id)
if not template_info:
return False, [f"未知的模板ID: {template_id}"]
errors = []
required_fields = template_info.get("required_fields", [])
# 检查必填字段
for field in required_fields:
if field not in content:
errors.append(f"缺少必填字段: {field}")
elif not content[field]:
errors.append(f"必填字段 {field} 不能为空")
return len(errors) == 0, errors
def get_template_class(self, template_id: str):
"""
动态获取模板类
Args:
template_id: 模板ID
Returns:
模板类
"""
template_info = self.get_template_info(template_id)
if not template_info:
raise ValueError(f"未知的模板ID: {template_id}")
# 优先使用数据库中配置的路径
handler_path = template_info.get("handler_path")
class_name = template_info.get("class_name")
if handler_path and class_name:
try:
module = __import__(handler_path, fromlist=[class_name])
return getattr(module, class_name)
except ImportError as e:
logger.error(f"导入模板类失败: {e}")
raise ValueError(f"无法加载模板类: {template_id}")
# 回退到默认映射
template_mapping = {
"vibrant": "poster.templates.vibrant_template.VibrantTemplate",
"business": "poster.templates.business_template.BusinessTemplate"
}
class_path = template_mapping.get(template_id)
if not class_path:
raise ValueError(f"未支持的模板类型: {template_id}")
# 动态导入模板类
module_path, class_name = class_path.rsplit(".", 1)
try:
module = __import__(module_path, fromlist=[class_name])
return getattr(module, class_name)
except ImportError as e:
logger.error(f"导入模板类失败: {e}")
raise ValueError(f"无法加载模板类: {template_id}")
def reload_config(self):
"""重新加载配置"""
logger.info("正在重新加载配置...")
self.load_config()
# 全局配置管理器实例
_config_manager = None
def get_poster_config_manager() -> PosterConfigManager:
"""获取全局配置管理器实例"""
global _config_manager
if _config_manager is None:
_config_manager = PosterConfigManager()
return _config_manager