确认了根据获得的提示词进行内容生成的功能
This commit is contained in:
parent
0653355c5e
commit
db1319eb11
Binary file not shown.
@ -6,6 +6,7 @@ API依赖注入模块
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import Depends
|
||||
from core.config import get_config_manager, ConfigManager
|
||||
from core.ai import AIAgent
|
||||
from utils.file_io import OutputManager
|
||||
@ -49,4 +50,9 @@ def get_output_manager() -> OutputManager:
|
||||
"""获取输出管理器"""
|
||||
if output_manager is None:
|
||||
raise RuntimeError("输出管理器未初始化")
|
||||
return output_manager
|
||||
return output_manager
|
||||
|
||||
def get_tweet_service():
|
||||
"""获取文字内容服务"""
|
||||
from api.services.tweet import TweetService
|
||||
return TweetService(get_ai_agent(), get_config(), get_output_manager())
|
||||
Binary file not shown.
Binary file not shown.
@ -6,171 +6,138 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.config import ConfigManager
|
||||
from api.services.prompt_service import PromptService
|
||||
from api.services.prompt_builder import PromptBuilderService
|
||||
from api.models.prompt import (
|
||||
StyleRequest, StyleResponse, StyleListResponse,
|
||||
AudienceRequest, AudienceResponse, AudienceListResponse,
|
||||
ScenicSpotRequest, ScenicSpotResponse, ScenicSpotListResponse,
|
||||
PromptBuilderRequest, PromptBuilderResponse
|
||||
)
|
||||
|
||||
# 从依赖注入模块导入依赖
|
||||
from api.dependencies import get_config
|
||||
from api.dependencies import get_config, get_tweet_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
router = APIRouter(
|
||||
tags=["prompt"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# 依赖注入函数
|
||||
def get_prompt_service(
|
||||
config_manager: ConfigManager = Depends(get_config)
|
||||
) -> PromptService:
|
||||
def get_prompt_service():
|
||||
"""获取提示词服务"""
|
||||
return PromptService(config_manager)
|
||||
from core.config import get_config_manager
|
||||
return PromptService(get_config_manager())
|
||||
|
||||
def get_prompt_builder(
|
||||
config_manager: ConfigManager = Depends(get_config),
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
) -> PromptBuilderService:
|
||||
"""获取提示词构建服务"""
|
||||
def get_prompt_builder():
|
||||
"""获取提示词构建器服务"""
|
||||
from core.config import get_config_manager
|
||||
config_manager = get_config_manager()
|
||||
prompt_service = PromptService(config_manager)
|
||||
return PromptBuilderService(config_manager, prompt_service)
|
||||
|
||||
|
||||
@router.get("/styles", response_model=StyleListResponse)
|
||||
async def get_all_styles(
|
||||
# 请求和响应模型
|
||||
class PromptRequest(BaseModel):
|
||||
"""提示词请求模型"""
|
||||
style: str = Field(..., description="内容风格")
|
||||
audience: str = Field(..., description="目标受众")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"style": "攻略风",
|
||||
"audience": "亲子向"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PromptResponse(BaseModel):
|
||||
"""提示词响应模型"""
|
||||
style_content: str = Field(..., description="风格提示词")
|
||||
audience_content: str = Field(..., description="受众提示词")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"style_content": "以实用信息为主,包含详细的游玩路线...",
|
||||
"audience_content": "家庭出游,有小孩同行,关注安全性..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PromptBuilderRequest(BaseModel):
|
||||
"""提示词构建器请求模型"""
|
||||
topic: Dict[str, Any] = Field(..., description="选题信息")
|
||||
step: Optional[str] = Field(None, description="步骤,如topic、content、judge等")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"topic": {
|
||||
"index": "1",
|
||||
"date": "2023-07-15",
|
||||
"object": "北京故宫",
|
||||
"product": "故宫门票",
|
||||
"style": "攻略风",
|
||||
"target_audience": "亲子向",
|
||||
"logic": "暑期旅游热门景点推荐"
|
||||
},
|
||||
"step": "content"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PromptBuilderResponse(BaseModel):
|
||||
"""提示词构建器响应模型"""
|
||||
system_prompt: str = Field(..., description="系统提示词")
|
||||
user_prompt: str = Field(..., description="用户提示词")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"system_prompt": "你是一位专业的旅游内容创作者...",
|
||||
"user_prompt": "请为北京故宫创作一篇旅游攻略..."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class GenerateContentResponse(BaseModel):
|
||||
"""生成内容响应模型"""
|
||||
request_id: str = Field(..., description="请求ID")
|
||||
topic_index: str = Field(..., description="选题索引")
|
||||
content: Dict[str, Any] = Field(..., description="生成的内容")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"request_id": "content_20230715_123456",
|
||||
"topic_index": "1",
|
||||
"content": {
|
||||
"title": "【北京故宫】避开人潮的秘密路线,90%的人都不知道!",
|
||||
"content": "故宫,作为中国最著名的文化遗产之一...",
|
||||
"tag": ["北京旅游", "故宫", "旅游攻略", "避暑胜地"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/get-style-audience", response_model=PromptResponse)
|
||||
async def get_style_audience(
|
||||
request: PromptRequest,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取所有内容风格"""
|
||||
"""获取风格和受众提示词"""
|
||||
try:
|
||||
styles_dict = prompt_service.get_all_styles()
|
||||
# 将字典列表转换为StyleResponse对象列表
|
||||
styles = [StyleResponse(name=style["name"], description=style["description"]) for style in styles_dict]
|
||||
return StyleListResponse(styles=styles)
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有风格失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取风格列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/styles/{style_name}", response_model=StyleResponse)
|
||||
async def get_style(
|
||||
style_name: str,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取指定内容风格"""
|
||||
try:
|
||||
content = prompt_service.get_style_content(style_name)
|
||||
return StyleResponse(name=style_name, description=content)
|
||||
except Exception as e:
|
||||
logger.error(f"获取风格 '{style_name}' 失败: {e}")
|
||||
raise HTTPException(status_code=404, detail=f"未找到风格: {style_name}")
|
||||
|
||||
|
||||
@router.post("/styles", response_model=StyleResponse)
|
||||
async def create_or_update_style(
|
||||
style: StyleRequest,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""创建或更新内容风格"""
|
||||
try:
|
||||
if not style.description:
|
||||
# 如果没有提供描述,则获取现有的
|
||||
content = prompt_service.get_style_content(style.name)
|
||||
return StyleResponse(name=style.name, description=content)
|
||||
style_content = prompt_service.get_style_content(request.style)
|
||||
audience_content = prompt_service.get_audience_content(request.audience)
|
||||
|
||||
success = prompt_service.save_style(style.name, style.description)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"保存风格 '{style.name}' 失败")
|
||||
|
||||
return StyleResponse(name=style.name, description=style.description)
|
||||
return PromptResponse(
|
||||
style_content=style_content,
|
||||
audience_content=audience_content
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"保存风格 '{style.name}' 失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/audiences", response_model=AudienceListResponse)
|
||||
async def get_all_audiences(
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取所有目标受众"""
|
||||
try:
|
||||
audiences_dict = prompt_service.get_all_audiences()
|
||||
# 将字典列表转换为AudienceResponse对象列表
|
||||
audiences = [AudienceResponse(name=audience["name"], description=audience["description"]) for audience in audiences_dict]
|
||||
return AudienceListResponse(audiences=audiences)
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有受众失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取受众列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/audiences/{audience_name}", response_model=AudienceResponse)
|
||||
async def get_audience(
|
||||
audience_name: str,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取指定目标受众"""
|
||||
try:
|
||||
content = prompt_service.get_audience_content(audience_name)
|
||||
return AudienceResponse(name=audience_name, description=content)
|
||||
except Exception as e:
|
||||
logger.error(f"获取受众 '{audience_name}' 失败: {e}")
|
||||
raise HTTPException(status_code=404, detail=f"未找到受众: {audience_name}")
|
||||
|
||||
|
||||
@router.post("/audiences", response_model=AudienceResponse)
|
||||
async def create_or_update_audience(
|
||||
audience: AudienceRequest,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""创建或更新目标受众"""
|
||||
try:
|
||||
if not audience.description:
|
||||
# 如果没有提供描述,则获取现有的
|
||||
content = prompt_service.get_audience_content(audience.name)
|
||||
return AudienceResponse(name=audience.name, description=content)
|
||||
|
||||
success = prompt_service.save_audience(audience.name, audience.description)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail=f"保存受众 '{audience.name}' 失败")
|
||||
|
||||
return AudienceResponse(name=audience.name, description=audience.description)
|
||||
except Exception as e:
|
||||
logger.error(f"保存受众 '{audience.name}' 失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"操作失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/scenic-spots", response_model=ScenicSpotListResponse)
|
||||
async def get_all_scenic_spots(
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取所有景区"""
|
||||
try:
|
||||
spots_dict = prompt_service.get_all_scenic_spots()
|
||||
# 将字典列表转换为ScenicSpotResponse对象列表
|
||||
spots = [ScenicSpotResponse(name=spot["name"], description=spot["description"]) for spot in spots_dict]
|
||||
return ScenicSpotListResponse(spots=spots)
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有景区失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取景区列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/scenic-spots/{spot_name}", response_model=ScenicSpotResponse)
|
||||
async def get_scenic_spot(
|
||||
spot_name: str,
|
||||
prompt_service: PromptService = Depends(get_prompt_service)
|
||||
):
|
||||
"""获取指定景区信息"""
|
||||
try:
|
||||
content = prompt_service.get_scenic_spot_info(spot_name)
|
||||
return ScenicSpotResponse(name=spot_name, description=content)
|
||||
except Exception as e:
|
||||
logger.error(f"获取景区 '{spot_name}' 失败: {e}")
|
||||
raise HTTPException(status_code=404, detail=f"未找到景区: {spot_name}")
|
||||
logger.error(f"获取提示词失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取提示词失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/build-prompt", response_model=PromptBuilderResponse)
|
||||
@ -204,4 +171,33 @@ async def build_prompt(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建提示词失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"构建提示词失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/generate-content", response_model=GenerateContentResponse)
|
||||
async def generate_content(
|
||||
request: PromptBuilderRequest,
|
||||
prompt_builder: PromptBuilderService = Depends(get_prompt_builder),
|
||||
tweet_service = Depends(get_tweet_service)
|
||||
):
|
||||
"""使用构建的提示词生成内容"""
|
||||
try:
|
||||
# 构建提示词
|
||||
step = request.step or "content"
|
||||
system_prompt, user_prompt = prompt_builder.build_content_prompt(request.topic, step)
|
||||
|
||||
# 使用提示词生成内容
|
||||
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
|
||||
topic=request.topic,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt
|
||||
)
|
||||
|
||||
return GenerateContentResponse(
|
||||
request_id=request_id,
|
||||
topic_index=topic_index,
|
||||
content=content
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"生成内容失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||||
@ -6,36 +6,51 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from core.config import ConfigManager
|
||||
from core.ai import AIAgent
|
||||
from utils.file_io import OutputManager
|
||||
from api.services.tweet import TweetService
|
||||
from api.models.tweet import (
|
||||
TopicRequest, TopicResponse,
|
||||
ContentRequest, ContentResponse,
|
||||
JudgeRequest, JudgeResponse,
|
||||
ContentRequest, ContentResponse,
|
||||
JudgeRequest, JudgeResponse,
|
||||
PipelineRequest, PipelineResponse
|
||||
)
|
||||
from api.services.tweet import TweetService
|
||||
from api.dependencies import get_tweet_service
|
||||
|
||||
# 从依赖注入模块导入依赖
|
||||
from api.dependencies import get_config, get_ai_agent, get_output_manager
|
||||
# 创建一个新的模型用于接收预构建提示词的请求
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ContentWithPromptRequest(BaseModel):
|
||||
"""带有预构建提示词的内容生成请求模型"""
|
||||
topic: Dict[str, Any] = Field(..., description="选题信息")
|
||||
system_prompt: str = Field(..., description="系统提示词")
|
||||
user_prompt: str = Field(..., description="用户提示词")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"topic": {
|
||||
"index": "1",
|
||||
"date": "2023-07-15",
|
||||
"object": "北京故宫",
|
||||
"product": "故宫门票",
|
||||
"style": "旅游攻略",
|
||||
"target_audience": "年轻人",
|
||||
"logic": "暑期旅游热门景点推荐"
|
||||
},
|
||||
"system_prompt": "你是一位专业的旅游内容创作者...",
|
||||
"user_prompt": "请为以下景点创作一篇旅游文章..."
|
||||
}
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# 依赖注入函数
|
||||
def get_tweet_service(
|
||||
config_manager: ConfigManager = Depends(get_config),
|
||||
ai_agent: AIAgent = Depends(get_ai_agent),
|
||||
output_manager: OutputManager = Depends(get_output_manager)
|
||||
) -> TweetService:
|
||||
"""获取文字内容服务"""
|
||||
return TweetService(ai_agent, config_manager, output_manager)
|
||||
router = APIRouter(
|
||||
prefix="/tweet",
|
||||
tags=["tweet"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/topics", response_model=TopicResponse, summary="生成选题")
|
||||
@ -93,6 +108,35 @@ async def generate_content(
|
||||
raise HTTPException(status_code=500, detail=f"生成内容失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/content-with-prompt", response_model=ContentResponse, summary="使用预构建提示词生成内容")
|
||||
async def generate_content_with_prompt(
|
||||
request: ContentWithPromptRequest,
|
||||
tweet_service: TweetService = Depends(get_tweet_service)
|
||||
):
|
||||
"""
|
||||
使用预构建的提示词为选题生成内容
|
||||
|
||||
- **topic**: 选题信息
|
||||
- **system_prompt**: 系统提示词
|
||||
- **user_prompt**: 用户提示词
|
||||
"""
|
||||
try:
|
||||
request_id, topic_index, content = await tweet_service.generate_content_with_prompt(
|
||||
topic=request.topic,
|
||||
system_prompt=request.system_prompt,
|
||||
user_prompt=request.user_prompt
|
||||
)
|
||||
|
||||
return ContentResponse(
|
||||
request_id=request_id,
|
||||
topic_index=topic_index,
|
||||
content=content
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"使用预构建提示词生成内容失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"使用预构建提示词生成内容失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/judge", response_model=JudgeResponse, summary="审核内容")
|
||||
async def judge_content(
|
||||
request: JudgeRequest,
|
||||
|
||||
Binary file not shown.
@ -98,6 +98,30 @@ class TweetService:
|
||||
logger.info(f"内容生成完成,请求ID: {request_id}, 选题索引: {topic_index}")
|
||||
return request_id, topic_index, content
|
||||
|
||||
async def generate_content_with_prompt(self, topic: Dict[str, Any], system_prompt: str, user_prompt: str) -> Tuple[str, str, Dict[str, Any]]:
|
||||
"""
|
||||
使用预构建的提示词为选题生成内容
|
||||
|
||||
Args:
|
||||
topic: 选题信息
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
|
||||
Returns:
|
||||
请求ID、选题索引和生成的内容
|
||||
"""
|
||||
topic_index = topic.get('index', 'unknown')
|
||||
logger.info(f"开始使用预构建提示词为选题 {topic_index} 生成内容")
|
||||
|
||||
# 使用预构建的提示词生成内容
|
||||
content = await self.content_generator.generate_content_with_prompt(topic, system_prompt, user_prompt)
|
||||
|
||||
# 生成请求ID
|
||||
request_id = f"content_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
logger.info(f"内容生成完成,请求ID: {request_id}, 选题索引: {topic_index}")
|
||||
return request_id, topic_index, content
|
||||
|
||||
async def judge_content(self, topic: Dict[str, Any], content: Dict[str, Any]) -> Tuple[str, str, Dict[str, Any], bool]:
|
||||
"""
|
||||
审核内容
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user