修复了数据库获取数据的功能

This commit is contained in:
jinye_huang 2025-07-11 15:29:30 +08:00
parent 19ca9d06ce
commit f6a48031a0
11 changed files with 330 additions and 160 deletions

View File

@ -5,22 +5,23 @@
API依赖注入模块 API依赖注入模块
""" """
from typing import Optional
from core.config import get_config_manager, ConfigManager from core.config import get_config_manager, ConfigManager
from core.ai import AIAgent from core.ai import AIAgent
from utils.file_io import OutputManager from utils.file_io import OutputManager
# 全局依赖 # 全局依赖
config_manager = None config_manager: Optional[ConfigManager] = None
ai_agent = None ai_agent: Optional[AIAgent] = None
output_manager = None output_manager: Optional[OutputManager] = None
def initialize_dependencies(): def initialize_dependencies():
"""初始化全局依赖""" """初始化全局依赖"""
global config_manager, ai_agent, output_manager global config_manager, ai_agent, output_manager
# 初始化配置 # 初始化配置 - 使用服务器模式
config_manager = get_config_manager() config_manager = get_config_manager()
config_manager.load_from_directory("config") config_manager.load_from_directory("config", server_mode=True)
# 初始化输出管理器 # 初始化输出管理器
from datetime import datetime from datetime import datetime
@ -34,12 +35,18 @@ def initialize_dependencies():
def get_config() -> ConfigManager: def get_config() -> ConfigManager:
"""获取配置管理器""" """获取配置管理器"""
if config_manager is None:
raise RuntimeError("配置管理器未初始化")
return config_manager return config_manager
def get_ai_agent() -> AIAgent: def get_ai_agent() -> AIAgent:
"""获取AI代理""" """获取AI代理"""
if ai_agent is None:
raise RuntimeError("AI代理未初始化")
return ai_agent return ai_agent
def get_output_manager() -> OutputManager: def get_output_manager() -> OutputManager:
"""获取输出管理器""" """获取输出管理器"""
if output_manager is None:
raise RuntimeError("输出管理器未初始化")
return output_manager return output_manager

View File

