139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Travel Content Creator
|
|
主入口文件
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import logging
|
|
import asyncio
|
|
import argparse
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from core.config import get_config_manager
|
|
from core.ai import AIAgent
|
|
from utils.file_io import OutputManager
|
|
from tweet.topic_generator import TopicGenerator
|
|
from tweet.content_generator import ContentGenerator
|
|
from tweet.content_judger import ContentJudger
|
|
from poster.poster_generator import PosterGenerator
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Pipeline:
|
|
"""
|
|
内容生成流水线
|
|
协调各个模块的工作
|
|
"""
|
|
|
|
def __init__(self):
|
|
# 初始化配置
|
|
self.config_manager = get_config_manager()
|
|
self.config_manager.load_from_directory("config")
|
|
|
|
# 初始化输出管理器
|
|
run_id = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
self.output_manager = OutputManager("result", run_id)
|
|
|
|
# 初始化AI代理
|
|
ai_config = self.config_manager.get_config('ai_model', AIModelConfig)
|
|
self.ai_agent = AIAgent(ai_config)
|
|
|
|
# 初始化各个组件
|
|
self.topic_generator = TopicGenerator(self.ai_agent, self.config_manager, self.output_manager)
|
|
self.content_generator = ContentGenerator(self.ai_agent, self.config_manager, self.output_manager)
|
|
self.content_judger = ContentJudger(self.ai_agent, self.config_manager, self.output_manager)
|
|
self.poster_generator = PosterGenerator(self.config_manager, self.output_manager)
|
|
|
|
async def run(self):
|
|
"""运行完整流水线"""
|
|
start_time = time.time()
|
|
logger.info("--- 开始执行内容生成流水线 ---")
|
|
|
|
# 步骤1: 生成选题
|
|
logger.info("--- 步骤 1: 开始生成选题 ---")
|
|
topics = await self.topic_generator.generate_topics()
|
|
if not topics:
|
|
logger.error("未能生成任何选题,流程终止。")
|
|
return
|
|
|
|
logger.info(f"成功生成 {len(topics)} 个选题")
|
|
|
|
# 步骤2: 为每个选题生成内容
|
|
logger.info("--- 步骤 2: 开始生成内容 ---")
|
|
contents = {}
|
|
for topic in topics:
|
|
topic_index = topic.get('index', 'unknown')
|
|
logger.info(f"--- 步骤 2: 开始为选题 {topic_index} 生成内容 ---")
|
|
content = await self.content_generator.generate_content_for_topic(topic)
|
|
contents[topic_index] = content
|
|
|
|
# 步骤3: 审核内容
|
|
logger.info("--- 步骤 3: 开始审核内容 ---")
|
|
judged_contents = {}
|
|
for topic_index, content in contents.items():
|
|
topic = next((t for t in topics if t.get('index') == topic_index), None)
|
|
if not topic:
|
|
logger.warning(f"找不到选题 {topic_index} 的原始数据,跳过审核")
|
|
continue
|
|
|
|
logger.info(f"--- 步骤 3: 开始审核选题 {topic_index} 的内容 ---")
|
|
try:
|
|
judged_data = await self.content_judger.judge_content(content, topic)
|
|
judged_contents[topic_index] = judged_data
|
|
except Exception as e:
|
|
logger.critical(f"为选题 {topic_index} 处理内容审核时发生意外错误: {e}", exc_info=True)
|
|
|
|
# 步骤4: 生成海报
|
|
# logger.info("--- 步骤 4: 开始生成海报 ---")
|
|
# posters = {}
|
|
# for topic_index, content in judged_contents.items():
|
|
# if not content.get('judge_success', False):
|
|
# logger.warning(f"选题 {topic_index} 的内容审核未通过,跳过海报生成")
|
|
# continue
|
|
|
|
# logger.info(f"--- 步骤 4: 开始为选题 {topic_index} 生成海报 ---")
|
|
# poster_path = self.poster_generator.generate_poster(content, topic_index)
|
|
# if poster_path:
|
|
# posters[topic_index] = poster_path
|
|
|
|
# 完成
|
|
logger.info("--- 所有任务已完成 ---")
|
|
end_time = time.time()
|
|
logger.info(f"--- 运行结束 --- 耗时: {end_time - start_time:.2f} 秒 ---")
|
|
|
|
|
|
async def main():
|
|
"""主函数"""
|
|
parser = argparse.ArgumentParser(description="Travel Content Creator")
|
|
parser.add_argument("--config-dir", default="config", help="配置目录路径")
|
|
args = parser.parse_args()
|
|
|
|
# 检查配置目录
|
|
if not os.path.isdir(args.config_dir):
|
|
logger.error(f"配置目录不存在: {args.config_dir}")
|
|
sys.exit(1)
|
|
|
|
# 运行流水线
|
|
pipeline = Pipeline()
|
|
await pipeline.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 导入这里避免循环导入
|
|
from core.config import AIModelConfig
|
|
|
|
asyncio.run(main())
|
|
|