81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
API依赖注入模块
|
|
"""
|
|
|
|
from typing import Optional
|
|
from datetime import datetime
|
|
import uuid
|
|
from fastapi import Depends
|
|
from core.config import get_config_manager, ConfigManager
|
|
from core.ai import AIAgent
|
|
from utils.file_io import OutputManager
|
|
|
|
# 全局依赖
|
|
config_manager: Optional[ConfigManager] = None
|
|
ai_agent: Optional[AIAgent] = None
|
|
|
|
def initialize_dependencies():
|
|
"""初始化全局依赖"""
|
|
global config_manager, ai_agent
|
|
|
|
# 初始化配置 - 使用服务器模式
|
|
config_manager = get_config_manager()
|
|
config_manager.load_from_directory("config", server_mode=True)
|
|
|
|
# 初始化AI代理
|
|
from core.config import AIModelConfig
|
|
ai_config = config_manager.get_config('ai_model', AIModelConfig)
|
|
ai_agent = AIAgent(ai_config)
|
|
|
|
def get_config() -> ConfigManager:
|
|
"""获取配置管理器"""
|
|
if config_manager is None:
|
|
raise RuntimeError("配置管理器未初始化")
|
|
return config_manager
|
|
|
|
def get_ai_agent() -> AIAgent:
|
|
"""获取AI代理"""
|
|
if ai_agent is None:
|
|
raise RuntimeError("AI代理未初始化")
|
|
return ai_agent
|
|
|
|
def create_output_manager() -> OutputManager:
|
|
"""为每个请求创建新的输出管理器"""
|
|
# 为每个请求生成唯一的run_id
|
|
run_id = f"api_request-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}"
|
|
return OutputManager("result", run_id)
|
|
|
|
def get_output_manager() -> OutputManager:
|
|
"""获取输出管理器(每次调用创建新实例)"""
|
|
return create_output_manager()
|
|
|
|
def get_tweet_service():
|
|
"""获取文字内容服务"""
|
|
from api.services.tweet import TweetService
|
|
return TweetService(get_ai_agent(), get_config(), get_output_manager())
|
|
|
|
def get_poster_service():
|
|
"""获取海报服务"""
|
|
from api.services.poster import PosterService
|
|
return PosterService(get_ai_agent(), get_config(), get_output_manager())
|
|
|
|
def get_database_service():
|
|
"""获取数据库服务"""
|
|
from api.services.database_service import DatabaseService
|
|
return DatabaseService(get_config())
|
|
|
|
def get_prompt_builder():
|
|
"""获取提示词构建器服务"""
|
|
from api.services.prompt_builder import PromptBuilderService
|
|
from api.services.prompt_service import PromptService
|
|
|
|
prompt_service = PromptService(get_config())
|
|
return PromptBuilderService(get_config(), prompt_service)
|
|
|
|
def get_integration_service():
|
|
"""获取整合服务"""
|
|
from api.services.integration_service import IntegrationService
|
|
return IntegrationService(get_config(), get_output_manager()) |