链接数据库,可以运行

This commit is contained in:
jinye_huang 2025-07-11 13:50:08 +08:00
parent 9da9a39062
commit 19ca9d06ce
18 changed files with 1258 additions and 50 deletions

Binary file not shown.

Binary file not shown.

45
api/dependencies.py Normal file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API依赖注入模块
"""
from core.config import get_config_manager, ConfigManager
from core.ai import AIAgent
from utils.file_io import OutputManager
# 全局依赖
config_manager = None
ai_agent = None
output_manager = None
def initialize_dependencies():
"""初始化全局依赖"""
global config_manager, ai_agent, output_manager
# 初始化配置
config_manager = get_config_manager()
config_manager.load_from_directory("config")
# 初始化输出管理器
from datetime import datetime
run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_manager = OutputManager("result", run_id)
# 初始化AI代理
from core.config import AIModelConfig
ai_config = config_manager.get_config('ai_model', AIModelConfig)
ai_agent = AIAgent(ai_config)
def get_config() -> ConfigManager:
"""获取配置管理器"""
return config_manager
def get_ai_agent() -> AIAgent:
"""获取AI代理"""
return ai_agent
def get_output_manager() -> OutputManager:
"""获取输出管理器"""
return output_manager

View File

@ -7,16 +7,11 @@ TravelContentCreator API服务
"""
import logging
from fastapi import FastAPI, Depends
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from core.config import get_config_manager, AIModelConfig
from core.ai import AIAgent
from utils.file_io import OutputManager
from datetime import datetime
from api.routers import tweet, poster
from api import dependencies
# 配置日志
logging.basicConfig(
@ -26,32 +21,15 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# 全局依赖
config_manager = None
ai_agent = None
output_manager = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
在应用启动时初始化全局依赖在应用关闭时清理资源
"""
global config_manager, ai_agent, output_manager
# 初始化配置
logger.info("正在初始化API服务...")
config_manager = get_config_manager()
config_manager.load_from_directory("config")
# 初始化输出管理器
run_id = f"api_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_manager = OutputManager("result", run_id)
# 初始化AI代理
ai_config = config_manager.get_config('ai_model', AIModelConfig)
ai_agent = AIAgent(ai_config)
dependencies.initialize_dependencies()
logger.info("API服务初始化完成")
yield
@ -77,37 +55,19 @@ app.add_middleware(
allow_headers=["*"],
)
# 依赖注入函数
def get_config():
return config_manager
def get_ai_agent():
return ai_agent
def get_output_manager():
return output_manager
# 导入路由
from api.routers import tweet, poster, prompt
# 包含路由
app.include_router(tweet.router, prefix="/api/tweet", tags=["tweet"])
app.include_router(poster.router, prefix="/api/poster", tags=["poster"])
app.include_router(prompt.router, prefix="/api/prompt", tags=["prompt"])
@app.get("/")
async def root():
"""API根路径返回简单的欢迎信息"""
return {"message": "欢迎使用TravelContentCreator API服务"}
@app.get("/health")
async def health_check():
"""健康检查端点"""
return {
"status": "healthy",
"services": {
"ai_agent": ai_agent is not None,
"config": config_manager is not None,
"output_manager": output_manager is not None
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)

Binary file not shown.

Binary file not shown.

186
api/models/prompt.py Normal file
View File

@ -0,0 +1,186 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词API模型定义
"""
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
class StyleRequest(BaseModel):
"""风格请求模型"""
name: str = Field(..., description="风格名称")
description: Optional[str] = Field(None, description="风格描述,如果为空则表示获取")
class Config:
schema_extra = {
"example": {
"name": "攻略风",
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
}
}
class StyleResponse(BaseModel):
"""风格响应模型"""
name: str = Field(..., description="风格名称")
description: str = Field(..., description="风格描述")
class Config:
schema_extra = {
"example": {
"name": "攻略风",
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
}
}
class AudienceRequest(BaseModel):
"""受众请求模型"""
name: str = Field(..., description="受众名称")
description: Optional[str] = Field(None, description="受众描述,如果为空则表示获取")
class Config:
schema_extra = {
"example": {
"name": "亲子向",
"description": "25-45岁家长群体孩子年龄3-12岁注重安全和教育意义偏好收藏实用攻略"
}
}
class AudienceResponse(BaseModel):
"""受众响应模型"""
name: str = Field(..., description="受众名称")
description: str = Field(..., description="受众描述")
class Config:
schema_extra = {
"example": {
"name": "亲子向",
"description": "25-45岁家长群体孩子年龄3-12岁注重安全和教育意义偏好收藏实用攻略"
}
}
class ScenicSpotRequest(BaseModel):
"""景区请求模型"""
name: str = Field(..., description="景区名称")
class Config:
schema_extra = {
"example": {
"name": "天津冒险湾"
}
}
class ScenicSpotResponse(BaseModel):
"""景区响应模型"""
name: str = Field(..., description="景区名称")
description: str = Field(..., description="景区描述")
class Config:
schema_extra = {
"example": {
"name": "天津冒险湾",
"description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..."
}
}
class StyleListResponse(BaseModel):
"""风格列表响应模型"""
styles: List[StyleResponse] = Field(..., description="风格列表")
class Config:
schema_extra = {
"example": {
"styles": [
{
"name": "攻略风",
"description": "详细的旅行攻略信息,包含行程安排、交通指南、住宿推荐等实用信息,语言平实靠谱"
},
{
"name": "清新文艺风",
"description": "文艺范十足,清新脱俗的表达风格,注重意境和美感描述"
}
]
}
}
class AudienceListResponse(BaseModel):
"""受众列表响应模型"""
audiences: List[AudienceResponse] = Field(..., description="受众列表")
class Config:
schema_extra = {
"example": {
"audiences": [
{
"name": "亲子向",
"description": "25-45岁家长群体孩子年龄3-12岁注重安全和教育意义偏好收藏实用攻略"
},
{
"name": "周边游",
"description": "全龄覆盖,主要围绕三天内的短期假期出游需求和周末出游需求"
}
]
}
}
class ScenicSpotListResponse(BaseModel):
"""景区列表响应模型"""
spots: List[ScenicSpotResponse] = Field(..., description="景区列表")
class Config:
schema_extra = {
"example": {
"spots": [
{
"name": "天津冒险湾",
"description": "天津冒险湾位于天津市滨海新区,是华北地区最大的水上乐园..."
}
]
}
}
class PromptBuilderRequest(BaseModel):
"""提示词构建请求模型"""
topic: Dict[str, Any] = Field(..., description="选题信息")
step: Optional[str] = Field(None, description="当前步骤,用于过滤参考内容")
class Config:
schema_extra = {
"example": {
"topic": {
"index": "1",
"date": "2025-07-15",
"object": "天津冒险湾",
"product": "冒险湾-2大2小套票",
"style": "攻略风",
"target_audience": "亲子向",
"logic": "暑期亲子游热门景点推荐"
},
"step": "content"
}
}
class PromptBuilderResponse(BaseModel):
"""提示词构建响应模型"""
system_prompt: str = Field(..., description="系统提示词")
user_prompt: str = Field(..., description="用户提示词")
class Config:
schema_extra = {
"example": {
"system_prompt": "你是一位专业的旅游内容创作者...",
"user_prompt": "请根据以下信息创作一篇旅游文章..."
}
}

Binary file not shown.

Binary file not shown.

View File

@ -19,8 +19,8 @@ from api.models.poster import (
PosterTextRequest, PosterTextResponse
)
# 从main.py中导入依赖
from api.main import get_config, get_ai_agent, get_output_manager
# 从依赖注入模块导入依赖
from api.dependencies import get_config, get_ai_agent, get_output_manager
logger = logging.getLogger(__name__)

207
api/routers/prompt.py Normal file
View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词API路由
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from typing import List, Dict, Any
from core.config import ConfigManager
from api.services.prompt_service import PromptService
from api.services.prompt_builder import PromptBuilderService
from api.models.prompt import (
StyleRequest, StyleResponse, StyleListResponse,
AudienceRequest, AudienceResponse, AudienceListResponse,
ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse,
PromptBuilderRequest, PromptBuilderResponse
)
# 从依赖注入模块导入依赖
from api.dependencies import get_config
logger = logging.getLogger(__name__)
# 创建路由
router = APIRouter()
# 依赖注入函数
def get_prompt_service(
config_manager: ConfigManager = Depends(get_config)
) -> PromptService:
"""获取提示词服务"""
return PromptService(config_manager)
def get_prompt_builder(
config_manager: ConfigManager = Depends(get_config),
prompt_service: PromptService = Depends(get_prompt_service)
) -> PromptBuilderService:
"""获取提示词构建服务"""
return PromptBuilderService(config_manager, prompt_service)
@router.get("/styles", response_model=StyleListResponse)
async def get_all_styles(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有内容风格"""
try:
styles_dict = prompt_service.get_all_styles()
# 将字典列表转换为StyleResponse对象列表
styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict]
return StyleListResponse(styles=styles)
except Exception as e:
logger.error(f"获取所有风格失败: {e}")
raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}")
@router.get("/styles/{style_name}", response_model=StyleResponse)
async def get_style(
style_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定内容风格"""
try:
content = prompt_service.get_style_content(style_name)
return StyleResponse(name=style_name, description=content)
except Exception as e:
logger.error(f"获取风格 '{style_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}")
@router.post("/styles", response_model=StyleResponse)
async def create_or_update_style(
style: StyleRequest,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""创建或更新内容风格"""
try:
if not style.description:
# 如果没有提供描述,则获取现有的
content = prompt_service.get_style_content(style.name)
return StyleResponse(name=style.name, description=content)
success = prompt_service.save_style(style.name, style.description)
if not success:
raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败")
return StyleResponse(name=style.name, description=style.description)
except Exception as e:
logger.error(f"保存风格 '{style.name}' 失败: {e}")
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
@router.get("/audiences", response_model=AudienceListResponse)
async def get_all_audiences(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有目标受众"""
try:
audiences_dict = prompt_service.get_all_audiences()
# 将字典列表转换为AudienceResponse对象列表
audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict]
return AudienceListResponse(audiences=audiences)
except Exception as e:
logger.error(f"获取所有受众失败: {e}")
raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}")
@router.get("/audiences/{audience_name}", response_model=AudienceResponse)
async def get_audience(
audience_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定目标受众"""
try:
content = prompt_service.get_audience_content(audience_name)
return AudienceResponse(name=audience_name, description=content)
except Exception as e:
logger.error(f"获取受众 '{audience_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}")
@router.post("/audiences", response_model=AudienceResponse)
async def create_or_update_audience(
audience: AudienceRequest,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""创建或更新目标受众"""
try:
if not audience.description:
# 如果没有提供描述,则获取现有的
content = prompt_service.get_audience_content(audience.name)
return AudienceResponse(name=audience.name, description=content)
success = prompt_service.save_audience(audience.name, audience.description)
if not success:
raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败")
return AudienceResponse(name=audience.name, description=audience.description)
except Exception as e:
logger.error(f"保存受众 '{audience.name}' 失败: {e}")
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
@router.get("/scenic-spots", response_model=ScenicSpotListResponse)
async def get_all_scenic_spots(
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取所有景区"""
try:
spots_dict = prompt_service.get_all_scenic_spots()
# 将字典列表转换为ScenicSpotResponse对象列表
spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict]
return ScenicSpotListResponse(spots=spots)
except Exception as e:
logger.error(f"获取所有景区失败: {e}")
raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}")
@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse)
async def get_scenic_spot(
spot_name: str,
prompt_service: PromptService = Depends(get_prompt_service)
):
"""获取指定景区信息"""
try:
content = prompt_service.get_scenic_spot_info(spot_name)
return ScenicSpotResponse(name=spot_name, description=content)
except Exception as e:
logger.error(f"获取景区 '{spot_name}' 失败: {e}")
raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}")
@router.post("/build-prompt", response_model=PromptBuilderResponse)
async def build_prompt(
request: PromptBuilderRequest,
prompt_builder: PromptBuilderService = Depends(get_prompt_builder)
):
"""构建完整提示词"""
try:
# 根据请求中的step确定构建哪种类型的提示词
step = request.step or "content"
if step == "topic":
# 构建选题提示词
# 从topic中提取必要的参数
num_topics = request.topic.get("num_topics", 5)
month = request.topic.get("month", "7")
system_prompt, user_prompt = prompt_builder.build_topic_prompt(num_topics, month)
elif step == "judge":
# 构建审核提示词
# 需要提供生成的内容
content = request.topic.get("content", {})
system_prompt, user_prompt = prompt_builder.build_judge_prompt(request.topic, content)
else:
# 默认构建内容生成提示词
system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step)
return PromptBuilderResponse(
system_prompt=system_prompt,
user_prompt=user_prompt
)
except Exception as e:
logger.error(f"构建提示词失败: {e}")
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}")

View File

@ -20,8 +20,8 @@ from api.models.tweet import (
PipelineRequest, PipelineResponse
)
# 从main.py中导入依赖
from api.main import get_config, get_ai_agent, get_output_manager
# 从依赖注入模块导入依赖
from api.dependencies import get_config, get_ai_agent, get_output_manager
logger = logging.getLogger(__name__)

Binary file not shown.

View File

@ -0,0 +1,191 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词构建服务
负责根据选题信息构建完整的提示词
"""
import logging
from typing import Dict, Any, Optional, Tuple
from pathlib import Path
from core.config import ConfigManager, GenerateContentConfig
from utils.prompts import PromptTemplate
from api.services.prompt_service import PromptService
logger = logging.getLogger(__name__)
class PromptBuilderService:
"""提示词构建服务类"""
def __init__(self, config_manager: ConfigManager, prompt_service: PromptService):
"""
初始化提示词构建服务
Args:
config_manager: 配置管理器
prompt_service: 提示词服务
"""
self.config_manager = config_manager
self.prompt_service = prompt_service
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
def build_content_prompt(self, topic: Dict[str, Any], step: str = "content") -> Tuple[str, str]:
"""
构建内容生成提示词
Args:
topic: 选题信息
step: 当前步骤用于过滤参考内容
Returns:
系统提示词和用户提示词的元组
"""
# 加载系统提示词和用户提示词模板
system_prompt_path = self.content_config.content_system_prompt
user_prompt_path = self.content_config.content_user_prompt
# 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path)
# 获取风格内容
style_filename = topic.get("style", "")
style_content = self.prompt_service.get_style_content(style_filename)
# 获取目标受众内容
demand_filename = topic.get("target_audience", "")
demand_content = self.prompt_service.get_audience_content(demand_filename)
# 获取景区信息
object_name = topic.get("object", "")
object_content = self.prompt_service.get_scenic_spot_info(object_name)
# 获取产品信息
product_name = topic.get("product", "")
product_content = self.prompt_service.get_product_info(product_name)
# 获取参考内容
refer_content = self.prompt_service.get_refer_content(step)
# 构建系统提示词
system_prompt = template.get_system_prompt()
# 构建用户提示词
user_prompt = template.build_user_prompt(
style_content=f"{style_filename}\n{style_content}",
demand_content=f"{demand_filename}\n{demand_content}",
object_content=f"{object_name}\n{object_content}",
product_content=f"{product_name}\n{product_content}",
refer_content=refer_content
)
return system_prompt, user_prompt
def build_topic_prompt(self, num_topics: int, month: str) -> Tuple[str, str]:
"""
构建选题生成提示词
Args:
num_topics: 要生成的选题数量
month: 月份
Returns:
系统提示词和用户提示词的元组
"""
# 从配置中获取选题提示词模板路径
topic_config = self.config_manager.get_config('topic_gen', dict)
if not topic_config:
raise ValueError("未找到选题生成配置")
system_prompt_path = topic_config.get("topic_system_prompt", "")
user_prompt_path = topic_config.get("topic_user_prompt", "")
if not system_prompt_path or not user_prompt_path:
raise ValueError("选题提示词模板路径不完整")
# 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path)
# 获取风格列表
styles = self.prompt_service.get_all_styles()
style_content = "Style文件列表:\n" + "\n".join([f"- {style['name']}" for style in styles])
# 获取目标受众列表
audiences = self.prompt_service.get_all_audiences()
demand_content = "Demand文件列表:\n" + "\n".join([f"- {audience['name']}" for audience in audiences])
# 获取参考内容
refer_content = self.prompt_service.get_refer_content("topic")
# 获取景区信息列表
spots = self.prompt_service.get_all_scenic_spots()
object_content = "Object信息:\n" + "\n".join([f"- {spot['name']}" for spot in spots])
# 构建系统提示词
system_prompt = template.get_system_prompt()
# 构建创作资料
creative_materials = (
f"你拥有的创作资料如下:\n"
f"{style_content}\n\n"
f"{demand_content}\n\n"
f"{refer_content}\n\n"
f"{object_content}"
)
# 构建用户提示词
user_prompt = template.build_user_prompt(
creative_materials=creative_materials,
num_topics=num_topics,
month=month
)
return system_prompt, user_prompt
def build_judge_prompt(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str]:
"""
构建内容审核提示词
Args:
topic: 选题信息
content: 生成的内容
Returns:
系统提示词和用户提示词的元组
"""
# 从配置中获取审核提示词模板路径
system_prompt_path = self.content_config.judger_system_prompt
user_prompt_path = self.content_config.judger_user_prompt
# 创建提示词模板
template = PromptTemplate(system_prompt_path, user_prompt_path)
# 获取景区信息
object_name = topic.get("object", "")
object_content = self.prompt_service.get_scenic_spot_info(object_name)
# 获取产品信息
product_name = topic.get("product", "")
product_content = self.prompt_service.get_product_info(product_name)
# 获取参考内容
refer_content = self.prompt_service.get_refer_content("judge")
# 构建系统提示词
system_prompt = template.get_system_prompt()
# 格式化内容
import json
tweet_content = json.dumps(content, ensure_ascii=False, indent=4)
# 构建用户提示词
user_prompt = template.build_user_prompt(
tweet_content=tweet_content,
object_content=object_content,
product_content=product_content,
refer_content=refer_content
)
return system_prompt, user_prompt

View File

@ -0,0 +1,619 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词服务层
负责提示词的存储检索和构建
"""
import logging
import json
import os
import re
from typing import Dict, Any, Optional, List, cast
from pathlib import Path
import mysql.connector
from mysql.connector import pooling
from core.config import ConfigManager, ResourceConfig
from utils.prompts import BasePromptBuilder, PromptTemplate
from utils.file_io import ResourceLoader
logger = logging.getLogger(__name__)
class PromptService:
"""提示词服务类"""
def __init__(self, config_manager: ConfigManager):
"""
初始化提示词服务
Args:
config_manager: 配置管理器
"""
self.config_manager = config_manager
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
# 初始化数据库连接池
self._init_db_pool()
def _init_db_pool(self):
"""初始化数据库连接池"""
try:
# 尝试直接从配置文件加载数据库配置
config_dir = Path("config")
db_config_path = config_dir / "database.json"
if not db_config_path.exists():
logger.warning(f"数据库配置文件不存在: {db_config_path}")
self.db_pool = None
return
# 加载配置文件
with open(db_config_path, 'r', encoding='utf-8') as f:
db_config = json.load(f)
# 处理环境变量
processed_config = {}
for key, value in db_config.items():
if isinstance(value, str) and "${" in value:
# 匹配 ${ENV_VAR:-default} 格式
pattern = r'\${([^:-]+)(?::-([^}]+))?}'
match = re.match(pattern, value)
if match:
env_var, default = match.groups()
processed_value = os.environ.get(env_var, default)
# 尝试转换为数字
if key == "port":
try:
processed_value = int(processed_value)
except (ValueError, TypeError):
processed_value = 3306
processed_config[key] = processed_value
else:
processed_config[key] = value
# 创建连接池
self.db_pool = pooling.MySQLConnectionPool(
pool_name="prompt_pool",
pool_size=5,
host=processed_config.get("host", "localhost"),
user=processed_config.get("user", "root"),
password=processed_config.get("password", ""),
database=processed_config.get("database", "travel_content"),
port=processed_config.get("port", 3306),
charset=processed_config.get("charset", "utf8mb4")
)
logger.info(f"数据库连接池初始化成功,连接到 {processed_config.get('host')}:{processed_config.get('port')}")
except Exception as e:
logger.error(f"初始化数据库连接池失败: {e}")
self.db_pool = None
def get_style_content(self, style_name: str) -> str:
"""
获取内容风格提示词
Args:
style_name: 风格名称
Returns:
风格提示词内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM contentStyle WHERE styleName = %s",
(style_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取风格提示词: {style_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取风格提示词失败: {e}")
# 回退到文件系统
try:
style_paths = self.resource_config.style.paths
for path_str in style_paths:
try:
if style_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取风格提示词: {style_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if style_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到风格提示词: {style_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取风格提示词失败: {e}")
logger.warning(f"未找到风格提示词: {style_name},将使用默认值")
return "通用风格"
def get_audience_content(self, audience_name: str) -> str:
"""
获取目标受众提示词
Args:
audience_name: 受众名称
Returns:
受众提示词内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM targetAudience WHERE audienceName = %s",
(audience_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取受众提示词: {audience_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取受众提示词失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
if audience_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取受众提示词: {audience_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
# 如果没有精确匹配,尝试模糊匹配
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
if audience_name.lower() in full_path.stem.lower():
logger.info(f"通过模糊匹配找到受众提示词: {audience_name} -> {full_path.name}")
return content
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取受众提示词失败: {e}")
logger.warning(f"未找到受众提示词: {audience_name},将使用默认值")
return "通用用户画像"
def get_scenic_spot_info(self, spot_name: str) -> str:
"""
获取景区信息
Args:
spot_name: 景区名称
Returns:
景区信息内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM scenicSpot WHERE spotName = %s",
(spot_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取景区信息: {spot_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取景区信息失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
if spot_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取景区信息: {spot_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取景区信息失败: {e}")
logger.warning(f"未找到景区信息: {spot_name}")
return ""
def get_product_info(self, product_name: str) -> str:
"""
获取产品信息
Args:
product_name: 产品名称
Returns:
产品信息内容
"""
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute(
"SELECT description FROM product WHERE productName = %s",
(product_name,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
if result:
logger.info(f"从数据库获取产品信息: {product_name}")
return result["description"]
except Exception as e:
logger.error(f"从数据库获取产品信息失败: {e}")
# 回退到文件系统
try:
product_paths = self.resource_config.product.paths
for path_str in product_paths:
try:
if product_name in path_str:
full_path = self._get_full_path(path_str)
if full_path.exists():
logger.info(f"从文件系统获取产品信息: {product_name}")
return full_path.read_text('utf-8')
except Exception as e:
logger.error(f"读取产品文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取产品信息失败: {e}")
logger.warning(f"未找到产品信息: {product_name}")
return ""
def get_refer_content(self, step: str = "") -> str:
"""
获取参考内容
Args:
step: 当前步骤用于过滤参考内容
Returns:
参考内容
"""
refer_content = "参考内容:\n"
# 从文件系统获取参考内容
try:
refer_list = self.resource_config.refer.refer_list
filtered_configs = [
item for item in refer_list
if not item.step or item.step == step
]
for ref_item in filtered_configs:
try:
path_str = ref_item.path
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
refer_content += f"--- {full_path.name} ---\n{content}\n\n"
except Exception as e:
logger.error(f"读取参考文件失败 {ref_item.path}: {e}")
except Exception as e:
logger.error(f"获取参考内容失败: {e}")
return refer_content
def _get_full_path(self, path_str: str) -> Path:
"""根据基准目录解析相对或绝对路径"""
if not self.resource_config.resource_dirs:
raise ValueError("Resource directories list is empty in config.")
base_path = Path(self.resource_config.resource_dirs[0])
file_path = Path(path_str)
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
def get_all_styles(self) -> List[Dict[str, str]]:
"""
获取所有内容风格
Returns:
风格列表每个元素包含name和description
"""
styles = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT styleName as name, description FROM contentStyle")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有风格: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有风格失败: {e}")
# 回退到文件系统
try:
style_paths = self.resource_config.style.paths
for path_str in style_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文案提示词" in name:
name = name.replace("文案提示词", "")
styles.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取风格文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有风格失败: {e}")
return styles
def get_all_audiences(self) -> List[Dict[str, str]]:
"""
获取所有目标受众
Returns:
受众列表每个元素包含name和description
"""
audiences = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT audienceName as name, description FROM targetAudience")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有受众: {len(results)}")
return results
except Exception as e:
logger.error(f"从数据库获取所有受众失败: {e}")
# 回退到文件系统
try:
demand_paths = self.resource_config.demand.paths
for path_str in demand_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
name = full_path.stem
if "文旅需求" in name:
name = name.replace("文旅需求", "")
audiences.append({
"name": name,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取受众文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有受众失败: {e}")
return audiences
def get_all_scenic_spots(self) -> List[Dict[str, str]]:
"""
获取所有景区
Returns:
景区列表每个元素包含name和description
"""
spots = []
# 优先从数据库获取
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT spotName as name, description FROM scenicSpot")
results = cursor.fetchall()
cursor.close()
conn.close()
if results:
logger.info(f"从数据库获取所有景区: {len(results)}")
return [{
"name": item["name"],
"description": item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
} for item in results]
except Exception as e:
logger.error(f"从数据库获取所有景区失败: {e}")
# 回退到文件系统
try:
object_paths = self.resource_config.object.paths
for path_str in object_paths:
try:
full_path = self._get_full_path(path_str)
if full_path.exists() and full_path.is_file():
content = full_path.read_text('utf-8')
spots.append({
"name": full_path.stem,
"description": content[:100] + "..." if len(content) > 100 else content
})
except Exception as e:
logger.error(f"读取景区文件失败 {path_str}: {e}")
except Exception as e:
logger.error(f"获取所有景区失败: {e}")
return spots
def save_style(self, name: str, description: str) -> bool:
"""
保存内容风格
Args:
name: 风格名称
description: 风格描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor()
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM contentStyle WHERE styleName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE contentStyle SET description = %s WHERE styleName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO contentStyle (styleName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
cursor.close()
conn.close()
logger.info(f"风格保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"风格保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保风格目录存在
style_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Style"
style_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = style_dir / f"{name}文案提示词.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.style.paths:
self.resource_config.style.paths.append(str(file_path))
logger.info(f"风格保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"风格保存到文件系统失败: {e}")
return False
def save_audience(self, name: str, description: str) -> bool:
"""
保存目标受众
Args:
name: 受众名称
description: 受众描述
Returns:
是否保存成功
"""
# 优先保存到数据库
if self.db_pool:
try:
conn = self.db_pool.get_connection()
cursor = conn.cursor()
# 检查是否存在
cursor.execute(
"SELECT COUNT(*) FROM targetAudience WHERE audienceName = %s",
(name,)
)
count = cursor.fetchone()[0]
if count > 0:
# 更新
cursor.execute(
"UPDATE targetAudience SET description = %s WHERE audienceName = %s",
(description, name)
)
else:
# 插入
cursor.execute(
"INSERT INTO targetAudience (audienceName, description) VALUES (%s, %s)",
(name, description)
)
conn.commit()
cursor.close()
conn.close()
logger.info(f"受众保存到数据库成功: {name}")
return True
except Exception as e:
logger.error(f"受众保存到数据库失败: {e}")
# 回退到文件系统
try:
# 确保受众目录存在
demand_dir = Path(self.resource_config.resource_dirs[0]) / "resource" / "prompt" / "Demand"
demand_dir.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = demand_dir / f"{name}文旅需求.md"
with open(file_path, 'w', encoding='utf-8') as f:
f.write(description)
# 更新配置
if str(file_path) not in self.resource_config.demand.paths:
self.resource_config.demand.paths.append(str(file_path))
logger.info(f"受众保存到文件系统成功: {file_path}")
return True
except Exception as e:
logger.error(f"受众保存到文件系统失败: {e}")
return False