@ -10,7 +10,7 @@ import logging
from typing import Dict, Any, Optional, Tuple from typing import Dict, Any, Optional, Tuple
from pathlib import Path from pathlib import Path
from core.config import ConfigManager, GenerateContentConfig from core.config import ConfigManager, GenerateContentConfig, GenerateTopicConfig, PosterConfig
from utils.prompts import PromptTemplate from utils.prompts import PromptTemplate
from api.services.prompt_service import PromptService from api.services.prompt_service import PromptService
@ -30,7 +30,30 @@ class PromptBuilderService:
""" """
self.config_manager = config_manager self.config_manager = config_manager
self.prompt_service = prompt_service self.prompt_service = prompt_service
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
def _ensure_content_config(self) -> GenerateContentConfig:
"""确保内容生成配置已加载"""
# 按需加载内容生成配置
if not self.config_manager.load_task_config('content_gen'):
logger.warning("未找到内容生成配置,将使用默认配置")
return self.config_manager.get_config('content_gen', GenerateContentConfig)
def _ensure_topic_config(self) -> GenerateTopicConfig:
"""确保选题生成配置已加载"""
# 按需加载选题生成配置
if not self.config_manager.load_task_config('topic_gen'):
logger.warning("未找到选题生成配置,将使用默认配置")
return self.config_manager.get_config('topic_gen', GenerateTopicConfig)
def _ensure_poster_config(self) -> PosterConfig:
"""确保海报生成配置已加载"""
# 按需加载海报生成配置
if not self.config_manager.load_task_config('poster_gen'):
logger.warning("未找到海报生成配置,将使用默认配置")
return self.config_manager.get_config('poster_gen', PosterConfig)
def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]: def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]:
""" """
@ -43,9 +66,12 @@ class PromptBuilderService:
Returns: Returns:
系统提示词和用户提示词的元组 系统提示词和用户提示词的元组
""" """
# 获取内容生成配置
content_config = self._ensure_content_config()
# 加载系统提示词和用户提示词模板 # 加载系统提示词和用户提示词模板
system_prompt_path = self.content_config.content_system_prompt system_prompt_path = content_config.content_system_prompt
user_prompt_path = self.content_config.content_user_prompt user_prompt_path = content_config.content_user_prompt
# 创建提示词模板 # 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path) template = PromptTemplate(system_prompt_path, user_prompt_path)
@ -83,6 +109,47 @@ class PromptBuilderService:
return system_prompt, user_prompt return system_prompt, user_prompt
def build_poster_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]:
"""
构建海报生成提示词
Args:
topic: 选题信息
content: 生成的内容
Returns:
系统提示词和用户提示词的元组
"""
# 获取海报生成配置
poster_config = self._ensure_poster_config()
# 从配置中获取海报提示词模板路径
system_prompt_path = poster_config.poster_system_prompt
user_prompt_path = poster_config.poster_user_prompt
if not system_prompt_path or not user_prompt_path:
raise ValueError("海报提示词模板路径不完整")
# 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path)
# 获取景区信息
object_name = topic.get("object", "")
object_content = self.prompt_service.get_scenic_spot_info(object_name)
# 构建系统提示词
system_prompt = template.get_system_prompt()
# 构建用户提示词
user_prompt = template.build_user_prompt(
content=content.get("content", ""),
title=content.get("title", ""),
object_name=object_name,
object_content=object_content
)
return system_prompt, user_prompt
def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]: def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]:
""" """
构建选题生成提示词 构建选题生成提示词
@ -94,13 +161,11 @@ class PromptBuilderService:
Returns: Returns:
系统提示词和用户提示词的元组 系统提示词和用户提示词的元组
""" """
# 从配置中获取选题提示词模板路径 # 获取选题生成配置
topic_config = self.config_manager.get_config('topic_gen', dict) topic_config = self._ensure_topic_config()
if not topic_config:
raise ValueError("未找到选题生成配置")
system_prompt_path = topic_config.get("topic_system_prompt", "") system_prompt_path = topic_config.topic_system_prompt
user_prompt_path = topic_config.get("topic_user_prompt", "") user_prompt_path = topic_config.topic_user_prompt
if not system_prompt_path or not user_prompt_path: if not system_prompt_path or not user_prompt_path:
raise ValueError("选题提示词模板路径不完整") raise ValueError("选题提示词模板路径不完整")
@ -155,9 +220,12 @@ class PromptBuilderService:
Returns: Returns:
系统提示词和用户提示词的元组 系统提示词和用户提示词的元组
""" """
# 获取内容生成配置
content_config = self._ensure_content_config()
# 从配置中获取审核提示词模板路径 # 从配置中获取审核提示词模板路径
system_prompt_path = self.content_config.judger_system_prompt system_prompt_path = content_config.judger_system_prompt
user_prompt_path = self.content_config.judger_user_prompt user_prompt_path = content_config.judger_user_prompt
# 创建提示词模板 # 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path) template = PromptTemplate(system_prompt_path, user_prompt_path)

View File

