239 lines
9.2 KiB
Python
239 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试海报生成的脚本
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import sys
|
||
import traceback
|
||
from datetime import datetime
|
||
|
||
from core.ai_agent import AI_Agent
|
||
from utils.tweet_generator import generate_posters_for_topic
|
||
from utils.output_handler import FileSystemOutputHandler
|
||
from core.topic_parser import TopicParser
|
||
|
||
def load_config(config_path="poster_config.json"):
|
||
"""从JSON文件加载配置"""
|
||
if not os.path.exists(config_path):
|
||
print(f"错误:配置文件 '{config_path}' 未找到。")
|
||
sys.exit(1)
|
||
try:
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 基本验证
|
||
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
|
||
"image_base_dir", "poster_assets_base_dir",
|
||
"poster_content_system_prompt"]
|
||
|
||
if not all(key in config for key in required_keys):
|
||
missing_keys = [key for key in required_keys if key not in config]
|
||
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
|
||
sys.exit(1)
|
||
|
||
return config
|
||
except json.JSONDecodeError:
|
||
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
print(f"从 '{config_path}' 加载配置时出错:{e}")
|
||
sys.exit(1)
|
||
|
||
def main():
|
||
# 设置日志记录
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description="测试海报生成")
|
||
parser.add_argument(
|
||
"--config",
|
||
type=str,
|
||
default="/root/autodl-tmp/TravelContentCreator/poster_gen_config.json",
|
||
help="配置文件路径(例如,poster_config.json)"
|
||
)
|
||
parser.add_argument(
|
||
"--topics_file",
|
||
type=str,
|
||
required=True,
|
||
default="/root/autodl-tmp/TravelContentCreator/result/2025-04-25_17-27-03/tweet_topic_2025-04-25_17-27-03.json",
|
||
help="必需的选题JSON文件路径,用于获取海报生成的主题数据。"
|
||
)
|
||
parser.add_argument(
|
||
"--topic_index",
|
||
type=int,
|
||
default=1,
|
||
help="要生成海报的特定选题索引。如果未提供,将为所有选题生成海报。"
|
||
)
|
||
parser.add_argument(
|
||
"--run_id",
|
||
type=str,
|
||
default=None,
|
||
help="可选的指定运行ID。如果未提供,将生成基于时间戳的ID。"
|
||
)
|
||
parser.add_argument(
|
||
"--debug",
|
||
action='store_true',
|
||
help="启用调试级别日志记录。"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 调整日志级别(如果启用了调试)
|
||
if args.debug:
|
||
logging.getLogger().setLevel(logging.DEBUG)
|
||
logging.info("已启用调试日志记录。")
|
||
|
||
logging.info("启动海报生成测试脚本...")
|
||
logging.info(f"使用配置文件:{args.config}")
|
||
logging.info(f"使用选题文件:{args.topics_file}")
|
||
if args.topic_index is not None:
|
||
logging.info(f"将仅处理选题索引:{args.topic_index}")
|
||
if args.run_id:
|
||
logging.info(f"使用指定的run_id:{args.run_id}")
|
||
|
||
# 加载配置
|
||
config = load_config(args.config)
|
||
if config is None:
|
||
logging.critical("无法加载配置。退出。")
|
||
sys.exit(1)
|
||
|
||
# 初始化输出处理器
|
||
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
|
||
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
|
||
|
||
# 加载选题数据
|
||
logging.info(f"从以下位置加载选题:{args.topics_file}")
|
||
topics_list = TopicParser.load_topics_from_json(args.topics_file)
|
||
if not topics_list:
|
||
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
|
||
sys.exit(1)
|
||
|
||
logging.info(f"成功加载{len(topics_list)}个选题。")
|
||
|
||
# 设置run_id
|
||
run_id = args.run_id
|
||
if run_id is None:
|
||
# 尝试从文件名推断run_id
|
||
try:
|
||
base = os.path.basename(args.topics_file)
|
||
if base.startswith("tweet_topic_") and base.endswith(".json"):
|
||
run_id = base[len("tweet_topic_"): -len(".json")]
|
||
logging.info(f"从选题文件名推断的run_id:{run_id}")
|
||
else:
|
||
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
|
||
logging.info(f"为海报生成的run_id:{run_id}")
|
||
except Exception as e:
|
||
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
|
||
logging.info(f"生成的run_id:{run_id}")
|
||
|
||
# 加载海报内容系统提示词
|
||
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
|
||
if not os.path.exists(poster_content_system_prompt_path):
|
||
logging.error(f"海报内容系统提示词文件不存在:{poster_content_system_prompt_path}")
|
||
sys.exit(1)
|
||
|
||
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
|
||
poster_content_system_prompt = f.read()
|
||
|
||
# 准备海报生成参数
|
||
poster_variants = config.get("variants", 1)
|
||
poster_assets_dir = config.get("poster_assets_base_dir")
|
||
img_base_dir = config.get("image_base_dir")
|
||
res_dir_config = config.get("resource_dir", [])
|
||
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
|
||
txt_possibility = config.get("text_possibility", 0.3)
|
||
img_frame_possibility = config.get("img_frame_possibility", 0.7)
|
||
text_bg_possibility = config.get("text_bg_possibility", 0)
|
||
collage_subdir = config.get("output_collage_subdir", "collage_img")
|
||
poster_subdir = config.get("output_poster_subdir", "poster")
|
||
poster_filename = config.get("output_poster_filename", "poster.jpg")
|
||
title_possibility = config.get("title_possibility", 0.3)
|
||
collage_style = config.get("collage_style", None)
|
||
timeout = config.get("request_timeout", 180)
|
||
|
||
# 检查关键路径
|
||
if not poster_assets_dir or not img_base_dir:
|
||
logging.error("配置中缺少关键路径(poster_assets_base_dir或image_base_dir)。无法继续。")
|
||
sys.exit(1)
|
||
|
||
# 开始海报生成
|
||
pipeline_start_time = time.time()
|
||
logging.info("开始执行海报生成...")
|
||
|
||
poster_success = False
|
||
|
||
# 如果指定了topic_index,只处理该选题
|
||
if args.topic_index is not None:
|
||
topics_to_process = []
|
||
for topic in topics_list:
|
||
if topic.get('index') == args.topic_index or (topic.get('index') is None and int(args.topic_index) == 1):
|
||
topics_to_process.append(topic)
|
||
break
|
||
if not topics_to_process:
|
||
logging.error(f"未找到索引为{args.topic_index}的选题。")
|
||
sys.exit(1)
|
||
else:
|
||
topics_to_process = topics_list
|
||
|
||
# 逐个处理选题
|
||
for i, topic_item in enumerate(topics_to_process):
|
||
topic_index = topic_item.get('index', i + 1)
|
||
logging.info(f"--- 处理选题 {topic_index}: {topic_item.get('object', 'N/A')} ---")
|
||
|
||
try:
|
||
posters_attempted = generate_posters_for_topic(
|
||
topic_item=topic_item,
|
||
output_dir=config["output_dir"],
|
||
run_id=run_id,
|
||
topic_index=topic_index,
|
||
output_handler=output_handler,
|
||
variants=poster_variants,
|
||
poster_assets_base_dir=poster_assets_dir,
|
||
image_base_dir=img_base_dir,
|
||
resource_dir_config=res_dir_config,
|
||
poster_target_size=poster_size,
|
||
text_possibility=txt_possibility,
|
||
img_frame_possibility=img_frame_possibility,
|
||
text_bg_possibility=text_bg_possibility,
|
||
output_collage_subdir=collage_subdir,
|
||
output_poster_subdir=poster_subdir,
|
||
output_poster_filename=poster_filename,
|
||
system_prompt=poster_content_system_prompt,
|
||
collage_style=collage_style,
|
||
title_possibility=title_possibility,
|
||
timeout=timeout
|
||
)
|
||
|
||
if posters_attempted:
|
||
logging.info(f"选题{topic_index}的海报生成过程已完成。")
|
||
poster_success = True
|
||
else:
|
||
logging.warning(f"选题{topic_index}的海报生成被跳过或在早期失败。")
|
||
except Exception as e:
|
||
logging.exception(f"处理选题{topic_index}的海报生成时出错:")
|
||
logging.error(f"错误信息: {e}")
|
||
logging.info(f"--- 完成选题 {topic_index} ---")
|
||
|
||
# 最终化输出
|
||
if run_id:
|
||
output_handler.finalize(run_id)
|
||
|
||
pipeline_end_time = time.time()
|
||
if poster_success:
|
||
logging.info(f"海报生成完成,耗时{pipeline_end_time - pipeline_start_time:.2f}秒。")
|
||
else:
|
||
logging.warning("海报生成完成,但可能遇到错误或未生成输出。")
|
||
|
||
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |