231 lines
8.0 KiB
Python
231 lines
8.0 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
海报配置管理器
|
|||
|
|
负责加载和管理海报相关的配置信息
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import yaml
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, Optional, List
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PosterConfigManager:
|
|||
|
|
"""海报配置管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, config_file: Optional[str] = None):
|
|||
|
|
"""
|
|||
|
|
初始化配置管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
config_file: 配置文件路径,如果为空则使用默认路径
|
|||
|
|
"""
|
|||
|
|
if config_file is None:
|
|||
|
|
config_file = os.path.join(os.path.dirname(__file__), "poster_prompts.yaml")
|
|||
|
|
|
|||
|
|
self.config_file = config_file
|
|||
|
|
self.config = {}
|
|||
|
|
self.load_config()
|
|||
|
|
|
|||
|
|
def load_config(self):
|
|||
|
|
"""加载配置文件"""
|
|||
|
|
try:
|
|||
|
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|||
|
|
self.config = yaml.safe_load(f)
|
|||
|
|
logger.info(f"成功加载配置文件: {self.config_file}")
|
|||
|
|
except FileNotFoundError:
|
|||
|
|
logger.error(f"配置文件不存在: {self.config_file}")
|
|||
|
|
self.config = self._get_default_config()
|
|||
|
|
except yaml.YAMLError as e:
|
|||
|
|
logger.error(f"解析配置文件失败: {e}")
|
|||
|
|
self.config = self._get_default_config()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"加载配置文件时发生未知错误: {e}")
|
|||
|
|
self.config = self._get_default_config()
|
|||
|
|
|
|||
|
|
def _get_default_config(self) -> Dict[str, Any]:
|
|||
|
|
"""获取默认配置"""
|
|||
|
|
return {
|
|||
|
|
"poster_prompts": {},
|
|||
|
|
"templates": {},
|
|||
|
|
"defaults": {
|
|||
|
|
"template": "vibrant",
|
|||
|
|
"temperature": 0.7,
|
|||
|
|
"output_dir": "result/posters",
|
|||
|
|
"image_dir": "/root/TravelContentCreator/data/images",
|
|||
|
|
"font_dir": "/root/TravelContentCreator/assets/font"
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def get_template_list(self) -> List[Dict[str, Any]]:
|
|||
|
|
"""获取所有可用的模板列表"""
|
|||
|
|
templates = self.config.get("templates", {})
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
"id": template_id,
|
|||
|
|
"name": template_info.get("name", template_id),
|
|||
|
|
"description": template_info.get("description", ""),
|
|||
|
|
"size": template_info.get("size", [900, 1200]),
|
|||
|
|
"required_fields": template_info.get("required_fields", []),
|
|||
|
|
"optional_fields": template_info.get("optional_fields", [])
|
|||
|
|
}
|
|||
|
|
for template_id, template_info in templates.items()
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""获取指定模板的详细信息"""
|
|||
|
|
return self.config.get("templates", {}).get(template_id)
|
|||
|
|
|
|||
|
|
def get_prompt_config(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""获取指定模板的提示词配置"""
|
|||
|
|
template_info = self.get_template_info(template_id)
|
|||
|
|
if not template_info:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
prompt_key = template_info.get("prompt_key", template_id)
|
|||
|
|
return self.config.get("poster_prompts", {}).get(prompt_key)
|
|||
|
|
|
|||
|
|
def get_system_prompt(self, template_id: str) -> Optional[str]:
|
|||
|
|
"""获取系统提示词"""
|
|||
|
|
prompt_config = self.get_prompt_config(template_id)
|
|||
|
|
if prompt_config:
|
|||
|
|
return prompt_config.get("system_prompt")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_user_prompt_template(self, template_id: str) -> Optional[str]:
|
|||
|
|
"""获取用户提示词模板"""
|
|||
|
|
prompt_config = self.get_prompt_config(template_id)
|
|||
|
|
if prompt_config:
|
|||
|
|
return prompt_config.get("user_prompt_template")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def format_user_prompt(self, template_id: str, **kwargs) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
格式化用户提示词
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_id: 模板ID
|
|||
|
|
**kwargs: 用于格式化的参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
格式化后的用户提示词
|
|||
|
|
"""
|
|||
|
|
template = self.get_user_prompt_template(template_id)
|
|||
|
|
if not template:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 确保参数中的字典类型转为JSON字符串
|
|||
|
|
formatted_kwargs = {}
|
|||
|
|
for key, value in kwargs.items():
|
|||
|
|
if isinstance(value, (dict, list)):
|
|||
|
|
formatted_kwargs[key] = json.dumps(value, ensure_ascii=False, indent=2)
|
|||
|
|
else:
|
|||
|
|
formatted_kwargs[key] = str(value)
|
|||
|
|
|
|||
|
|
return template.format(**formatted_kwargs)
|
|||
|
|
except KeyError as e:
|
|||
|
|
logger.error(f"格式化提示词失败,缺少参数: {e}")
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"格式化提示词时发生错误: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_default_config(self, key: str) -> Any:
|
|||
|
|
"""获取默认配置值"""
|
|||
|
|
return self.config.get("defaults", {}).get(key)
|
|||
|
|
|
|||
|
|
def validate_template_content(self, template_id: str, content: Dict[str, Any]) -> tuple[bool, List[str]]:
|
|||
|
|
"""
|
|||
|
|
验证模板内容是否符合要求
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_id: 模板ID
|
|||
|
|
content: 要验证的内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(是否有效, 错误信息列表)
|
|||
|
|
"""
|
|||
|
|
template_info = self.get_template_info(template_id)
|
|||
|
|
if not template_info:
|
|||
|
|
return False, [f"未知的模板ID: {template_id}"]
|
|||
|
|
|
|||
|
|
errors = []
|
|||
|
|
required_fields = template_info.get("required_fields", [])
|
|||
|
|
|
|||
|
|
# 检查必填字段
|
|||
|
|
for field in required_fields:
|
|||
|
|
if field not in content:
|
|||
|
|
errors.append(f"缺少必填字段: {field}")
|
|||
|
|
elif not content[field]:
|
|||
|
|
errors.append(f"必填字段 {field} 不能为空")
|
|||
|
|
|
|||
|
|
return len(errors) == 0, errors
|
|||
|
|
|
|||
|
|
def get_template_class(self, template_id: str):
|
|||
|
|
"""
|
|||
|
|
动态获取模板类
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
template_id: 模板ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
模板类
|
|||
|
|
"""
|
|||
|
|
template_info = self.get_template_info(template_id)
|
|||
|
|
if not template_info:
|
|||
|
|
raise ValueError(f"未知的模板ID: {template_id}")
|
|||
|
|
|
|||
|
|
# 优先使用数据库中配置的路径
|
|||
|
|
handler_path = template_info.get("handler_path")
|
|||
|
|
class_name = template_info.get("class_name")
|
|||
|
|
|
|||
|
|
if handler_path and class_name:
|
|||
|
|
try:
|
|||
|
|
module = __import__(handler_path, fromlist=[class_name])
|
|||
|
|
return getattr(module, class_name)
|
|||
|
|
except ImportError as e:
|
|||
|
|
logger.error(f"导入模板类失败: {e}")
|
|||
|
|
raise ValueError(f"无法加载模板类: {template_id}")
|
|||
|
|
|
|||
|
|
# 回退到默认映射
|
|||
|
|
template_mapping = {
|
|||
|
|
"vibrant": "poster.templates.vibrant_template.VibrantTemplate",
|
|||
|
|
"business": "poster.templates.business_template.BusinessTemplate"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
class_path = template_mapping.get(template_id)
|
|||
|
|
if not class_path:
|
|||
|
|
raise ValueError(f"未支持的模板类型: {template_id}")
|
|||
|
|
|
|||
|
|
# 动态导入模板类
|
|||
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|||
|
|
try:
|
|||
|
|
module = __import__(module_path, fromlist=[class_name])
|
|||
|
|
return getattr(module, class_name)
|
|||
|
|
except ImportError as e:
|
|||
|
|
logger.error(f"导入模板类失败: {e}")
|
|||
|
|
raise ValueError(f"无法加载模板类: {template_id}")
|
|||
|
|
|
|||
|
|
def reload_config(self):
|
|||
|
|
"""重新加载配置"""
|
|||
|
|
logger.info("正在重新加载配置...")
|
|||
|
|
self.load_config()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局配置管理器实例
|
|||
|
|
_config_manager = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_poster_config_manager() -> PosterConfigManager:
|
|||
|
|
"""获取全局配置管理器实例"""
|
|||
|
|
global _config_manager
|
|||
|
|
if _config_manager is None:
|
|||
|
|
_config_manager = PosterConfigManager()
|
|||
|
|
return _config_manager
|