链接数据库,可以运行
This commit is contained in:
parent
9da9a39062
commit
19ca9d06ce
BIN
api/__pycache__/dependencies.cpython-312.pyc
Normal file
BIN
api/__pycache__/dependencies.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
45
api/dependencies.py
Normal file
45
api/dependencies.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
API依赖注入模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.config import get_config_manager, ConfigManager
|
||||||
|
from core.ai import AIAgent
|
||||||
|
from utils.file_io import OutputManager
|
||||||
|
|
||||||
|
# 全局依赖
|
||||||
|
config_manager = None
|
||||||
|
ai_agent = None
|
||||||
|
output_manager = None
|
||||||
|
|
||||||
|
def initialize_dependencies():
|
||||||
|
"""初始化全局依赖"""
|
||||||
|
global config_manager, ai_agent, output_manager
|
||||||
|
|
||||||
|
# 初始化配置
|
||||||
|
config_manager = get_config_manager()
|
||||||
|
config_manager.load_from_directory("config")
|
||||||
|
|
||||||
|
# 初始化输出管理器
|
||||||
|
from datetime import datetime
|
||||||
|
run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
output_manager = OutputManager("result", run_id)
|
||||||
|
|
||||||
|
# 初始化AI代理
|
||||||
|
from core.config import AIModelConfig
|
||||||
|
ai_config = config_manager.get_config('ai_model', AIModelConfig)
|
||||||
|
ai_agent = AIAgent(ai_config)
|
||||||
|
|
||||||
|
def get_config() -> ConfigManager:
|
||||||
|
"""获取配置管理器"""
|
||||||
|
return config_manager
|
||||||
|
|
||||||
|
def get_ai_agent() -> AIAgent:
|
||||||
|
"""获取AI代理"""
|
||||||
|
return ai_agent
|
||||||
|
|
||||||
|
def get_output_manager() -> OutputManager:
|
||||||
|
"""获取输出管理器"""
|
||||||
|
return output_manager
|
||||||
52
api/main.py
52
api/main.py
@ -7,16 +7,11 @@ TravelContentCreator API服务
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import FastAPI, Depends
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from core.config import get_config_manager, AIModelConfig
|
from api import dependencies
|
||||||
from core.ai import AIAgent
|
|
||||||
from utils.file_io import OutputManager
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from api.routers import tweet, poster
|
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -26,32 +21,15 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 全局依赖
|
|
||||||
config_manager = None
|
|
||||||
ai_agent = None
|
|
||||||
output_manager = None
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""
|
"""
|
||||||
应用生命周期管理
|
应用生命周期管理
|
||||||
在应用启动时初始化全局依赖,在应用关闭时清理资源
|
在应用启动时初始化全局依赖,在应用关闭时清理资源
|
||||||
"""
|
"""
|
||||||
global config_manager, ai_agent, output_manager
|
|
||||||
|
|
||||||
# 初始化配置
|
# 初始化配置
|
||||||
logger.info("正在初始化API服务...")
|
logger.info("正在初始化API服务...")
|
||||||
config_manager = get_config_manager()
|
dependencies.initialize_dependencies()
|
||||||
config_manager.load_from_directory("config")
|
|
||||||
|
|
||||||
# 初始化输出管理器
|
|
||||||
run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
||||||
output_manager = OutputManager("result", run_id)
|
|
||||||
|
|
||||||
# 初始化AI代理
|
|
||||||
ai_config = config_manager.get_config('ai_model', AIModelConfig)
|
|
||||||
ai_agent = AIAgent(ai_config)
|
|
||||||
|
|
||||||
logger.info("API服务初始化完成")
|
logger.info("API服务初始化完成")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -77,37 +55,19 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 依赖注入函数
|
# 导入路由
|
||||||
def get_config():
|
from api.routers import tweet, poster, prompt
|
||||||
return config_manager
|
|
||||||
|
|
||||||
def get_ai_agent():
|
|
||||||
return ai_agent
|
|
||||||
|
|
||||||
def get_output_manager():
|
|
||||||
return output_manager
|
|
||||||
|
|
||||||
# 包含路由
|
# 包含路由
|
||||||
app.include_router(tweet.router, prefix="/api/tweet", tags=["tweet"])
|
app.include_router(tweet.router, prefix="/api/tweet", tags=["tweet"])
|
||||||
app.include_router(poster.router, prefix="/api/poster", tags=["poster"])
|
app.include_router(poster.router, prefix="/api/poster", tags=["poster"])
|
||||||
|
app.include_router(prompt.router, prefix="/api/prompt", tags=["prompt"])
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
"""API根路径,返回简单的欢迎信息"""
|
"""API根路径,返回简单的欢迎信息"""
|
||||||
return {"message": "欢迎使用TravelContentCreator API服务"}
|
return {"message": "欢迎使用TravelContentCreator API服务"}
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""健康检查端点"""
|
|
||||||
return {
|
|
||||||
"status": "healthy",
|
|
||||||
"services": {
|
|
||||||
"ai_agent": ai_agent is not None,
|
|
||||||
"config": config_manager is not None,
|
|
||||||
"output_manager": output_manager is not None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)
|
uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
BIN
api/models/__pycache__/poster.cpython-312.pyc
Normal file
BIN
api/models/__pycache__/poster.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/models/__pycache__/prompt.cpython-312.pyc
Normal file
BIN
api/models/__pycache__/prompt.cpython-312.pyc
Normal file
Binary file not shown.
186
api/models/prompt.py
Normal file
186
api/models/prompt.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
提示词API模型定义
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class StyleRequest(BaseModel):
|
||||||
|
"""风格请求模型"""
|
||||||
|
name: str = Field(..., description="风格名称")
|
||||||
|
description: Optional[str] = Field(None, description="风格描述,如果为空则表示获取")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "攻略风",
|
||||||
|
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StyleResponse(BaseModel):
|
||||||
|
"""风格响应模型"""
|
||||||
|
name: str = Field(..., description="风格名称")
|
||||||
|
description: str = Field(..., description="风格描述")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "攻略风",
|
||||||
|
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AudienceRequest(BaseModel):
|
||||||
|
"""受众请求模型"""
|
||||||
|
name: str = Field(..., description="受众名称")
|
||||||
|
description: Optional[str] = Field(None, description="受众描述,如果为空则表示获取")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "亲子向",
|
||||||
|
"description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AudienceResponse(BaseModel):
|
||||||
|
"""受众响应模型"""
|
||||||
|
name: str = Field(..., description="受众名称")
|
||||||
|
description: str = Field(..., description="受众描述")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "亲子向",
|
||||||
|
"description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ScenicSpotRequest(BaseModel):
|
||||||
|
"""景区请求模型"""
|
||||||
|
name: str = Field(..., description="景区名称")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "天津冒险湾"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ScenicSpotResponse(BaseModel):
|
||||||
|
"""景区响应模型"""
|
||||||
|
name: str = Field(..., description="景区名称")
|
||||||
|
description: str = Field(..., description="景区描述")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"name": "天津冒险湾",
|
||||||
|
"description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StyleListResponse(BaseModel):
|
||||||
|
"""风格列表响应模型"""
|
||||||
|
styles: List[StyleResponse] = Field(..., description="风格列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"styles": [
|
||||||
|
{
|
||||||
|
"name": "攻略风",
|
||||||
|
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "清新文艺风",
|
||||||
|
"description": "文艺范十足,清新脱俗的表达风格,注重意境和美感描述"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AudienceListResponse(BaseModel):
|
||||||
|
"""受众列表响应模型"""
|
||||||
|
audiences: List[AudienceResponse] = Field(..., description="受众列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"audiences": [
|
||||||
|
{
|
||||||
|
"name": "亲子向",
|
||||||
|
"description": "25-45岁家长群体,孩子年龄3-12岁,注重安全和教育意义,偏好收藏实用攻略"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "周边游",
|
||||||
|
"description": "全龄覆盖,主要围绕三天内的短期假期出游需求和周末出游需求"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ScenicSpotListResponse(BaseModel):
|
||||||
|
"""景区列表响应模型"""
|
||||||
|
spots: List[ScenicSpotResponse] = Field(..., description="景区列表")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"spots": [
|
||||||
|
{
|
||||||
|
"name": "天津冒险湾",
|
||||||
|
"description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilderRequest(BaseModel):
|
||||||
|
"""提示词构建请求模型"""
|
||||||
|
topic: Dict[str, Any] = Field(..., description="选题信息")
|
||||||
|
step: Optional[str] = Field(None, description="当前步骤,用于过滤参考内容")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"topic": {
|
||||||
|
"index": "1",
|
||||||
|
"date": "2025-07-15",
|
||||||
|
"object": "天津冒险湾",
|
||||||
|
"product": "冒险湾-2大2小套票",
|
||||||
|
"style": "攻略风",
|
||||||
|
"target_audience": "亲子向",
|
||||||
|
"logic": "暑期亲子游热门景点推荐"
|
||||||
|
},
|
||||||
|
"step": "content"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilderResponse(BaseModel):
|
||||||
|
"""提示词构建响应模型"""
|
||||||
|
system_prompt: str = Field(..., description="系统提示词")
|
||||||
|
user_prompt: str = Field(..., description="用户提示词")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"system_prompt": "你是一位专业的旅游内容创作者...",
|
||||||
|
"user_prompt": "请根据以下信息创作一篇旅游文章..."
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
api/routers/__pycache__/poster.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/poster.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/routers/__pycache__/prompt.cpython-312.pyc
Normal file
BIN
api/routers/__pycache__/prompt.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
@ -19,8 +19,8 @@ from api.models.poster import (
|
|||||||
PosterTextRequest, PosterTextResponse
|
PosterTextRequest, PosterTextResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从main.py中导入依赖
|
# 从依赖注入模块导入依赖
|
||||||
from api.main import get_config, get_ai_agent, get_output_manager
|
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
207
api/routers/prompt.py
Normal file
207
api/routers/prompt.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
提示词API路由
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from core.config import ConfigManager
|
||||||
|
from api.services.prompt_service import PromptService
|
||||||
|
from api.services.prompt_builder import PromptBuilderService
|
||||||
|
from api.models.prompt import (
|
||||||
|
StyleRequest, StyleResponse, StyleListResponse,
|
||||||
|
AudienceRequest, AudienceResponse, AudienceListResponse,
|
||||||
|
ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse,
|
||||||
|
PromptBuilderRequest, PromptBuilderResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从依赖注入模块导入依赖
|
||||||
|
from api.dependencies import get_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 创建路由
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 依赖注入函数
|
||||||
|
def get_prompt_service(
|
||||||
|
config_manager: ConfigManager = Depends(get_config)
|
||||||
|
) -> PromptService:
|
||||||
|
"""获取提示词服务"""
|
||||||
|
return PromptService(config_manager)
|
||||||
|
|
||||||
|
def get_prompt_builder(
|
||||||
|
config_manager: ConfigManager = Depends(get_config),
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
) -> PromptBuilderService:
|
||||||
|
"""获取提示词构建服务"""
|
||||||
|
return PromptBuilderService(config_manager, prompt_service)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/styles", response_model=StyleListResponse)
|
||||||
|
async def get_all_styles(
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取所有内容风格"""
|
||||||
|
try:
|
||||||
|
styles_dict = prompt_service.get_all_styles()
|
||||||
|
# 将字典列表转换为StyleResponse对象列表
|
||||||
|
styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict]
|
||||||
|
return StyleListResponse(styles=styles)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有风格失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/styles/{style_name}", response_model=StyleResponse)
|
||||||
|
async def get_style(
|
||||||
|
style_name: str,
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取指定内容风格"""
|
||||||
|
try:
|
||||||
|
content = prompt_service.get_style_content(style_name)
|
||||||
|
return StyleResponse(name=style_name, description=content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取风格 '{style_name}' 失败: {e}")
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/styles", response_model=StyleResponse)
|
||||||
|
async def create_or_update_style(
|
||||||
|
style: StyleRequest,
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""创建或更新内容风格"""
|
||||||
|
try:
|
||||||
|
if not style.description:
|
||||||
|
# 如果没有提供描述,则获取现有的
|
||||||
|
content = prompt_service.get_style_content(style.name)
|
||||||
|
return StyleResponse(name=style.name, description=content)
|
||||||
|
|
||||||
|
success = prompt_service.save_style(style.name, style.description)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败")
|
||||||
|
|
||||||
|
return StyleResponse(name=style.name, description=style.description)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存风格 '{style.name}' 失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/audiences", response_model=AudienceListResponse)
|
||||||
|
async def get_all_audiences(
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取所有目标受众"""
|
||||||
|
try:
|
||||||
|
audiences_dict = prompt_service.get_all_audiences()
|
||||||
|
# 将字典列表转换为AudienceResponse对象列表
|
||||||
|
audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict]
|
||||||
|
return AudienceListResponse(audiences=audiences)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有受众失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/audiences/{audience_name}", response_model=AudienceResponse)
|
||||||
|
async def get_audience(
|
||||||
|
audience_name: str,
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取指定目标受众"""
|
||||||
|
try:
|
||||||
|
content = prompt_service.get_audience_content(audience_name)
|
||||||
|
return AudienceResponse(name=audience_name, description=content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取受众 '{audience_name}' 失败: {e}")
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/audiences", response_model=AudienceResponse)
|
||||||
|
async def create_or_update_audience(
|
||||||
|
audience: AudienceRequest,
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""创建或更新目标受众"""
|
||||||
|
try:
|
||||||
|
if not audience.description:
|
||||||
|
# 如果没有提供描述,则获取现有的
|
||||||
|
content = prompt_service.get_audience_content(audience.name)
|
||||||
|
return AudienceResponse(name=audience.name, description=content)
|
||||||
|
|
||||||
|
success = prompt_service.save_audience(audience.name, audience.description)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败")
|
||||||
|
|
||||||
|
return AudienceResponse(name=audience.name, description=audience.description)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存受众 '{audience.name}' 失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenic-spots", response_model=ScenicSpotListResponse)
|
||||||
|
async def get_all_scenic_spots(
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取所有景区"""
|
||||||
|
try:
|
||||||
|
spots_dict = prompt_service.get_all_scenic_spots()
|
||||||
|
# 将字典列表转换为ScenicSpotResponse对象列表
|
||||||
|
spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict]
|
||||||
|
return ScenicSpotListResponse(spots=spots)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有景区失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse)
|
||||||
|
async def get_scenic_spot(
|
||||||
|
spot_name: str,
|
||||||
|
prompt_service: PromptService = Depends(get_prompt_service)
|
||||||
|
):
|
||||||
|
"""获取指定景区信息"""
|
||||||
|
try:
|
||||||
|
content = prompt_service.get_scenic_spot_info(spot_name)
|
||||||
|
return ScenicSpotResponse(name=spot_name, description=content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取景区 '{spot_name}' 失败: {e}")
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/build-prompt", response_model=PromptBuilderResponse)
|
||||||
|
async def build_prompt(
|
||||||
|
request: PromptBuilderRequest,
|
||||||
|
prompt_builder: PromptBuilderService = Depends(get_prompt_builder)
|
||||||
|
):
|
||||||
|
"""构建完整提示词"""
|
||||||
|
try:
|
||||||
|
# 根据请求中的step确定构建哪种类型的提示词
|
||||||
|
step = request.step or "content"
|
||||||
|
|
||||||
|
if step == "topic":
|
||||||
|
# 构建选题提示词
|
||||||
|
# 从topic中提取必要的参数
|
||||||
|
num_topics = request.topic.get("num_topics", 5)
|
||||||
|
month = request.topic.get("month", "7")
|
||||||
|
system_prompt, user_prompt = prompt_builder.build_topic_prompt(num_topics, month)
|
||||||
|
elif step == "judge":
|
||||||
|
# 构建审核提示词
|
||||||
|
# 需要提供生成的内容
|
||||||
|
content = request.topic.get("content", {})
|
||||||
|
system_prompt, user_prompt = prompt_builder.build_judge_prompt(request.topic, content)
|
||||||
|
else:
|
||||||
|
# 默认构建内容生成提示词
|
||||||
|
system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step)
|
||||||
|
|
||||||
|
return PromptBuilderResponse(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=user_prompt
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建提示词失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}")
|
||||||
@ -20,8 +20,8 @@ from api.models.tweet import (
|
|||||||
PipelineRequest, PipelineResponse
|
PipelineRequest, PipelineResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从main.py中导入依赖
|
# 从依赖注入模块导入依赖
|
||||||
from api.main import get_config, get_ai_agent, get_output_manager
|
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
BIN
api/services/__pycache__/poster.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/poster.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/services/__pycache__/prompt_builder.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/prompt_builder.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/services/__pycache__/prompt_service.cpython-312.pyc
Normal file
BIN
api/services/__pycache__/prompt_service.cpython-312.pyc
Normal file
Binary file not shown.
191
api/services/prompt_builder.py
Normal file
191
api/services/prompt_builder.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
提示词构建服务
|
||||||
|
负责根据选题信息构建完整的提示词
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from core.config import ConfigManager, GenerateContentConfig
|
||||||
|
from utils.prompts import PromptTemplate
|
||||||
|
from api.services.prompt_service import PromptService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilderService:
|
||||||
|
"""提示词构建服务类"""
|
||||||
|
|
||||||
|
def __init__(self, config_manager: ConfigManager, prompt_service: PromptService):
|
||||||
|
"""
|
||||||
|
初始化提示词构建服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_manager: 配置管理器
|
||||||
|
prompt_service: 提示词服务
|
||||||
|
"""
|
||||||
|
self.config_manager = config_manager
|
||||||
|
self.prompt_service = prompt_service
|
||||||
|
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
|
||||||
|
|
||||||
|
def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
构建内容生成提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 选题信息
|
||||||
|
step: 当前步骤,用于过滤参考内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统提示词和用户提示词的元组
|
||||||
|
"""
|
||||||
|
# 加载系统提示词和用户提示词模板
|
||||||
|
system_prompt_path = self.content_config.content_system_prompt
|
||||||
|
user_prompt_path = self.content_config.content_user_prompt
|
||||||
|
|
||||||
|
# 创建提示词模板
|
||||||
|
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||||
|
|
||||||
|
# 获取风格内容
|
||||||
|
style_filename = topic.get("style", "")
|
||||||
|
style_content = self.prompt_service.get_style_content(style_filename)
|
||||||
|
|
||||||
|
# 获取目标受众内容
|
||||||
|
demand_filename = topic.get("target_audience", "")
|
||||||
|
demand_content = self.prompt_service.get_audience_content(demand_filename)
|
||||||
|
|
||||||
|
# 获取景区信息
|
||||||
|
object_name = topic.get("object", "")
|
||||||
|
object_content = self.prompt_service.get_scenic_spot_info(object_name)
|
||||||
|
|
||||||
|
# 获取产品信息
|
||||||
|
product_name = topic.get("product", "")
|
||||||
|
product_content = self.prompt_service.get_product_info(product_name)
|
||||||
|
|
||||||
|
# 获取参考内容
|
||||||
|
refer_content = self.prompt_service.get_refer_content(step)
|
||||||
|
|
||||||
|
# 构建系统提示词
|
||||||
|
system_prompt = template.get_system_prompt()
|
||||||
|
|
||||||
|
# 构建用户提示词
|
||||||
|
user_prompt = template.build_user_prompt(
|
||||||
|
style_content=f"{style_filename}\n{style_content}",
|
||||||
|
demand_content=f"{demand_filename}\n{demand_content}",
|
||||||
|
object_content=f"{object_name}\n{object_content}",
|
||||||
|
product_content=f"{product_name}\n{product_content}",
|
||||||
|
refer_content=refer_content
|
||||||
|
)
|
||||||
|
|
||||||
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
构建选题生成提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_topics: 要生成的选题数量
|
||||||
|
month: 月份
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统提示词和用户提示词的元组
|
||||||
|
"""
|
||||||
|
# 从配置中获取选题提示词模板路径
|
||||||
|
topic_config = self.config_manager.get_config('topic_gen', dict)
|
||||||
|
if not topic_config:
|
||||||
|
raise ValueError("未找到选题生成配置")
|
||||||
|
|
||||||
|
system_prompt_path = topic_config.get("topic_system_prompt", "")
|
||||||
|
user_prompt_path = topic_config.get("topic_user_prompt", "")
|
||||||
|
|
||||||
|
if not system_prompt_path or not user_prompt_path:
|
||||||
|
raise ValueError("选题提示词模板路径不完整")
|
||||||
|
|
||||||
|
# 创建提示词模板
|
||||||
|
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||||
|
|
||||||
|
# 获取风格列表
|
||||||
|
styles = self.prompt_service.get_all_styles()
|
||||||
|
style_content = "Style文件列表:\n" + "\n".join([f"- {style['name']}" for style in styles])
|
||||||
|
|
||||||
|
# 获取目标受众列表
|
||||||
|
audiences = self.prompt_service.get_all_audiences()
|
||||||
|
demand_content = "Demand文件列表:\n" + "\n".join([f"- {audience['name']}" for audience in audiences])
|
||||||
|
|
||||||
|
# 获取参考内容
|
||||||
|
refer_content = self.prompt_service.get_refer_content("topic")
|
||||||
|
|
||||||
|
# 获取景区信息列表
|
||||||
|
spots = self.prompt_service.get_all_scenic_spots()
|
||||||
|
object_content = "Object信息:\n" + "\n".join([f"- {spot['name']}" for spot in spots])
|
||||||
|
|
||||||
|
# 构建系统提示词
|
||||||
|
system_prompt = template.get_system_prompt()
|
||||||
|
|
||||||
|
# 构建创作资料
|
||||||
|
creative_materials = (
|
||||||
|
f"你拥有的创作资料如下:\n"
|
||||||
|
f"{style_content}\n\n"
|
||||||
|
f"{demand_content}\n\n"
|
||||||
|
f"{refer_content}\n\n"
|
||||||
|
f"{object_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建用户提示词
|
||||||
|
user_prompt = template.build_user_prompt(
|
||||||
|
creative_materials=creative_materials,
|
||||||
|
num_topics=num_topics,
|
||||||
|
month=month
|
||||||
|
)
|
||||||
|
|
||||||
|
return system_prompt, user_prompt
|
||||||
|
|
||||||
|
def build_judge_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
构建内容审核提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
topic: 选题信息
|
||||||
|
content: 生成的内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统提示词和用户提示词的元组
|
||||||
|
"""
|
||||||
|
# 从配置中获取审核提示词模板路径
|
||||||
|
system_prompt_path = self.content_config.judger_system_prompt
|
||||||
|
user_prompt_path = self.content_config.judger_user_prompt
|
||||||
|
|
||||||
|
# 创建提示词模板
|
||||||
|
template = PromptTemplate(system_prompt_path, user_prompt_path)
|
||||||
|
|
||||||
|
# 获取景区信息
|
||||||
|
object_name = topic.get("object", "")
|
||||||
|
object_content = self.prompt_service.get_scenic_spot_info(object_name)
|
||||||
|
|
||||||
|
# 获取产品信息
|
||||||
|
product_name = topic.get("product", "")
|
||||||
|
product_content = self.prompt_service.get_product_info(product_name)
|
||||||
|
|
||||||
|
# 获取参考内容
|
||||||
|
refer_content = self.prompt_service.get_refer_content("judge")
|
||||||
|
|
||||||
|
# 构建系统提示词
|
||||||
|
system_prompt = template.get_system_prompt()
|
||||||
|
|
||||||
|
# 格式化内容
|
||||||
|
import json
|
||||||
|
tweet_content = json.dumps(content, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
# 构建用户提示词
|
||||||
|
user_prompt = template.build_user_prompt(
|
||||||
|
tweet_content=tweet_content,
|
||||||
|
object_content=object_content,
|
||||||
|
product_content=product_content,
|
||||||
|
refer_content=refer_content
|
||||||
|
)
|
||||||
|
|
||||||
|
return system_prompt, user_prompt
|
||||||
619
api/services/prompt_service.py
Normal file
619
api/services/prompt_service.py
Normal file
@ -0,0 +1,619 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
提示词服务层
|
||||||
|
负责提示词的存储、检索和构建
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, Any, Optional, List, cast
|
||||||
|
from pathlib import Path
|
||||||
|
import mysql.connector
|
||||||
|
from mysql.connector import pooling
|
||||||
|
|
||||||
|
from core.config import ConfigManager, ResourceConfig
|
||||||
|
from utils.prompts import BasePromptBuilder, PromptTemplate
|
||||||
|
from utils.file_io import ResourceLoader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptService:
|
||||||
|
"""提示词服务类"""
|
||||||
|
|
||||||
|
def __init__(self, config_manager: ConfigManager):
|
||||||
|
"""
|
||||||
|
初始化提示词服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_manager: 配置管理器
|
||||||
|
"""
|
||||||
|
self.config_manager = config_manager
|
||||||
|
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
|
||||||
|
|
||||||
|
# 初始化数据库连接池
|
||||||
|
self._init_db_pool()
|
||||||
|
|
||||||
|
def _init_db_pool(self):
|
||||||
|
"""初始化数据库连接池"""
|
||||||
|
try:
|
||||||
|
# 尝试直接从配置文件加载数据库配置
|
||||||
|
config_dir = Path("config")
|
||||||
|
db_config_path = config_dir / "database.json"
|
||||||
|
|
||||||
|
if not db_config_path.exists():
|
||||||
|
logger.warning(f"数据库配置文件不存在: {db_config_path}")
|
||||||
|
self.db_pool = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
with open(db_config_path, 'r', encoding='utf-8') as f:
|
||||||
|
db_config = json.load(f)
|
||||||
|
|
||||||
|
# 处理环境变量
|
||||||
|
processed_config = {}
|
||||||
|
for key, value in db_config.items():
|
||||||
|
if isinstance(value, str) and "${" in value:
|
||||||
|
# 匹配 ${ENV_VAR:-default} 格式
|
||||||
|
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
|
||||||
|
match = re.match(pattern, value)
|
||||||
|
if match:
|
||||||
|
env_var, default = match.groups()
|
||||||
|
processed_value = os.environ.get(env_var, default)
|
||||||
|
# 尝试转换为数字
|
||||||
|
if key == "port":
|
||||||
|
try:
|
||||||
|
processed_value = int(processed_value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
processed_value = 3306
|
||||||
|
processed_config[key] = processed_value
|
||||||
|
else:
|
||||||
|
processed_config[key] = value
|
||||||
|
|
||||||
|
# 创建连接池
|
||||||
|
self.db_pool = pooling.MySQLConnectionPool(
|
||||||
|
pool_name="prompt_pool",
|
||||||
|
pool_size=5,
|
||||||
|
host=processed_config.get("host", "localhost"),
|
||||||
|
user=processed_config.get("user", "root"),
|
||||||
|
password=processed_config.get("password", ""),
|
||||||
|
database=processed_config.get("database", "travel_content"),
|
||||||
|
port=processed_config.get("port", 3306),
|
||||||
|
charset=processed_config.get("charset", "utf8mb4")
|
||||||
|
)
|
||||||
|
logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化数据库连接池失败: {e}")
|
||||||
|
self.db_pool = None
|
||||||
|
|
||||||
|
def get_style_content(self, style_name: str) -> str:
|
||||||
|
"""
|
||||||
|
获取内容风格提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
style_name: 风格名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
风格提示词内容
|
||||||
|
"""
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT description FROM contentStyle WHERE styleName = %s",
|
||||||
|
(style_name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"从数据库获取风格提示词: {style_name}")
|
||||||
|
return result["description"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取风格提示词失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
style_paths = self.resource_config.style.paths
|
||||||
|
for path_str in style_paths:
|
||||||
|
try:
|
||||||
|
if style_name in path_str:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists():
|
||||||
|
logger.info(f"从文件系统获取风格提示词: {style_name}")
|
||||||
|
return full_path.read_text('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取风格文件失败 {path_str}: {e}")
|
||||||
|
|
||||||
|
# 如果没有精确匹配,尝试模糊匹配
|
||||||
|
for path_str in style_paths:
|
||||||
|
try:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
if style_name.lower() in full_path.stem.lower():
|
||||||
|
logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}")
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取风格文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取风格提示词失败: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"未找到风格提示词: {style_name},将使用默认值")
|
||||||
|
return "通用风格"
|
||||||
|
|
||||||
|
def get_audience_content(self, audience_name: str) -> str:
|
||||||
|
"""
|
||||||
|
获取目标受众提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audience_name: 受众名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
受众提示词内容
|
||||||
|
"""
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT description FROM targetAudience WHERE audienceName = %s",
|
||||||
|
(audience_name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"从数据库获取受众提示词: {audience_name}")
|
||||||
|
return result["description"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取受众提示词失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
demand_paths = self.resource_config.demand.paths
|
||||||
|
for path_str in demand_paths:
|
||||||
|
try:
|
||||||
|
if audience_name in path_str:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists():
|
||||||
|
logger.info(f"从文件系统获取受众提示词: {audience_name}")
|
||||||
|
return full_path.read_text('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取受众文件失败 {path_str}: {e}")
|
||||||
|
|
||||||
|
# 如果没有精确匹配,尝试模糊匹配
|
||||||
|
for path_str in demand_paths:
|
||||||
|
try:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
if audience_name.lower() in full_path.stem.lower():
|
||||||
|
logger.info(f"通过模糊匹配找到受众提示词: {audience_name} -> {full_path.name}")
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取受众文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取受众提示词失败: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"未找到受众提示词: {audience_name},将使用默认值")
|
||||||
|
return "通用用户画像"
|
||||||
|
|
||||||
|
def get_scenic_spot_info(self, spot_name: str) -> str:
|
||||||
|
"""
|
||||||
|
获取景区信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spot_name: 景区名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
景区信息内容
|
||||||
|
"""
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT description FROM scenicSpot WHERE spotName = %s",
|
||||||
|
(spot_name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"从数据库获取景区信息: {spot_name}")
|
||||||
|
return result["description"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取景区信息失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
object_paths = self.resource_config.object.paths
|
||||||
|
for path_str in object_paths:
|
||||||
|
try:
|
||||||
|
if spot_name in path_str:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists():
|
||||||
|
logger.info(f"从文件系统获取景区信息: {spot_name}")
|
||||||
|
return full_path.read_text('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取景区文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取景区信息失败: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"未找到景区信息: {spot_name}")
|
||||||
|
return "无"
|
||||||
|
|
||||||
|
def get_product_info(self, product_name: str) -> str:
|
||||||
|
"""
|
||||||
|
获取产品信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
product_name: 产品名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
产品信息内容
|
||||||
|
"""
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT description FROM product WHERE productName = %s",
|
||||||
|
(product_name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"从数据库获取产品信息: {product_name}")
|
||||||
|
return result["description"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取产品信息失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
product_paths = self.resource_config.product.paths
|
||||||
|
for path_str in product_paths:
|
||||||
|
try:
|
||||||
|
if product_name in path_str:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists():
|
||||||
|
logger.info(f"从文件系统获取产品信息: {product_name}")
|
||||||
|
return full_path.read_text('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取产品文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取产品信息失败: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"未找到产品信息: {product_name}")
|
||||||
|
return "无"
|
||||||
|
|
||||||
|
def get_refer_content(self, step: str = "") -> str:
|
||||||
|
"""
|
||||||
|
获取参考内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: 当前步骤,用于过滤参考内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参考内容
|
||||||
|
"""
|
||||||
|
refer_content = "参考内容:\n"
|
||||||
|
|
||||||
|
# 从文件系统获取参考内容
|
||||||
|
try:
|
||||||
|
refer_list = self.resource_config.refer.refer_list
|
||||||
|
filtered_configs = [
|
||||||
|
item for item in refer_list
|
||||||
|
if not item.step or item.step == step
|
||||||
|
]
|
||||||
|
|
||||||
|
for ref_item in filtered_configs:
|
||||||
|
try:
|
||||||
|
path_str = ref_item.path
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
refer_content += f"--- {full_path.name} ---\n{content}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取参考文件失败 {ref_item.path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取参考内容失败: {e}")
|
||||||
|
|
||||||
|
return refer_content
|
||||||
|
|
||||||
|
def _get_full_path(self, path_str: str) -> Path:
|
||||||
|
"""根据基准目录解析相对或绝对路径"""
|
||||||
|
if not self.resource_config.resource_dirs:
|
||||||
|
raise ValueError("Resource directories list is empty in config.")
|
||||||
|
base_path = Path(self.resource_config.resource_dirs[0])
|
||||||
|
file_path = Path(path_str)
|
||||||
|
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
|
||||||
|
|
||||||
|
def get_all_styles(self) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
获取所有内容风格
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
风格列表,每个元素包含name和description
|
||||||
|
"""
|
||||||
|
styles = []
|
||||||
|
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute("SELECT styleName as name, description FROM contentStyle")
|
||||||
|
results = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if results:
|
||||||
|
logger.info(f"从数据库获取所有风格: {len(results)}个")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取所有风格失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
style_paths = self.resource_config.style.paths
|
||||||
|
for path_str in style_paths:
|
||||||
|
try:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
name = full_path.stem
|
||||||
|
if "文案提示词" in name:
|
||||||
|
name = name.replace("文案提示词", "")
|
||||||
|
styles.append({
|
||||||
|
"name": name,
|
||||||
|
"description": content[:100] + "..." if len(content) > 100 else content
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取风格文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有风格失败: {e}")
|
||||||
|
|
||||||
|
return styles
|
||||||
|
|
||||||
|
def get_all_audiences(self) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
获取所有目标受众
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
受众列表,每个元素包含name和description
|
||||||
|
"""
|
||||||
|
audiences = []
|
||||||
|
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute("SELECT audienceName as name, description FROM targetAudience")
|
||||||
|
results = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if results:
|
||||||
|
logger.info(f"从数据库获取所有受众: {len(results)}个")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取所有受众失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
demand_paths = self.resource_config.demand.paths
|
||||||
|
for path_str in demand_paths:
|
||||||
|
try:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
name = full_path.stem
|
||||||
|
if "文旅需求" in name:
|
||||||
|
name = name.replace("文旅需求", "")
|
||||||
|
audiences.append({
|
||||||
|
"name": name,
|
||||||
|
"description": content[:100] + "..." if len(content) > 100 else content
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取受众文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有受众失败: {e}")
|
||||||
|
|
||||||
|
return audiences
|
||||||
|
|
||||||
|
def get_all_scenic_spots(self) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
获取所有景区
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
景区列表,每个元素包含name和description
|
||||||
|
"""
|
||||||
|
spots = []
|
||||||
|
|
||||||
|
# 优先从数据库获取
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor(dictionary=True)
|
||||||
|
cursor.execute("SELECT spotName as name, description FROM scenicSpot")
|
||||||
|
results = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if results:
|
||||||
|
logger.info(f"从数据库获取所有景区: {len(results)}个")
|
||||||
|
return [{
|
||||||
|
"name": item["name"],
|
||||||
|
"description": item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
|
||||||
|
} for item in results]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取所有景区失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
object_paths = self.resource_config.object.paths
|
||||||
|
for path_str in object_paths:
|
||||||
|
try:
|
||||||
|
full_path = self._get_full_path(path_str)
|
||||||
|
if full_path.exists() and full_path.is_file():
|
||||||
|
content = full_path.read_text('utf-8')
|
||||||
|
spots.append({
|
||||||
|
"name": full_path.stem,
|
||||||
|
"description": content[:100] + "..." if len(content) > 100 else content
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取景区文件失败 {path_str}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取所有景区失败: {e}")
|
||||||
|
|
||||||
|
return spots
|
||||||
|
|
||||||
|
def save_style(self, name: str, description: str) -> bool:
|
||||||
|
"""
|
||||||
|
保存内容风格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 风格名称
|
||||||
|
description: 风格描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否保存成功
|
||||||
|
"""
|
||||||
|
# 优先保存到数据库
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 检查是否存在
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT COUNT(*) FROM contentStyle WHERE styleName = %s",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
# 更新
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE contentStyle SET description = %s WHERE styleName = %s",
|
||||||
|
(description, name)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 插入
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO contentStyle (styleName, description) VALUES (%s, %s)",
|
||||||
|
(name, description)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
logger.info(f"风格保存到数据库成功: {name}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"风格保存到数据库失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
# 确保风格目录存在
|
||||||
|
style_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Style"
|
||||||
|
style_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
file_path = style_dir / f"{name}文案提示词.md"
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(description)
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
if str(file_path) not in self.resource_config.style.paths:
|
||||||
|
self.resource_config.style.paths.append(str(file_path))
|
||||||
|
|
||||||
|
logger.info(f"风格保存到文件系统成功: {file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"风格保存到文件系统失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_audience(self, name: str, description: str) -> bool:
|
||||||
|
"""
|
||||||
|
保存目标受众
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 受众名称
|
||||||
|
description: 受众描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否保存成功
|
||||||
|
"""
|
||||||
|
# 优先保存到数据库
|
||||||
|
if self.db_pool:
|
||||||
|
try:
|
||||||
|
conn = self.db_pool.get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 检查是否存在
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT COUNT(*) FROM targetAudience WHERE audienceName = %s",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
# 更新
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE targetAudience SET description = %s WHERE audienceName = %s",
|
||||||
|
(description, name)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 插入
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO targetAudience (audienceName, description) VALUES (%s, %s)",
|
||||||
|
(name, description)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
logger.info(f"受众保存到数据库成功: {name}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"受众保存到数据库失败: {e}")
|
||||||
|
|
||||||
|
# 回退到文件系统
|
||||||
|
try:
|
||||||
|
# 确保受众目录存在
|
||||||
|
demand_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Demand"
|
||||||
|
demand_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
file_path = demand_dir / f"{name}文旅需求.md"
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(description)
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
if str(file_path) not in self.resource_config.demand.paths:
|
||||||
|
self.resource_config.demand.paths.append(str(file_path))
|
||||||
|
|
||||||
|
logger.info(f"受众保存到文件系统成功: {file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"受众保存到文件系统失败: {e}")
|
||||||
|
return False
|
||||||
Loading…
x
Reference in New Issue
Block a user