修复了数据库获取数据的功能
This commit is contained in:
parent
19ca9d06ce
commit
f6a48031a0
Binary file not shown.
@ -5,22 +5,23 @@
|
||||
API依赖注入模块
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from core.config import get_config_manager, ConfigManager
|
||||
from core.ai import AIAgent
|
||||
from utils.file_io import OutputManager
|
||||
|
||||
# 全局依赖
|
||||
config_manager = None
|
||||
ai_agent = None
|
||||
output_manager = None
|
||||
config_manager: Optional[ConfigManager] = None
|
||||
ai_agent: Optional[AIAgent] = None
|
||||
output_manager: Optional[OutputManager] = None
|
||||
|
||||
def initialize_dependencies():
|
||||
"""初始化全局依赖"""
|
||||
global config_manager, ai_agent, output_manager
|
||||
|
||||
# 初始化配置
|
||||
# 初始化配置 - 使用服务器模式
|
||||
config_manager = get_config_manager()
|
||||
config_manager.load_from_directory("config")
|
||||
config_manager.load_from_directory("config", server_mode=True)
|
||||
|
||||
# 初始化输出管理器
|
||||
from datetime import datetime
|
||||
@ -34,12 +35,18 @@ def initialize_dependencies():
|
||||
|
||||
def get_config() -> ConfigManager:
|
||||
"""获取配置管理器"""
|
||||
if config_manager is None:
|
||||
raise RuntimeError("配置管理器未初始化")
|
||||
return config_manager
|
||||
|
||||
def get_ai_agent() -> AIAgent:
|
||||
"""获取AI代理"""
|
||||
if ai_agent is None:
|
||||
raise RuntimeError("AI代理未初始化")
|
||||
return ai_agent
|
||||
|
||||
def get_output_manager() -> OutputManager:
|
||||
"""获取输出管理器"""
|
||||
if output_manager is None:
|
||||
raise RuntimeError("输出管理器未初始化")
|
||||
return output_manager
|
||||
Binary file not shown.
Binary file not shown.
@ -10,7 +10,7 @@ import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import ConfigManager, GenerateContentConfig
|
||||
from core.config import ConfigManager, GenerateContentConfig, GenerateTopicConfig, PosterConfig
|
||||
from utils.prompts import PromptTemplate
|
||||
from api.services.prompt_service import PromptService
|
||||
|
||||
@ -30,7 +30,30 @@ class PromptBuilderService:
|
||||
"""
|
||||
self.config_manager = config_manager
|
||||
self.prompt_service = prompt_service
|
||||
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
|
||||
|
||||
def _ensure_content_config(self) -> GenerateContentConfig:
|
||||
"""确保内容生成配置已加载"""
|
||||
# 按需加载内容生成配置
|
||||
if not self.config_manager.load_task_config('content_gen'):
|
||||
logger.warning("未找到内容生成配置,将使用默认配置")
|
||||
|
||||
return self.config_manager.get_config('content_gen', GenerateContentConfig)
|
||||
|
||||
def _ensure_topic_config(self) -> GenerateTopicConfig:
|
||||
"""确保选题生成配置已加载"""
|
||||
# 按需加载选题生成配置
|
||||
if not self.config_manager.load_task_config('topic_gen'):
|
||||
logger.warning("未找到选题生成配置,将使用默认配置")
|
||||
|
||||
return self.config_manager.get_config('topic_gen', GenerateTopicConfig)
|
||||
|
||||
def _ensure_poster_config(self) -> PosterConfig:
|
||||
"""确保海报生成配置已加载"""
|
||||
# 按需加载海报生成配置
|
||||
if not self.config_manager.load_task_config('poster_gen'):
|
||||
logger.warning("未找到海报生成配置,将使用默认配置")
|
||||
|
||||
return self.config_manager.get_config('poster_gen', PosterConfig)
|
||||
|
||||
def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]:
|
||||
"""
|
||||
@ -43,9 +66,12 @@ class PromptBuilderService:
|
||||
Returns:
|
||||
系统提示词和用户提示词的元组
|
||||
"""
|
||||
# 获取内容生成配置
|
||||
content_config = self._ensure_content_config()
|
||||
|
||||
# 加载系统提示词和用户提示词模板
|
||||
system_prompt_path = self.content_config.content_system_prompt
|
||||
user_prompt_path = self.content_config.content_user_prompt
|
||||
system_prompt_path = content_config.content_system_prompt
|
||||
user_prompt_path = content_config.content_user_prompt
|
||||
|
||||
# 创建提示词模板
|
||||
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||
@ -83,6 +109,47 @@ class PromptBuilderService:
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def build_poster_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]:
|
||||
"""
|
||||
构建海报生成提示词
|
||||
|
||||
Args:
|
||||
topic: 选题信息
|
||||
content: 生成的内容
|
||||
|
||||
Returns:
|
||||
系统提示词和用户提示词的元组
|
||||
"""
|
||||
# 获取海报生成配置
|
||||
poster_config = self._ensure_poster_config()
|
||||
|
||||
# 从配置中获取海报提示词模板路径
|
||||
system_prompt_path = poster_config.poster_system_prompt
|
||||
user_prompt_path = poster_config.poster_user_prompt
|
||||
|
||||
if not system_prompt_path or not user_prompt_path:
|
||||
raise ValueError("海报提示词模板路径不完整")
|
||||
|
||||
# 创建提示词模板
|
||||
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||
|
||||
# 获取景区信息
|
||||
object_name = topic.get("object", "")
|
||||
object_content = self.prompt_service.get_scenic_spot_info(object_name)
|
||||
|
||||
# 构建系统提示词
|
||||
system_prompt = template.get_system_prompt()
|
||||
|
||||
# 构建用户提示词
|
||||
user_prompt = template.build_user_prompt(
|
||||
content=content.get("content", ""),
|
||||
title=content.get("title", ""),
|
||||
object_name=object_name,
|
||||
object_content=object_content
|
||||
)
|
||||
|
||||
return system_prompt, user_prompt
|
||||
|
||||
def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]:
|
||||
"""
|
||||
构建选题生成提示词
|
||||
@ -94,13 +161,11 @@ class PromptBuilderService:
|
||||
Returns:
|
||||
系统提示词和用户提示词的元组
|
||||
"""
|
||||
# 从配置中获取选题提示词模板路径
|
||||
topic_config = self.config_manager.get_config('topic_gen', dict)
|
||||
if not topic_config:
|
||||
raise ValueError("未找到选题生成配置")
|
||||
# 获取选题生成配置
|
||||
topic_config = self._ensure_topic_config()
|
||||
|
||||
system_prompt_path = topic_config.get("topic_system_prompt", "")
|
||||
user_prompt_path = topic_config.get("topic_user_prompt", "")
|
||||
system_prompt_path = topic_config.topic_system_prompt
|
||||
user_prompt_path = topic_config.topic_user_prompt
|
||||
|
||||
if not system_prompt_path or not user_prompt_path:
|
||||
raise ValueError("选题提示词模板路径不完整")
|
||||
@ -155,9 +220,12 @@ class PromptBuilderService:
|
||||
Returns:
|
||||
系统提示词和用户提示词的元组
|
||||
"""
|
||||
# 获取内容生成配置
|
||||
content_config = self._ensure_content_config()
|
||||
|
||||
# 从配置中获取审核提示词模板路径
|
||||
system_prompt_path = self.content_config.judger_system_prompt
|
||||
user_prompt_path = self.content_config.judger_user_prompt
|
||||
system_prompt_path = content_config.judger_system_prompt
|
||||
user_prompt_path = content_config.judger_user_prompt
|
||||
|
||||
# 创建提示词模板
|
||||
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||
|
||||
@ -10,6 +10,9 @@ import logging
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List, cast
|
||||
from pathlib import Path
|
||||
import mysql.connector
|
||||
@ -33,121 +36,130 @@ class PromptService:
|
||||
config_manager: 配置管理器
|
||||
"""
|
||||
self.config_manager = config_manager
|
||||
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
|
||||
|
||||
# 按需加载resource配置
|
||||
if 'resource' not in self.config_manager._loaded_configs:
|
||||
self.config_manager.load_task_config('resource')
|
||||
|
||||
self.resource_config = self.config_manager.get_config('resource', ResourceConfig)
|
||||
|
||||
# ResourceLoader是静态类,不需要实例化
|
||||
self.resource_dirs = self.resource_config.resource_dirs
|
||||
|
||||
# 初始化数据库连接池
|
||||
self._init_db_pool()
|
||||
self.db_pool = self._init_db_pool()
|
||||
|
||||
def _init_db_pool(self):
|
||||
"""初始化数据库连接池"""
|
||||
try:
|
||||
# 尝试直接从配置文件加载数据库配置
|
||||
config_dir = Path("config")
|
||||
db_config_path = config_dir / "database.json"
|
||||
|
||||
if not db_config_path.exists():
|
||||
logger.warning(f"数据库配置文件不存在: {db_config_path}")
|
||||
self.db_pool = None
|
||||
return
|
||||
|
||||
# 加载配置文件
|
||||
with open(db_config_path, 'r', encoding='utf-8') as f:
|
||||
db_config = json.load(f)
|
||||
|
||||
# 处理环境变量
|
||||
processed_config = {}
|
||||
for key, value in db_config.items():
|
||||
if isinstance(value, str) and "${" in value:
|
||||
# 匹配 ${ENV_VAR:-default} 格式
|
||||
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
env_var, default = match.groups()
|
||||
processed_value = os.environ.get(env_var, default)
|
||||
# 尝试转换为数字
|
||||
if key == "port":
|
||||
try:
|
||||
processed_value = int(processed_value)
|
||||
except (ValueError, TypeError):
|
||||
processed_value = 3306
|
||||
processed_config[key] = processed_value
|
||||
else:
|
||||
processed_config[key] = value
|
||||
|
||||
# 创建连接池
|
||||
self.db_pool = pooling.MySQLConnectionPool(
|
||||
pool_name="prompt_pool",
|
||||
pool_size=5,
|
||||
host=processed_config.get("host", "localhost"),
|
||||
user=processed_config.get("user", "root"),
|
||||
password=processed_config.get("password", ""),
|
||||
database=processed_config.get("database", "travel_content"),
|
||||
port=processed_config.get("port", 3306),
|
||||
charset=processed_config.get("charset", "utf8mb4")
|
||||
)
|
||||
logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化数据库连接池失败: {e}")
|
||||
self.db_pool = None
|
||||
# 创建必要的目录结构
|
||||
self._create_resource_directories()
|
||||
|
||||
def get_style_content(self, style_name: str) -> str:
|
||||
"""
|
||||
获取内容风格提示词
|
||||
def _create_resource_directories(self):
|
||||
"""创建必要的资源目录"""
|
||||
try:
|
||||
dirs_to_create = ['styles', 'attractions', 'products', 'audiences']
|
||||
for dir_name in dirs_to_create:
|
||||
dir_path = Path(dir_name)
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"创建目录: {dir_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建资源目录失败: {e}")
|
||||
|
||||
def _process_env_vars(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理配置中的环境变量"""
|
||||
processed = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and "${" in value:
|
||||
# 匹配 ${ENV_VAR:-default} 格式
|
||||
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
|
||||
match = re.match(pattern, value)
|
||||
if match:
|
||||
env_var, default = match.groups()
|
||||
processed_value = os.environ.get(env_var, default or "")
|
||||
# 尝试转换为数字
|
||||
if key == "port":
|
||||
try:
|
||||
processed_value = int(processed_value)
|
||||
except (ValueError, TypeError):
|
||||
processed_value = 3306
|
||||
processed[key] = processed_value
|
||||
else:
|
||||
processed[key] = value
|
||||
else:
|
||||
processed[key] = value
|
||||
return processed
|
||||
|
||||
def _init_db_pool(self):
|
||||
"""初始化数据库连接池,尝试多种连接方式"""
|
||||
# 获取数据库配置
|
||||
raw_db_config = self.config_manager.get_raw_config('database')
|
||||
|
||||
Args:
|
||||
style_name: 风格名称
|
||||
|
||||
Returns:
|
||||
风格提示词内容
|
||||
"""
|
||||
# 优先从数据库获取
|
||||
# 处理环境变量
|
||||
db_config = self._process_env_vars(raw_db_config)
|
||||
|
||||
# 连接尝试配置
|
||||
connection_attempts = [
|
||||
{"desc": "使用配置文件中的设置", "config": db_config},
|
||||
{"desc": "使用明确的密码", "config": {**db_config, "password": "password"}},
|
||||
{"desc": "使用空密码", "config": {**db_config, "password": ""}},
|
||||
{"desc": "使用auth_plugin", "config": {**db_config, "auth_plugin": "mysql_native_password"}}
|
||||
]
|
||||
|
||||
# 尝试不同的连接方式
|
||||
for attempt in connection_attempts:
|
||||
try:
|
||||
# 打印连接信息(不包含密码)
|
||||
connection_info = {k: v for k, v in attempt["config"].items() if k != 'password'}
|
||||
logger.info(f"尝试连接数据库 ({attempt['desc']}): {connection_info}")
|
||||
|
||||
# 创建连接池
|
||||
pool = pooling.MySQLConnectionPool(
|
||||
pool_name=f"prompt_service_pool_{int(time.time())}",
|
||||
pool_size=5,
|
||||
**attempt["config"]
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
with pool.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchall()
|
||||
|
||||
logger.info(f"数据库连接池初始化成功 ({attempt['desc']})")
|
||||
return pool
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"数据库连接尝试 ({attempt['desc']}) 失败: {e}\n{error_details}")
|
||||
|
||||
logger.warning("所有数据库连接尝试均失败,将使用文件系统作为数据源")
|
||||
return None
|
||||
|
||||
def get_style_content(self, style_name: str) -> str:
|
||||
"""获取风格提示词内容"""
|
||||
# 尝试从数据库获取
|
||||
if self.db_pool:
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT description FROM contentStyle WHERE styleName = %s",
|
||||
(style_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
if result:
|
||||
logger.info(f"从数据库获取风格提示词: {style_name}")
|
||||
return result["description"]
|
||||
with self.db_pool.get_connection() as conn:
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT description FROM contentStyle WHERE styleName = %s AND isDelete = 0",
|
||||
(style_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result['description']
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取风格提示词失败: {e}")
|
||||
|
||||
# 回退到文件系统
|
||||
try:
|
||||
style_paths = self.resource_config.style.paths
|
||||
for path_str in style_paths:
|
||||
try:
|
||||
if style_name in path_str:
|
||||
full_path = self._get_full_path(path_str)
|
||||
if full_path.exists():
|
||||
logger.info(f"从文件系统获取风格提示词: {style_name}")
|
||||
return full_path.read_text('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"读取风格文件失败 {path_str}: {e}")
|
||||
|
||||
# 如果没有精确匹配,尝试模糊匹配
|
||||
for path_str in style_paths:
|
||||
try:
|
||||
full_path = self._get_full_path(path_str)
|
||||
if full_path.exists() and full_path.is_file():
|
||||
content = full_path.read_text('utf-8')
|
||||
if style_name.lower() in full_path.stem.lower():
|
||||
logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"读取风格文件失败 {path_str}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取风格提示词失败: {e}")
|
||||
|
||||
logger.warning(f"未找到风格提示词: {style_name},将使用默认值")
|
||||
return "通用风格"
|
||||
# 从文件系统获取
|
||||
logger.info(f"从文件系统获取风格提示词: {style_name}")
|
||||
for dir_path in self.resource_dirs:
|
||||
style_file = ResourceLoader.find_file(os.path.join(dir_path, "styles"), style_name)
|
||||
if style_file:
|
||||
content = ResourceLoader.load_text_file(style_file)
|
||||
if content:
|
||||
return content
|
||||
|
||||
return f"未找到风格 '{style_name}' 的提示词"
|
||||
|
||||
def get_audience_content(self, audience_name: str) -> str:
|
||||
"""
|
||||
@ -224,7 +236,7 @@ class PromptService:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT description FROM scenicSpot WHERE spotName = %s",
|
||||
"SELECT description FROM scenicSpot WHERE name = %s",
|
||||
(spot_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
@ -271,7 +283,7 @@ class PromptService:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute(
|
||||
"SELECT description FROM product WHERE productName = %s",
|
||||
"SELECT description FROM product WHERE name = %s",
|
||||
(product_name,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
@ -453,7 +465,7 @@ class PromptService:
|
||||
try:
|
||||
conn = self.db_pool.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
cursor.execute("SELECT spotName as name, description FROM scenicSpot")
|
||||
cursor.execute("SELECT name as name, description FROM scenicSpot")
|
||||
results = cursor.fetchall()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
8
config/database.json
Normal file
8
config/database.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"host": "localhost",
|
||||
"user": "root",
|
||||
"password": "password",
|
||||
"database": "travel_content",
|
||||
"port": 3306,
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@ -9,7 +9,7 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Type, TypeVar, Optional, Any, cast
|
||||
from typing import Dict, Type, TypeVar, Optional, Any, cast, List, Set
|
||||
|
||||
from core.config.models import (
|
||||
BaseConfig, AIModelConfig, SystemConfig, GenerateTopicConfig, ResourceConfig,
|
||||
@ -27,23 +27,30 @@ class ConfigManager:
|
||||
负责加载、管理和访问所有配置
|
||||
"""
|
||||
|
||||
# 服务端必要的全局配置
|
||||
SERVER_CONFIGS = {'system', 'ai_model', 'database'}
|
||||
|
||||
# 单次生成任务的配置
|
||||
TASK_CONFIGS = {'topic_gen', 'content_gen', 'poster_gen', 'resource'}
|
||||
|
||||
def __init__(self):
|
||||
self._configs: Dict[str, BaseConfig] = {}
|
||||
self._raw_configs: Dict[str, Dict[str, Any]] = {} # 存储原始配置数据
|
||||
self.config_dir: Optional[Path] = None
|
||||
self.config_objects = {
|
||||
'ai_model': AIModelConfig(),
|
||||
'system': SystemConfig(),
|
||||
'topic_gen': GenerateTopicConfig(),
|
||||
'content_gen': GenerateContentConfig(),
|
||||
'resource': ResourceConfig()
|
||||
}
|
||||
self._loaded_configs: Set[str] = set()
|
||||
|
||||
def load_from_directory(self, config_dir: str):
|
||||
def load_from_directory(self, config_dir: str, server_mode: bool = False):
|
||||
"""
|
||||
从目录加载配置
|
||||
|
||||
Args:
|
||||
config_dir: 配置文件目录
|
||||
server_mode: 是否为服务器模式,如果是则只加载必要的全局配置
|
||||
"""
|
||||
self.config_dir = Path(config_dir)
|
||||
if not self.config_dir.is_dir():
|
||||
@ -54,15 +61,17 @@ class ConfigManager:
|
||||
self._register_configs()
|
||||
|
||||
# 动态加载目录中的所有.json文件
|
||||
self._load_all_configs_from_dir()
|
||||
self._load_all_configs_from_dir(server_mode)
|
||||
|
||||
def _register_configs(self):
|
||||
"""注册所有配置"""
|
||||
self.register_config('ai_model', AIModelConfig)
|
||||
self.register_config('system', SystemConfig)
|
||||
self.register_config('resource', ResourceConfig)
|
||||
|
||||
# 这些配置在服务器模式下不会自动加载,但仍然需要注册类型
|
||||
self.register_config('poster', PosterConfig)
|
||||
self.register_config('content', ContentConfig)
|
||||
self.register_config('resource', ResourceConfig)
|
||||
self.register_config('system', SystemConfig)
|
||||
self.register_config('topic_gen', GenerateTopicConfig)
|
||||
self.register_config('content_gen', GenerateContentConfig)
|
||||
|
||||
@ -113,49 +122,106 @@ class ConfigManager:
|
||||
|
||||
return cast(T, config)
|
||||
|
||||
def _load_all_configs_from_dir(self):
|
||||
"""动态加载目录中的所有.json文件"""
|
||||
try:
|
||||
# 1. 优先加载旧的主配置文件以实现向后兼容
|
||||
main_config_path = self.config_dir / 'poster_gen_config.json'
|
||||
if main_config_path.exists():
|
||||
self._load_main_config(main_config_path)
|
||||
else:
|
||||
logger.warning(f"旧的主配置文件不存在: {main_config_path}")
|
||||
def get_raw_config(self, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取原始配置数据
|
||||
|
||||
Args:
|
||||
name: 配置名称
|
||||
|
||||
# 2. 遍历并加载目录中所有其他的 .json 文件
|
||||
Returns:
|
||||
原始配置数据字典
|
||||
"""
|
||||
if name in self._raw_configs:
|
||||
return self._raw_configs[name]
|
||||
|
||||
# 如果没有原始配置,但有对象配置,则转换为字典
|
||||
if name in self._configs:
|
||||
return self._configs[name].to_dict()
|
||||
|
||||
# 尝试从文件加载
|
||||
if self.config_dir:
|
||||
config_path = self.config_dir / f"{name}.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
raw_config = json.load(f)
|
||||
self._raw_configs[name] = raw_config
|
||||
return raw_config
|
||||
except Exception as e:
|
||||
logger.error(f"加载原始配置 '{name}' 失败: {e}")
|
||||
|
||||
# 返回空字典
|
||||
return {}
|
||||
|
||||
def _load_all_configs_from_dir(self, server_mode: bool = False):
|
||||
"""
|
||||
动态加载目录中的所有.json文件
|
||||
|
||||
Args:
|
||||
server_mode: 是否为服务器模式,如果是则只加载必要的全局配置
|
||||
"""
|
||||
try:
|
||||
# 遍历并加载目录中所有其他的 .json 文件
|
||||
for config_path in self.config_dir.glob('*.json'):
|
||||
if config_path.name == 'poster_gen_config.json':
|
||||
continue # 跳过已处理的主配置文件
|
||||
|
||||
config_name = config_path.stem # 'topic_gen.json' -> 'topic_gen'
|
||||
|
||||
# 服务器模式下,只加载必要的全局配置
|
||||
if server_mode and config_name not in self.SERVER_CONFIGS:
|
||||
logger.info(f"服务器模式下跳过非全局配置: {config_name}")
|
||||
continue
|
||||
|
||||
# 加载原始配置
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
self._raw_configs[config_name] = config_data
|
||||
|
||||
# 更新对象配置
|
||||
if config_name in self._configs:
|
||||
logger.info(f"加载配置文件 '{config_name}': {config_path}")
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
self._configs[config_name].update(config_data)
|
||||
self._loaded_configs.add(config_name)
|
||||
else:
|
||||
logger.warning(f"在 '{config_path}' 中找到的配置 '{config_name}' 没有对应的已注册配置类型,已跳过。")
|
||||
logger.info(f"加载原始配置 '{config_name}': {config_path}")
|
||||
|
||||
# 3. 最后应用环境变量覆盖
|
||||
# 最后应用环境变量覆盖
|
||||
self._apply_env_overrides()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从目录 '{self.config_dir}' 加载配置失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _load_main_config(self, path: Path):
|
||||
"""加载主配置文件,并分发到各个配置对象"""
|
||||
logger.info(f"加载主配置文件: {path}")
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
def load_task_config(self, config_name: str) -> bool:
|
||||
"""
|
||||
按需加载任务配置
|
||||
|
||||
for name, config_obj in self._configs.items():
|
||||
# 主配置文件可能是扁平结构或嵌套结构
|
||||
if name in config_data: # 嵌套结构
|
||||
config_obj.update(config_data[name])
|
||||
else: # 尝试从根部更新 (扁平结构)
|
||||
config_obj.update(config_data)
|
||||
Args:
|
||||
config_name: 配置名称
|
||||
|
||||
Returns:
|
||||
是否成功加载
|
||||
"""
|
||||
if config_name in self._loaded_configs:
|
||||
return True
|
||||
|
||||
if self.config_dir:
|
||||
config_path = self.config_dir / f"{config_name}.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
self._raw_configs[config_name] = config_data
|
||||
|
||||
if config_name in self._configs:
|
||||
self._configs[config_name].update(config_data)
|
||||
self._loaded_configs.add(config_name)
|
||||
logger.info(f"按需加载任务配置 '{config_name}': {config_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"加载任务配置 '{config_name}' 失败: {e}")
|
||||
|
||||
logger.warning(f"未找到任务配置: {config_name}")
|
||||
return False
|
||||
|
||||
def _apply_env_overrides(self):
|
||||
"""应用环境变量覆盖"""
|
||||
@ -176,6 +242,10 @@ class ConfigManager:
|
||||
|
||||
if update_data:
|
||||
ai_model_config.update(update_data)
|
||||
# 更新原始配置
|
||||
if 'ai_model' in self._raw_configs:
|
||||
for key, value in update_data.items():
|
||||
self._raw_configs['ai_model'][key] = value
|
||||
logger.info(f"通过环境变量更新了AI模型配置: {list(update_data.keys())}")
|
||||
|
||||
def save_config(self, name: str):
|
||||
@ -191,6 +261,9 @@ class ConfigManager:
|
||||
path = self.config_dir / f"{name}.json"
|
||||
config = self.get_config(name, BaseConfig)
|
||||
config_data = config.to_dict()
|
||||
|
||||
# 更新原始配置
|
||||
self._raw_configs[name] = config_data
|
||||
|
||||
try:
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
|
||||
@ -161,6 +161,8 @@ class PosterConfig(BaseConfig):
|
||||
additional_images_enabled: bool = True
|
||||
template_selection: str = "random" # random, business, vibrant, original
|
||||
available_templates: List[str] = Field(default_factory=lambda: ["original", "business", "vibrant"])
|
||||
poster_system_prompt: str = "resource/prompt/generatePoster/system.txt"
|
||||
poster_user_prompt: str = "resource/prompt/generatePoster/user.txt"
|
||||
|
||||
|
||||
class ContentConfig(BaseConfig):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user