@ -10,6 +10,9 @@ import logging
import json import json
import os import os
import re import re
import sys
import traceback
import time
from typing import Dict, Any, Optional, List, cast from typing import Dict, Any, Optional, List, cast
from pathlib import Path from pathlib import Path
import mysql.connector import mysql.connector
@ -33,121 +36,130 @@ class PromptService:
config_manager: 配置管理器 config_manager: 配置管理器
""" """
self.config_manager = config_manager self.config_manager = config_manager
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
# 按需加载resource配置
if 'resource' not in self.config_manager._loaded_configs:
self.config_manager.load_task_config('resource')
self.resource_config = self.config_manager.get_config('resource', ResourceConfig)
# ResourceLoader是静态类不需要实例化
self.resource_dirs = self.resource_config.resource_dirs
# 初始化数据库连接池 # 初始化数据库连接池
self._init_db_pool() self.db_pool = self._init_db_pool()
def _init_db_pool(self): # 创建必要的目录结构
"""初始化数据库连接池""" self._create_resource_directories()
try:
# 尝试直接从配置文件加载数据库配置
config_dir = Path("config")
db_config_path = config_dir / "database.json"
if not db_config_path.exists():
logger.warning(f"数据库配置文件不存在: {db_config_path}")
self.db_pool = None
return
# 加载配置文件
with open(db_config_path, 'r', encoding='utf-8') as f:
db_config = json.load(f)
# 处理环境变量
processed_config = {}
for key, value in db_config.items():
if isinstance(value, str) and "${" in value:
# 匹配 ${ENV_VAR:-default} 格式
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
match = re.match(pattern, value)
if match:
env_var, default = match.groups()
processed_value = os.environ.get(env_var, default)
# 尝试转换为数字
if key == "port":
try:
processed_value = int(processed_value)
except (ValueError, TypeError):
processed_value = 3306
processed_config[key] = processed_value
else:
processed_config[key] = value
# 创建连接池
self.db_pool = pooling.MySQLConnectionPool(
pool_name="prompt_pool",
pool_size=5,
host=processed_config.get("host", "localhost"),
user=processed_config.get("user", "root"),
password=processed_config.get("password", ""),
database=processed_config.get("database", "travel_content"),
port=processed_config.get("port", 3306),
charset=processed_config.get("charset", "utf8mb4")
)
logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}")
except Exception as e:
logger.error(f"初始化数据库连接池失败: {e}")
self.db_pool = None
def get_style_content(self, style_name: str) -> str: def _create_resource_directories(self):
""" """创建必要的资源目录"""
获取内容风格提示词 try:
dirs_to_create = ['styles', 'attractions', 'products', 'audiences']
for dir_name in dirs_to_create:
dir_path = Path(dir_name)
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"创建目录: {dir_path}")
except Exception as e:
logger.error(f"创建资源目录失败: {e}")
def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""处理配置中的环境变量"""
processed = {}
for key, value in config.items():
if isinstance(value, str) and "${" in value:
# 匹配 ${ENV_VAR:-default} 格式
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
match = re.match(pattern, value)
if match:
env_var, default = match.groups()
processed_value = os.environ.get(env_var, default or "")
# 尝试转换为数字
if key == "port":
try:
processed_value = int(processed_value)
except (ValueError, TypeError):
processed_value = 3306
processed[key] = processed_value
else:
processed[key] = value
else:
processed[key] = value
return processed
def _init_db_pool(self):
"""初始化数据库连接池,尝试多种连接方式"""
# 获取数据库配置
raw_db_config = self.config_manager.get_raw_config('database')
Args: # 处理环境变量
style_name: 风格名称 db_config = self._process_env_vars(raw_db_config)
Returns: # 连接尝试配置
风格提示词内容 connection_attempts = [
""" {"desc": "使用配置文件中的设置", "config": db_config},
# 优先从数据库获取 {"desc": "使用明确的密码", "config": {**db_config, "password": "password"}},
{"desc": "使用空密码", "config": {**db_config, "password": ""}},
{"desc": "使用auth_plugin", "config": {**db_config, "auth_plugin": "mysql_native_password"}}
]
# 尝试不同的连接方式
for attempt in connection_attempts:
try:
# 打印连接信息(不包含密码)
connection_info = {k: v for k, v in attempt["config"].items() if k != 'password'}
logger.info(f"尝试连接数据库 ({attempt['desc']}): {connection_info}")
# 创建连接池
pool = pooling.MySQLConnectionPool(
pool_name=f"prompt_service_pool_{int(time.time())}",
pool_size=5,
**attempt["config"]
)
# 测试连接
with pool.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.fetchall()
logger.info(f"数据库连接池初始化成功 ({attempt['desc']})")
return pool
except Exception as e:
error_details = traceback.format_exc()
logger.error(f"数据库连接尝试 ({attempt['desc']}) 失败: {e}\n{error_details}")
logger.warning("所有数据库连接尝试均失败,将使用文件系统作为数据源")
return None
def get_style_content(self, style_name: str) -> str:
"""获取风格提示词内容"""
# 尝试从数据库获取
if self.db_pool: if self.db_pool:
try: try:
conn = self.db_pool.get_connection() with self.db_pool.get_connection() as conn:
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
cursor.execute( cursor.execute(
"SELECT description FROM contentStyle WHERE styleName = %s", "SELECT description FROM contentStyle WHERE styleName = %s AND isDelete = 0",
(style_name,) (style_name,)
) )
result = cursor.fetchone() result = cursor.fetchone()
cursor.close() if result:
conn.close() return result['description']
if result:
logger.info(f"从数据库获取风格提示词: {style_name}")
return result["description"]
except Exception as e: except Exception as e:
logger.error(f"从数据库获取风格提示词失败: {e}") logger.error(f"从数据库获取风格提示词失败: {e}")
# 回退到文件系统 # 从文件系统获取
try: logger.info(f"从文件系统获取风格提示词: {style_name}")
style_paths = self.resource_config.style.paths for dir_path in self.resource_dirs:
for path_str in style_paths: style_file = ResourceLoader.find_file(os.path.join(dir_path, "styles"), style_name)
try: if style_file:
if style_name in path_str: content = ResourceLoader.load_text_file(style_file)
full_path = self._get_full_path(path_str) if content:
if full_path.exists(): return content
logger.info(f"从文件系统获取风格提示词: {style_name}")
return full_path.read_text('utf-8') return f"未找到风格 '{style_name}' 的提示词"
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if style_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取风格提示词失败: {e}")
logger.warning(f"未找到风格提示词: {style_name},将使用默认值")
return "通用风格"
def get_audience_content(self, audience_name: str) -> str: def get_audience_content(self, audience_name: str) -> str:
""" """
@ -224,7 +236,7 @@ class PromptService:
conn = self.db_pool.get_connection() conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
cursor.execute( cursor.execute(
"SELECT description FROM scenicSpot WHERE spotName = %s", "SELECT description FROM scenicSpot WHERE name = %s",
(spot_name,) (spot_name,)
) )
result = cursor.fetchone() result = cursor.fetchone()
@ -271,7 +283,7 @@ class PromptService:
conn = self.db_pool.get_connection() conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
cursor.execute( cursor.execute(
"SELECT description FROM product WHERE productName = %s", "SELECT description FROM product WHERE name = %s",
(product_name,) (product_name,)
) )
result = cursor.fetchone() result = cursor.fetchone()
@ -453,7 +465,7 @@ class PromptService:
try: try:
conn = self.db_pool.get_connection() conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True) cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT spotName as name, description FROM scenicSpot") cursor.execute("SELECT name as name, description FROM scenicSpot")
results = cursor.fetchall() results = cursor.fetchall()
cursor.close() cursor.close()
conn.close() conn.close()

8
config/database.json Normal file
View File

@ -0,0 +1,8 @@
{
"host": "localhost",
"user": "root",
"password": "password",
"database": "travel_content",
"port": 3306,
"charset": "utf8mb4"
}

View File

@ -9,7 +9,7 @@ import json
import os import os
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Type, TypeVar, Optional, Any, cast from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set
from core.config.models import ( from core.config.models import (
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig, BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
@ -27,23 +27,30 @@ class ConfigManager:
负责加载管理和访问所有配置 负责加载管理和访问所有配置
""" """
# 服务端必要的全局配置
SERVER_CONFIGS = {'system', 'ai_model', 'database'}
# 单次生成任务的配置
TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'}
def __init__(self): def __init__(self):
self._configs: Dict[str, BaseConfig] = {} self._configs: Dict[str, BaseConfig] = {}
self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据
self.config_dir: Optional[Path] = None self.config_dir: Optional[Path] = None
self.config_objects = { self.config_objects = {
'ai_model': AIModelConfig(), 'ai_model': AIModelConfig(),
'system': SystemConfig(), 'system': SystemConfig(),
'topic_gen': GenerateTopicConfig(),
'content_gen': GenerateContentConfig(),
'resource': ResourceConfig() 'resource': ResourceConfig()
} }
self._loaded_configs: Set[str] = set()
def load_from_directory(self, config_dir: str): def load_from_directory(self, config_dir: str, server_mode: bool = False):
""" """
从目录加载配置 从目录加载配置
Args: Args:
config_dir: 配置文件目录 config_dir: 配置文件目录
server_mode: 是否为服务器模式如果是则只加载必要的全局配置
""" """
self.config_dir = Path(config_dir) self.config_dir = Path(config_dir)
if not self.config_dir.is_dir(): if not self.config_dir.is_dir():
@ -54,15 +61,17 @@ class ConfigManager:
self._register_configs() self._register_configs()
# 动态加载目录中的所有.json文件 # 动态加载目录中的所有.json文件
self._load_all_configs_from_dir() self._load_all_configs_from_dir(server_mode)
def _register_configs(self): def _register_configs(self):
"""注册所有配置""" """注册所有配置"""
self.register_config('ai_model', AIModelConfig) self.register_config('ai_model', AIModelConfig)
self.register_config('system', SystemConfig)
self.register_config('resource', ResourceConfig)
# 这些配置在服务器模式下不会自动加载,但仍然需要注册类型
self.register_config('poster', PosterConfig) self.register_config('poster', PosterConfig)
self.register_config('content', ContentConfig) self.register_config('content', ContentConfig)
self.register_config('resource', ResourceConfig)
self.register_config('system', SystemConfig)
self.register_config('topic_gen', GenerateTopicConfig) self.register_config('topic_gen', GenerateTopicConfig)
self.register_config('content_gen', GenerateContentConfig) self.register_config('content_gen', GenerateContentConfig)
@ -113,49 +122,106 @@ class ConfigManager:
return cast(T, config) return cast(T, config)
def _load_all_configs_from_dir(self): def get_raw_config(self, name: str) -> Dict[str, Any]:
"""动态加载目录中的所有.json文件""" """
try: 获取原始配置数据
# 1. 优先加载旧的主配置文件以实现向后兼容
main_config_path = self.config_dir / 'poster_gen_config.json' Args:
if main_config_path.exists(): name: 配置名称
self._load_main_config(main_config_path)
else:
logger.warning(f"旧的主配置文件不存在: {main_config_path}")
# 2. 遍历并加载目录中所有其他的 .json 文件 Returns:
原始配置数据字典
"""
if name in self._raw_configs:
return self._raw_configs[name]
# 如果没有原始配置,但有对象配置,则转换为字典
if name in self._configs:
return self._configs[name].to_dict()
# 尝试从文件加载
if self.config_dir:
config_path = self.config_dir / f"{name}.json"
if config_path.exists():
try:
with open(config_path, 'r', encoding='utf-8') as f:
raw_config = json.load(f)
self._raw_configs[name] = raw_config
return raw_config
except Exception as e:
logger.error(f"加载原始配置 '{name}' 失败: {e}")
# 返回空字典
return {}
def _load_all_configs_from_dir(self, server_mode: bool = False):
"""
动态加载目录中的所有.json文件
Args:
server_mode: 是否为服务器模式如果是则只加载必要的全局配置
"""
try:
# 遍历并加载目录中所有其他的 .json 文件
for config_path in self.config_dir.glob('*.json'): for config_path in self.config_dir.glob('*.json'):
if config_path.name == 'poster_gen_config.json':
continue # 跳过已处理的主配置文件
config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen' config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen'
# 服务器模式下,只加载必要的全局配置
if server_mode and config_name not in self.SERVER_CONFIGS:
logger.info(f"服务器模式下跳过非全局配置: {config_name}")
continue
# 加载原始配置
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._raw_configs[config_name] = config_data
# 更新对象配置
if config_name in self._configs: if config_name in self._configs:
logger.info(f"加载配置文件 '{config_name}': {config_path}") logger.info(f"加载配置文件 '{config_name}': {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._configs[config_name].update(config_data) self._configs[config_name].update(config_data)
self._loaded_configs.add(config_name)
else: else:
logger.warning(f"'{config_path}' 中找到的配置 '{config_name}' 没有对应的已注册配置类型,已跳过。") logger.info(f"加载原始配置 '{config_name}': {config_path}")
# 3. 最后应用环境变量覆盖 # 最后应用环境变量覆盖
self._apply_env_overrides() self._apply_env_overrides()
except Exception as e: except Exception as e:
logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True) logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True)
raise raise
def _load_main_config(self, path: Path): def load_task_config(self, config_name: str) -> bool:
"""加载主配置文件,并分发到各个配置对象""" """
logger.info(f"加载主配置文件: {path}") 按需加载任务配置
with open(path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
for name, config_obj in self._configs.items(): Args:
# 主配置文件可能是扁平结构或嵌套结构 config_name: 配置名称
if name in config_data: # 嵌套结构
config_obj.update(config_data[name]) Returns:
else: # 尝试从根部更新 (扁平结构) 是否成功加载
config_obj.update(config_data) """
if config_name in self._loaded_configs:
return True
if self.config_dir:
config_path = self.config_dir / f"{config_name}.json"
if config_path.exists():
try:
with open(config_path, 'r', encoding='utf-8') as f:
config_data = json.load(f)
self._raw_configs[config_name] = config_data
if config_name in self._configs:
self._configs[config_name].update(config_data)
self._loaded_configs.add(config_name)
logger.info(f"按需加载任务配置 '{config_name}': {config_path}")
return True
except Exception as e:
logger.error(f"加载任务配置 '{config_name}' 失败: {e}")
logger.warning(f"未找到任务配置: {config_name}")
return False
def _apply_env_overrides(self): def _apply_env_overrides(self):
"""应用环境变量覆盖""" """应用环境变量覆盖"""
@ -176,6 +242,10 @@ class ConfigManager:
if update_data: if update_data:
ai_model_config.update(update_data) ai_model_config.update(update_data)
# 更新原始配置
if 'ai_model' in self._raw_configs:
for key, value in update_data.items():
self._raw_configs['ai_model'][key] = value
logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}") logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}")
def save_config(self, name: str): def save_config(self, name: str):
@ -191,6 +261,9 @@ class ConfigManager:
path = self.config_dir / f"{name}.json" path = self.config_dir / f"{name}.json"
config = self.get_config(name, BaseConfig) config = self.get_config(name, BaseConfig)
config_data = config.to_dict() config_data = config.to_dict()
# 更新原始配置
self._raw_configs[name] = config_data
try: try:
with open(path, 'w', encoding='utf-8') as f: with open(path, 'w', encoding='utf-8') as f:

View File

@ -161,6 +161,8 @@ class PosterConfig(BaseConfig):
additional_images_enabled: bool = True additional_images_enabled: bool = True
template_selection: str = "random" # random, business, vibrant, original template_selection: str = "random" # random, business, vibrant, original
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"]) available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
poster_system_prompt: str = "resource/prompt/generatePoster/system.txt"
poster_user_prompt: str = "resource/prompt/generatePoster/user.txt"
class ContentConfig(BaseConfig): class ContentConfig(BaseConfig):