TravelContentCreator/api/dependencies.py

81 lines
2.6 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API依赖注入模块
"""
from typing import Optional
from datetime import datetime
import uuid
from fastapi import Depends
from core.config import get_config_manager, ConfigManager
from core.ai import AIAgent
from utils.file_io import OutputManager
# 全局依赖
config_manager: Optional[ConfigManager] = None
ai_agent: Optional[AIAgent] = None
def initialize_dependencies():
"""初始化全局依赖"""
global config_manager, ai_agent
# 初始化配置 - 使用服务器模式
config_manager = get_config_manager()
config_manager.load_from_directory("config", server_mode=True)
# 初始化AI代理
from core.config import AIModelConfig
ai_config = config_manager.get_config('ai_model', AIModelConfig)
ai_agent = AIAgent(ai_config)
def get_config() -> ConfigManager:
"""获取配置管理器"""
if config_manager is None:
raise RuntimeError("配置管理器未初始化")
return config_manager
def get_ai_agent() -> AIAgent:
"""获取AI代理"""
if ai_agent is None:
raise RuntimeError("AI代理未初始化")
return ai_agent
def create_output_manager() -> OutputManager:
"""为每个请求创建新的输出管理器"""
# 为每个请求生成唯一的run_id
run_id = f"api_request-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
return OutputManager("result", run_id)
def get_output_manager() -> OutputManager:
"""获取输出管理器(每次调用创建新实例)"""
return create_output_manager()
def get_tweet_service():
"""获取文字内容服务"""
from api.services.tweet import TweetService
return TweetService(get_ai_agent(), get_config(), get_output_manager())
def get_poster_service():
"""获取海报服务"""
from api.services.poster import PosterService
return PosterService(get_ai_agent(), get_config(), get_output_manager())
def get_database_service():
"""获取数据库服务"""
from api.services.database_service import DatabaseService
return DatabaseService(get_config())
def get_prompt_builder():
"""获取提示词构建器服务"""
from api.services.prompt_builder import PromptBuilderService
from api.services.prompt_service import PromptService
prompt_service = PromptService(get_config())
return PromptBuilderService(get_config(), prompt_service)
def get_integration_service():
"""获取整合服务"""
from api.services.integration_service import IntegrationService
return IntegrationService(get_config(), get_output_manager())