diff --git a/api/__pycache__/dependencies.cpython-312.pyc b/api/__pycache__/dependencies.cpython-312.pyc new file mode 100644 index 0000000..17cdd93 Binary files /dev/null and b/api/__pycache__/dependencies.cpython-312.pyc differ diff --git a/api/__pycache__/main.cpython-312.pyc b/api/__pycache__/main.cpython-312.pyc index 632e26c..2ef7585 100644 Binary files a/api/__pycache__/main.cpython-312.pyc and b/api/__pycache__/main.cpython-312.pyc differ diff --git a/api/dependencies.py b/api/dependencies.py new file mode 100644 index 0000000..7d14fdf --- /dev/null +++ b/api/dependencies.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +API依赖注入模块 +""" + +from core.config import get_config_manager, ConfigManager +from core.ai import AIAgent +from utils.file_io import OutputManager + +# 全局依赖 +config_manager = None +ai_agent = None +output_manager = None + +def initialize_dependencies(): + """初始化全局依赖""" + global config_manager, ai_agent, output_manager + + # 初始化配置 + config_manager = get_config_manager() + config_manager.load_from_directory("config") + + # 初始化输出管理器 + from datetime import datetime + run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + output_manager = OutputManager("result", run_id) + + # 初始化AI代理 + from core.config import AIModelConfig + ai_config = config_manager.get_config('ai_model', AIModelConfig) + ai_agent = AIAgent(ai_config) + +def get_config() -> ConfigManager: + """获取配置管理器""" + return config_manager + +def get_ai_agent() -> AIAgent: + """获取AI代理""" + return ai_agent + +def get_output_manager() -> OutputManager: + """获取输出管理器""" + return output_manager \ No newline at end of file diff --git a/api/main.py b/api/main.py index 27683a3..40209ed 100644 --- a/api/main.py +++ b/api/main.py @@ -7,16 +7,11 @@ TravelContentCreator API服务 """ import logging -from fastapi import FastAPI, Depends +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager -from core.config import get_config_manager, AIModelConfig -from core.ai import AIAgent -from utils.file_io import OutputManager -from datetime import datetime - -from api.routers import tweet, poster +from api import dependencies # 配置日志 logging.basicConfig( @@ -26,32 +21,15 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -# 全局依赖 -config_manager = None -ai_agent = None -output_manager = None - @asynccontextmanager async def lifespan(app: FastAPI): """ 应用生命周期管理 在应用启动时初始化全局依赖,在应用关闭时清理资源 """ - global config_manager, ai_agent, output_manager - # 初始化配置 logger.info("正在初始化API服务...") - config_manager = get_config_manager() - config_manager.load_from_directory("config") - - # 初始化输出管理器 - run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - output_manager = OutputManager("result", run_id) - - # 初始化AI代理 - ai_config = config_manager.get_config('ai_model', AIModelConfig) - ai_agent = AIAgent(ai_config) - + dependencies.initialize_dependencies() logger.info("API服务初始化完成") yield @@ -77,37 +55,19 @@ app.add_middleware( allow_headers=["*"], ) -# 依赖注入函数 -def get_config(): - return config_manager - -def get_ai_agent(): - return ai_agent - -def get_output_manager(): - return output_manager +# 导入路由 +from api.routers import tweet, poster, prompt # 包含路由 app.include_router(tweet.router, prefix="/api/tweet", tags=["tweet"]) app.include_router(poster.router, prefix="/api/poster", tags=["poster"]) +app.include_router(prompt.router, prefix="/api/prompt", tags=["prompt"]) @app.get("/") async def root(): """API根路径,返回简单的欢迎信息""" return {"message": "欢迎使用TravelContentCreator API服务"} -@app.get("/health") -async def health_check(): - """健康检查端点""" - return { - "status": "healthy", - "services": { - "ai_agent": ai_agent is not None, - "config": config_manager is not None, - "output_manager": output_manager is not None - } - } - if __name__ == "__main__": import uvicorn uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/api/models/__pycache__/poster.cpython-312.pyc b/api/models/__pycache__/poster.cpython-312.pyc new file mode 100644 index 0000000..639d214 Binary files /dev/null and b/api/models/__pycache__/poster.cpython-312.pyc differ diff --git a/api/models/__pycache__/prompt.cpython-312.pyc b/api/models/__pycache__/prompt.cpython-312.pyc new file mode 100644 index 0000000..bc0694a Binary files /dev/null and b/api/models/__pycache__/prompt.cpython-312.pyc differ diff --git a/api/models/prompt.py b/api/models/prompt.py new file mode 100644 index 0000000..07acbcf --- /dev/null +++ b/api/models/prompt.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +提示词API模型定义 +""" + +from typing import List, Dict, Any, Optional +from pydantic import BaseModel, Field + + +class StyleRequest(BaseModel): + """风格请求模型""" + name: str = Field(..., description="风格名称") + description: Optional[str] = Field(None, description="风格描述,如果为空则表示获取") + + class Config: + schema_extra = { + "example": { + "name": "攻略风", + "description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱" + } + } + + +class StyleResponse(BaseModel): + """风格响应模型""" + name: str = Field(..., description="风格名称") + description: str = Field(..., description="风格描述") + + class Config: + schema_extra = { + "example": { + "name": "攻略风", + "description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱" + } + } + + +class AudienceRequest(BaseModel): + """受众请求模型""" + name: str = Field(..., description="受众名称") + description: Optional[str] = Field(None, description="受众描述,如果为空则表示获取") + + class Config: + schema_extra = { + "example": { + "name": "亲子向", + "description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略" + } + } + + +class AudienceResponse(BaseModel): + """受众响应模型""" + name: str = Field(..., description="受众名称") + description: str = Field(..., description="受众描述") + + class Config: + schema_extra = { + "example": { + "name": "亲子向", + "description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略" + } + } + + +class ScenicSpotRequest(BaseModel): + """景区请求模型""" + name: str = Field(..., description="景区名称") + + class Config: + schema_extra = { + "example": { + "name": "天津冒险湾" + } + } + + +class ScenicSpotResponse(BaseModel): + """景区响应模型""" + name: str = Field(..., description="景区名称") + description: str = Field(..., description="景区描述") + + class Config: + schema_extra = { + "example": { + "name": "天津冒险湾", + "description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..." + } + } + + +class StyleListResponse(BaseModel): + """风格列表响应模型""" + styles: List[StyleResponse] = Field(..., description="风格列表") + + class Config: + schema_extra = { + "example": { + "styles": [ + { + "name": "攻略风", + "description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱" + }, + { + "name": "清新文艺风", + "description": "文艺范十足,清新脱俗的表达风格,注重意境和美感描述" + } + ] + } + } + + +class AudienceListResponse(BaseModel): + """受众列表响应模型""" + audiences: List[AudienceResponse] = Field(..., description="受众列表") + + class Config: + schema_extra = { + "example": { + "audiences": [ + { + "name": "亲子向", + "description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略" + }, + { + "name": "周边游", + "description": "全龄覆盖,主要围绕三天内的短期假期出游需求和周末出游需求" + } + ] + } + } + + +class ScenicSpotListResponse(BaseModel): + """景区列表响应模型""" + spots: List[ScenicSpotResponse] = Field(..., description="景区列表") + + class Config: + schema_extra = { + "example": { + "spots": [ + { + "name": "天津冒险湾", + "description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..." + } + ] + } + } + + +class PromptBuilderRequest(BaseModel): + """提示词构建请求模型""" + topic: Dict[str, Any] = Field(..., description="选题信息") + step: Optional[str] = Field(None, description="当前步骤,用于过滤参考内容") + + class Config: + schema_extra = { + "example": { + "topic": { + "index": "1", + "date": "2025-07-15", + "object": "天津冒险湾", + "product": "冒险湾-2大2小套票", + "style": "攻略风", + "target_audience": "亲子向", + "logic": "暑期亲子游热门景点推荐" + }, + "step": "content" + } + } + + +class PromptBuilderResponse(BaseModel): + """提示词构建响应模型""" + system_prompt: str = Field(..., description="系统提示词") + user_prompt: str = Field(..., description="用户提示词") + + class Config: + schema_extra = { + "example": { + "system_prompt": "你是一位专业的旅游内容创作者...", + "user_prompt": "请根据以下信息创作一篇旅游文章..." + } + } \ No newline at end of file diff --git a/api/routers/__pycache__/poster.cpython-312.pyc b/api/routers/__pycache__/poster.cpython-312.pyc new file mode 100644 index 0000000..1bd8334 Binary files /dev/null and b/api/routers/__pycache__/poster.cpython-312.pyc differ diff --git a/api/routers/__pycache__/prompt.cpython-312.pyc b/api/routers/__pycache__/prompt.cpython-312.pyc new file mode 100644 index 0000000..c344530 Binary files /dev/null and b/api/routers/__pycache__/prompt.cpython-312.pyc differ diff --git a/api/routers/__pycache__/tweet.cpython-312.pyc b/api/routers/__pycache__/tweet.cpython-312.pyc index f90c12e..43b8497 100644 Binary files a/api/routers/__pycache__/tweet.cpython-312.pyc and b/api/routers/__pycache__/tweet.cpython-312.pyc differ diff --git a/api/routers/poster.py b/api/routers/poster.py index 152f57c..829d2c6 100644 --- a/api/routers/poster.py +++ b/api/routers/poster.py @@ -19,8 +19,8 @@ from api.models.poster import ( PosterTextRequest, PosterTextResponse ) -# 从main.py中导入依赖 -from api.main import get_config, get_ai_agent, get_output_manager +# 从依赖注入模块导入依赖 +from api.dependencies import get_config, get_ai_agent, get_output_manager logger = logging.getLogger(__name__) diff --git a/api/routers/prompt.py b/api/routers/prompt.py new file mode 100644 index 0000000..5a83eac --- /dev/null +++ b/api/routers/prompt.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +提示词API路由 +""" + +import logging +from fastapi import APIRouter, Depends, HTTPException +from typing import List, Dict, Any + +from core.config import ConfigManager +from api.services.prompt_service import PromptService +from api.services.prompt_builder import PromptBuilderService +from api.models.prompt import ( + StyleRequest, StyleResponse, StyleListResponse, + AudienceRequest, AudienceResponse, AudienceListResponse, + ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse, + PromptBuilderRequest, PromptBuilderResponse +) + +# 从依赖注入模块导入依赖 +from api.dependencies import get_config + +logger = logging.getLogger(__name__) + +# 创建路由 +router = APIRouter() + +# 依赖注入函数 +def get_prompt_service( + config_manager: ConfigManager = Depends(get_config) +) -> PromptService: + """获取提示词服务""" + return PromptService(config_manager) + +def get_prompt_builder( + config_manager: ConfigManager = Depends(get_config), + prompt_service: PromptService = Depends(get_prompt_service) +) -> PromptBuilderService: + """获取提示词构建服务""" + return PromptBuilderService(config_manager, prompt_service) + + +@router.get("/styles", response_model=StyleListResponse) +async def get_all_styles( + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取所有内容风格""" + try: + styles_dict = prompt_service.get_all_styles() + # 将字典列表转换为StyleResponse对象列表 + styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict] + return StyleListResponse(styles=styles) + except Exception as e: + logger.error(f"获取所有风格失败: {e}") + raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}") + + +@router.get("/styles/{style_name}", response_model=StyleResponse) +async def get_style( + style_name: str, + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取指定内容风格""" + try: + content = prompt_service.get_style_content(style_name) + return StyleResponse(name=style_name, description=content) + except Exception as e: + logger.error(f"获取风格 '{style_name}' 失败: {e}") + raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}") + + +@router.post("/styles", response_model=StyleResponse) +async def create_or_update_style( + style: StyleRequest, + prompt_service: PromptService = Depends(get_prompt_service) +): + """创建或更新内容风格""" + try: + if not style.description: + # 如果没有提供描述,则获取现有的 + content = prompt_service.get_style_content(style.name) + return StyleResponse(name=style.name, description=content) + + success = prompt_service.save_style(style.name, style.description) + if not success: + raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败") + + return StyleResponse(name=style.name, description=style.description) + except Exception as e: + logger.error(f"保存风格 '{style.name}' 失败: {e}") + raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}") + + +@router.get("/audiences", response_model=AudienceListResponse) +async def get_all_audiences( + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取所有目标受众""" + try: + audiences_dict = prompt_service.get_all_audiences() + # 将字典列表转换为AudienceResponse对象列表 + audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict] + return AudienceListResponse(audiences=audiences) + except Exception as e: + logger.error(f"获取所有受众失败: {e}") + raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}") + + +@router.get("/audiences/{audience_name}", response_model=AudienceResponse) +async def get_audience( + audience_name: str, + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取指定目标受众""" + try: + content = prompt_service.get_audience_content(audience_name) + return AudienceResponse(name=audience_name, description=content) + except Exception as e: + logger.error(f"获取受众 '{audience_name}' 失败: {e}") + raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}") + + +@router.post("/audiences", response_model=AudienceResponse) +async def create_or_update_audience( + audience: AudienceRequest, + prompt_service: PromptService = Depends(get_prompt_service) +): + """创建或更新目标受众""" + try: + if not audience.description: + # 如果没有提供描述,则获取现有的 + content = prompt_service.get_audience_content(audience.name) + return AudienceResponse(name=audience.name, description=content) + + success = prompt_service.save_audience(audience.name, audience.description) + if not success: + raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败") + + return AudienceResponse(name=audience.name, description=audience.description) + except Exception as e: + logger.error(f"保存受众 '{audience.name}' 失败: {e}") + raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}") + + +@router.get("/scenic-spots", response_model=ScenicSpotListResponse) +async def get_all_scenic_spots( + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取所有景区""" + try: + spots_dict = prompt_service.get_all_scenic_spots() + # 将字典列表转换为ScenicSpotResponse对象列表 + spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict] + return ScenicSpotListResponse(spots=spots) + except Exception as e: + logger.error(f"获取所有景区失败: {e}") + raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}") + + +@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse) +async def get_scenic_spot( + spot_name: str, + prompt_service: PromptService = Depends(get_prompt_service) +): + """获取指定景区信息""" + try: + content = prompt_service.get_scenic_spot_info(spot_name) + return ScenicSpotResponse(name=spot_name, description=content) + except Exception as e: + logger.error(f"获取景区 '{spot_name}' 失败: {e}") + raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}") + + +@router.post("/build-prompt", response_model=PromptBuilderResponse) +async def build_prompt( + request: PromptBuilderRequest, + prompt_builder: PromptBuilderService = Depends(get_prompt_builder) +): + """构建完整提示词""" + try: + # 根据请求中的step确定构建哪种类型的提示词 + step = request.step or "content" + + if step == "topic": + # 构建选题提示词 + # 从topic中提取必要的参数 + num_topics = request.topic.get("num_topics", 5) + month = request.topic.get("month", "7") + system_prompt, user_prompt = prompt_builder.build_topic_prompt(num_topics, month) + elif step == "judge": + # 构建审核提示词 + # 需要提供生成的内容 + content = request.topic.get("content", {}) + system_prompt, user_prompt = prompt_builder.build_judge_prompt(request.topic, content) + else: + # 默认构建内容生成提示词 + system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step) + + return PromptBuilderResponse( + system_prompt=system_prompt, + user_prompt=user_prompt + ) + except Exception as e: + logger.error(f"构建提示词失败: {e}") + raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}") \ No newline at end of file diff --git a/api/routers/tweet.py b/api/routers/tweet.py index 3e71c4b..2ef0b88 100644 --- a/api/routers/tweet.py +++ b/api/routers/tweet.py @@ -20,8 +20,8 @@ from api.models.tweet import ( PipelineRequest, PipelineResponse ) -# 从main.py中导入依赖 -from api.main import get_config, get_ai_agent, get_output_manager +# 从依赖注入模块导入依赖 +from api.dependencies import get_config, get_ai_agent, get_output_manager logger = logging.getLogger(__name__) diff --git a/api/services/__pycache__/poster.cpython-312.pyc b/api/services/__pycache__/poster.cpython-312.pyc new file mode 100644 index 0000000..135d452 Binary files /dev/null and b/api/services/__pycache__/poster.cpython-312.pyc differ diff --git a/api/services/__pycache__/prompt_builder.cpython-312.pyc b/api/services/__pycache__/prompt_builder.cpython-312.pyc new file mode 100644 index 0000000..a8788c4 Binary files /dev/null and b/api/services/__pycache__/prompt_builder.cpython-312.pyc differ diff --git a/api/services/__pycache__/prompt_service.cpython-312.pyc b/api/services/__pycache__/prompt_service.cpython-312.pyc new file mode 100644 index 0000000..cf75200 Binary files /dev/null and b/api/services/__pycache__/prompt_service.cpython-312.pyc differ diff --git a/api/services/prompt_builder.py b/api/services/prompt_builder.py new file mode 100644 index 0000000..0a074f4 --- /dev/null +++ b/api/services/prompt_builder.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +提示词构建服务 +负责根据选题信息构建完整的提示词 +""" + +import logging +from typing import Dict, Any, Optional, Tuple +from pathlib import Path + +from core.config import ConfigManager, GenerateContentConfig +from utils.prompts import PromptTemplate +from api.services.prompt_service import PromptService + +logger = logging.getLogger(__name__) + + +class PromptBuilderService: + """提示词构建服务类""" + + def __init__(self, config_manager: ConfigManager, prompt_service: PromptService): + """ + 初始化提示词构建服务 + + Args: + config_manager: 配置管理器 + prompt_service: 提示词服务 + """ + self.config_manager = config_manager + self.prompt_service = prompt_service + self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig) + + def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]: + """ + 构建内容生成提示词 + + Args: + topic: 选题信息 + step: 当前步骤,用于过滤参考内容 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 加载系统提示词和用户提示词模板 + system_prompt_path = self.content_config.content_system_prompt + user_prompt_path = self.content_config.content_user_prompt + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取风格内容 + style_filename = topic.get("style", "") + style_content = self.prompt_service.get_style_content(style_filename) + + # 获取目标受众内容 + demand_filename = topic.get("target_audience", "") + demand_content = self.prompt_service.get_audience_content(demand_filename) + + # 获取景区信息 + object_name = topic.get("object", "") + object_content = self.prompt_service.get_scenic_spot_info(object_name) + + # 获取产品信息 + product_name = topic.get("product", "") + product_content = self.prompt_service.get_product_info(product_name) + + # 获取参考内容 + refer_content = self.prompt_service.get_refer_content(step) + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + style_content=f"{style_filename}\n{style_content}", + demand_content=f"{demand_filename}\n{demand_content}", + object_content=f"{object_name}\n{object_content}", + product_content=f"{product_name}\n{product_content}", + refer_content=refer_content + ) + + return system_prompt, user_prompt + + def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]: + """ + 构建选题生成提示词 + + Args: + num_topics: 要生成的选题数量 + month: 月份 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 从配置中获取选题提示词模板路径 + topic_config = self.config_manager.get_config('topic_gen', dict) + if not topic_config: + raise ValueError("未找到选题生成配置") + + system_prompt_path = topic_config.get("topic_system_prompt", "") + user_prompt_path = topic_config.get("topic_user_prompt", "") + + if not system_prompt_path or not user_prompt_path: + raise ValueError("选题提示词模板路径不完整") + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取风格列表 + styles = self.prompt_service.get_all_styles() + style_content = "Style文件列表:\n" + "\n".join([f"- {style['name']}" for style in styles]) + + # 获取目标受众列表 + audiences = self.prompt_service.get_all_audiences() + demand_content = "Demand文件列表:\n" + "\n".join([f"- {audience['name']}" for audience in audiences]) + + # 获取参考内容 + refer_content = self.prompt_service.get_refer_content("topic") + + # 获取景区信息列表 + spots = self.prompt_service.get_all_scenic_spots() + object_content = "Object信息:\n" + "\n".join([f"- {spot['name']}" for spot in spots]) + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 构建创作资料 + creative_materials = ( + f"你拥有的创作资料如下:\n" + f"{style_content}\n\n" + f"{demand_content}\n\n" + f"{refer_content}\n\n" + f"{object_content}" + ) + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + creative_materials=creative_materials, + num_topics=num_topics, + month=month + ) + + return system_prompt, user_prompt + + def build_judge_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]: + """ + 构建内容审核提示词 + + Args: + topic: 选题信息 + content: 生成的内容 + + Returns: + 系统提示词和用户提示词的元组 + """ + # 从配置中获取审核提示词模板路径 + system_prompt_path = self.content_config.judger_system_prompt + user_prompt_path = self.content_config.judger_user_prompt + + # 创建提示词模板 + template = PromptTemplate(system_prompt_path, user_prompt_path) + + # 获取景区信息 + object_name = topic.get("object", "") + object_content = self.prompt_service.get_scenic_spot_info(object_name) + + # 获取产品信息 + product_name = topic.get("product", "") + product_content = self.prompt_service.get_product_info(product_name) + + # 获取参考内容 + refer_content = self.prompt_service.get_refer_content("judge") + + # 构建系统提示词 + system_prompt = template.get_system_prompt() + + # 格式化内容 + import json + tweet_content = json.dumps(content, ensure_ascii=False, indent=4) + + # 构建用户提示词 + user_prompt = template.build_user_prompt( + tweet_content=tweet_content, + object_content=object_content, + product_content=product_content, + refer_content=refer_content + ) + + return system_prompt, user_prompt \ No newline at end of file diff --git a/api/services/prompt_service.py b/api/services/prompt_service.py new file mode 100644 index 0000000..076fd57 --- /dev/null +++ b/api/services/prompt_service.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +提示词服务层 +负责提示词的存储、检索和构建 +""" + +import logging +import json +import os +import re +from typing import Dict, Any, Optional, List, cast +from pathlib import Path +import mysql.connector +from mysql.connector import pooling + +from core.config import ConfigManager, ResourceConfig +from utils.prompts import BasePromptBuilder, PromptTemplate +from utils.file_io import ResourceLoader + +logger = logging.getLogger(__name__) + + +class PromptService: + """提示词服务类""" + + def __init__(self, config_manager: ConfigManager): + """ + 初始化提示词服务 + + Args: + config_manager: 配置管理器 + """ + self.config_manager = config_manager + self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig) + + # 初始化数据库连接池 + self._init_db_pool() + + def _init_db_pool(self): + """初始化数据库连接池""" + try: + # 尝试直接从配置文件加载数据库配置 + config_dir = Path("config") + db_config_path = config_dir / "database.json" + + if not db_config_path.exists(): + logger.warning(f"数据库配置文件不存在: {db_config_path}") + self.db_pool = None + return + + # 加载配置文件 + with open(db_config_path, 'r', encoding='utf-8') as f: + db_config = json.load(f) + + # 处理环境变量 + processed_config = {} + for key, value in db_config.items(): + if isinstance(value, str) and "${" in value: + # 匹配 ${ENV_VAR:-default} 格式 + pattern = r'\${([^:-]+)(?::-([^}]+))?}' + match = re.match(pattern, value) + if match: + env_var, default = match.groups() + processed_value = os.environ.get(env_var, default) + # 尝试转换为数字 + if key == "port": + try: + processed_value = int(processed_value) + except (ValueError, TypeError): + processed_value = 3306 + processed_config[key] = processed_value + else: + processed_config[key] = value + + # 创建连接池 + self.db_pool = pooling.MySQLConnectionPool( + pool_name="prompt_pool", + pool_size=5, + host=processed_config.get("host", "localhost"), + user=processed_config.get("user", "root"), + password=processed_config.get("password", ""), + database=processed_config.get("database", "travel_content"), + port=processed_config.get("port", 3306), + charset=processed_config.get("charset", "utf8mb4") + ) + logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}") + except Exception as e: + logger.error(f"初始化数据库连接池失败: {e}") + self.db_pool = None + + def get_style_content(self, style_name: str) -> str: + """ + 获取内容风格提示词 + + Args: + style_name: 风格名称 + + Returns: + 风格提示词内容 + """ + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT description FROM contentStyle WHERE styleName = %s", + (style_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + logger.info(f"从数据库获取风格提示词: {style_name}") + return result["description"] + except Exception as e: + logger.error(f"从数据库获取风格提示词失败: {e}") + + # 回退到文件系统 + try: + style_paths = self.resource_config.style.paths + for path_str in style_paths: + try: + if style_name in path_str: + full_path = self._get_full_path(path_str) + if full_path.exists(): + logger.info(f"从文件系统获取风格提示词: {style_name}") + return full_path.read_text('utf-8') + except Exception as e: + logger.error(f"读取风格文件失败 {path_str}: {e}") + + # 如果没有精确匹配,尝试模糊匹配 + for path_str in style_paths: + try: + full_path = self._get_full_path(path_str) + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + if style_name.lower() in full_path.stem.lower(): + logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}") + return content + except Exception as e: + logger.error(f"读取风格文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取风格提示词失败: {e}") + + logger.warning(f"未找到风格提示词: {style_name},将使用默认值") + return "通用风格" + + def get_audience_content(self, audience_name: str) -> str: + """ + 获取目标受众提示词 + + Args: + audience_name: 受众名称 + + Returns: + 受众提示词内容 + """ + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT description FROM targetAudience WHERE audienceName = %s", + (audience_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + logger.info(f"从数据库获取受众提示词: {audience_name}") + return result["description"] + except Exception as e: + logger.error(f"从数据库获取受众提示词失败: {e}") + + # 回退到文件系统 + try: + demand_paths = self.resource_config.demand.paths + for path_str in demand_paths: + try: + if audience_name in path_str: + full_path = self._get_full_path(path_str) + if full_path.exists(): + logger.info(f"从文件系统获取受众提示词: {audience_name}") + return full_path.read_text('utf-8') + except Exception as e: + logger.error(f"读取受众文件失败 {path_str}: {e}") + + # 如果没有精确匹配,尝试模糊匹配 + for path_str in demand_paths: + try: + full_path = self._get_full_path(path_str) + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + if audience_name.lower() in full_path.stem.lower(): + logger.info(f"通过模糊匹配找到受众提示词: {audience_name} -> {full_path.name}") + return content + except Exception as e: + logger.error(f"读取受众文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取受众提示词失败: {e}") + + logger.warning(f"未找到受众提示词: {audience_name},将使用默认值") + return "通用用户画像" + + def get_scenic_spot_info(self, spot_name: str) -> str: + """ + 获取景区信息 + + Args: + spot_name: 景区名称 + + Returns: + 景区信息内容 + """ + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT description FROM scenicSpot WHERE spotName = %s", + (spot_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + logger.info(f"从数据库获取景区信息: {spot_name}") + return result["description"] + except Exception as e: + logger.error(f"从数据库获取景区信息失败: {e}") + + # 回退到文件系统 + try: + object_paths = self.resource_config.object.paths + for path_str in object_paths: + try: + if spot_name in path_str: + full_path = self._get_full_path(path_str) + if full_path.exists(): + logger.info(f"从文件系统获取景区信息: {spot_name}") + return full_path.read_text('utf-8') + except Exception as e: + logger.error(f"读取景区文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取景区信息失败: {e}") + + logger.warning(f"未找到景区信息: {spot_name}") + return "无" + + def get_product_info(self, product_name: str) -> str: + """ + 获取产品信息 + + Args: + product_name: 产品名称 + + Returns: + 产品信息内容 + """ + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute( + "SELECT description FROM product WHERE productName = %s", + (product_name,) + ) + result = cursor.fetchone() + cursor.close() + conn.close() + + if result: + logger.info(f"从数据库获取产品信息: {product_name}") + return result["description"] + except Exception as e: + logger.error(f"从数据库获取产品信息失败: {e}") + + # 回退到文件系统 + try: + product_paths = self.resource_config.product.paths + for path_str in product_paths: + try: + if product_name in path_str: + full_path = self._get_full_path(path_str) + if full_path.exists(): + logger.info(f"从文件系统获取产品信息: {product_name}") + return full_path.read_text('utf-8') + except Exception as e: + logger.error(f"读取产品文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取产品信息失败: {e}") + + logger.warning(f"未找到产品信息: {product_name}") + return "无" + + def get_refer_content(self, step: str = "") -> str: + """ + 获取参考内容 + + Args: + step: 当前步骤,用于过滤参考内容 + + Returns: + 参考内容 + """ + refer_content = "参考内容:\n" + + # 从文件系统获取参考内容 + try: + refer_list = self.resource_config.refer.refer_list + filtered_configs = [ + item for item in refer_list + if not item.step or item.step == step + ] + + for ref_item in filtered_configs: + try: + path_str = ref_item.path + full_path = self._get_full_path(path_str) + + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + refer_content += f"--- {full_path.name} ---\n{content}\n\n" + except Exception as e: + logger.error(f"读取参考文件失败 {ref_item.path}: {e}") + except Exception as e: + logger.error(f"获取参考内容失败: {e}") + + return refer_content + + def _get_full_path(self, path_str: str) -> Path: + """根据基准目录解析相对或绝对路径""" + if not self.resource_config.resource_dirs: + raise ValueError("Resource directories list is empty in config.") + base_path = Path(self.resource_config.resource_dirs[0]) + file_path = Path(path_str) + return file_path if file_path.is_absolute() else (base_path / file_path).resolve() + + def get_all_styles(self) -> List[Dict[str, str]]: + """ + 获取所有内容风格 + + Returns: + 风格列表,每个元素包含name和description + """ + styles = [] + + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT styleName as name, description FROM contentStyle") + results = cursor.fetchall() + cursor.close() + conn.close() + + if results: + logger.info(f"从数据库获取所有风格: {len(results)}个") + return results + except Exception as e: + logger.error(f"从数据库获取所有风格失败: {e}") + + # 回退到文件系统 + try: + style_paths = self.resource_config.style.paths + for path_str in style_paths: + try: + full_path = self._get_full_path(path_str) + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + name = full_path.stem + if "文案提示词" in name: + name = name.replace("文案提示词", "") + styles.append({ + "name": name, + "description": content[:100] + "..." if len(content) > 100 else content + }) + except Exception as e: + logger.error(f"读取风格文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取所有风格失败: {e}") + + return styles + + def get_all_audiences(self) -> List[Dict[str, str]]: + """ + 获取所有目标受众 + + Returns: + 受众列表,每个元素包含name和description + """ + audiences = [] + + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT audienceName as name, description FROM targetAudience") + results = cursor.fetchall() + cursor.close() + conn.close() + + if results: + logger.info(f"从数据库获取所有受众: {len(results)}个") + return results + except Exception as e: + logger.error(f"从数据库获取所有受众失败: {e}") + + # 回退到文件系统 + try: + demand_paths = self.resource_config.demand.paths + for path_str in demand_paths: + try: + full_path = self._get_full_path(path_str) + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + name = full_path.stem + if "文旅需求" in name: + name = name.replace("文旅需求", "") + audiences.append({ + "name": name, + "description": content[:100] + "..." if len(content) > 100 else content + }) + except Exception as e: + logger.error(f"读取受众文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取所有受众失败: {e}") + + return audiences + + def get_all_scenic_spots(self) -> List[Dict[str, str]]: + """ + 获取所有景区 + + Returns: + 景区列表,每个元素包含name和description + """ + spots = [] + + # 优先从数据库获取 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT spotName as name, description FROM scenicSpot") + results = cursor.fetchall() + cursor.close() + conn.close() + + if results: + logger.info(f"从数据库获取所有景区: {len(results)}个") + return [{ + "name": item["name"], + "description": item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + } for item in results] + except Exception as e: + logger.error(f"从数据库获取所有景区失败: {e}") + + # 回退到文件系统 + try: + object_paths = self.resource_config.object.paths + for path_str in object_paths: + try: + full_path = self._get_full_path(path_str) + if full_path.exists() and full_path.is_file(): + content = full_path.read_text('utf-8') + spots.append({ + "name": full_path.stem, + "description": content[:100] + "..." if len(content) > 100 else content + }) + except Exception as e: + logger.error(f"读取景区文件失败 {path_str}: {e}") + except Exception as e: + logger.error(f"获取所有景区失败: {e}") + + return spots + + def save_style(self, name: str, description: str) -> bool: + """ + 保存内容风格 + + Args: + name: 风格名称 + description: 风格描述 + + Returns: + 是否保存成功 + """ + # 优先保存到数据库 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor() + + # 检查是否存在 + cursor.execute( + "SELECT COUNT(*) FROM contentStyle WHERE styleName = %s", + (name,) + ) + count = cursor.fetchone()[0] + + if count > 0: + # 更新 + cursor.execute( + "UPDATE contentStyle SET description = %s WHERE styleName = %s", + (description, name) + ) + else: + # 插入 + cursor.execute( + "INSERT INTO contentStyle (styleName, description) VALUES (%s, %s)", + (name, description) + ) + + conn.commit() + cursor.close() + conn.close() + logger.info(f"风格保存到数据库成功: {name}") + return True + except Exception as e: + logger.error(f"风格保存到数据库失败: {e}") + + # 回退到文件系统 + try: + # 确保风格目录存在 + style_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Style" + style_dir.mkdir(parents=True, exist_ok=True) + + # 保存文件 + file_path = style_dir / f"{name}文案提示词.md" + with open(file_path, 'w', encoding='utf-8') as f: + f.write(description) + + # 更新配置 + if str(file_path) not in self.resource_config.style.paths: + self.resource_config.style.paths.append(str(file_path)) + + logger.info(f"风格保存到文件系统成功: {file_path}") + return True + except Exception as e: + logger.error(f"风格保存到文件系统失败: {e}") + return False + + def save_audience(self, name: str, description: str) -> bool: + """ + 保存目标受众 + + Args: + name: 受众名称 + description: 受众描述 + + Returns: + 是否保存成功 + """ + # 优先保存到数据库 + if self.db_pool: + try: + conn = self.db_pool.get_connection() + cursor = conn.cursor() + + # 检查是否存在 + cursor.execute( + "SELECT COUNT(*) FROM targetAudience WHERE audienceName = %s", + (name,) + ) + count = cursor.fetchone()[0] + + if count > 0: + # 更新 + cursor.execute( + "UPDATE targetAudience SET description = %s WHERE audienceName = %s", + (description, name) + ) + else: + # 插入 + cursor.execute( + "INSERT INTO targetAudience (audienceName, description) VALUES (%s, %s)", + (name, description) + ) + + conn.commit() + cursor.close() + conn.close() + logger.info(f"受众保存到数据库成功: {name}") + return True + except Exception as e: + logger.error(f"受众保存到数据库失败: {e}") + + # 回退到文件系统 + try: + # 确保受众目录存在 + demand_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Demand" + demand_dir.mkdir(parents=True, exist_ok=True) + + # 保存文件 + file_path = demand_dir / f"{name}文旅需求.md" + with open(file_path, 'w', encoding='utf-8') as f: + f.write(description) + + # 更新配置 + if str(file_path) not in self.resource_config.demand.paths: + self.resource_config.demand.paths.append(str(file_path)) + + logger.info(f"受众保存到文件系统成功: {file_path}") + return True + except Exception as e: + logger.error(f"受众保存到文件系统失败: {e}") + return False \ No newline at end of file