TravelContentCreator/test_poster.py

239 lines
9.2 KiB
Python
Raw Normal View History

2025-04-25 10:11:45 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试海报生成的脚本
"""
import os
import time
import argparse
import json
import logging
import sys
import traceback
from datetime import datetime
from core.ai_agent import AI_Agent
from utils.tweet_generator import generate_posters_for_topic
from utils.output_handler import FileSystemOutputHandler
from core.topic_parser import TopicParser
def load_config(config_path="poster_config.json"):
"""从JSON文件加载配置"""
if not os.path.exists(config_path):
print(f"错误:配置文件 '{config_path}' 未找到。")
sys.exit(1)
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 基本验证
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
"image_base_dir", "poster_assets_base_dir",
"poster_content_system_prompt"]
if not all(key in config for key in required_keys):
missing_keys = [key for key in required_keys if key not in config]
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
sys.exit(1)
return config
except json.JSONDecodeError:
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
sys.exit(1)
except Exception as e:
print(f"'{config_path}' 加载配置时出错:{e}")
sys.exit(1)
def main():
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 解析命令行参数
parser = argparse.ArgumentParser(description="测试海报生成")
parser.add_argument(
"--config",
type=str,
2025-04-26 11:40:21 +08:00
default="/root/autodl-tmp/TravelContentCreator/poster_gen_config.json",
2025-04-25 10:11:45 +08:00
help="配置文件路径例如poster_config.json"
)
parser.add_argument(
"--topics_file",
type=str,
required=True,
2025-04-26 11:40:21 +08:00
default="/root/autodl-tmp/TravelContentCreator/result/2025-04-25_17-27-03/tweet_topic_2025-04-25_17-27-03.json",
2025-04-25 10:11:45 +08:00
help="必需的选题JSON文件路径用于获取海报生成的主题数据。"
)
parser.add_argument(
"--topic_index",
type=int,
2025-04-26 11:40:21 +08:00
default=1,
2025-04-25 10:11:45 +08:00
help="要生成海报的特定选题索引。如果未提供,将为所有选题生成海报。"
)
parser.add_argument(
"--run_id",
type=str,
default=None,
help="可选的指定运行ID。如果未提供将生成基于时间戳的ID。"
)
parser.add_argument(
"--debug",
action='store_true',
help="启用调试级别日志记录。"
)
args = parser.parse_args()
2025-04-26 11:40:21 +08:00
2025-04-25 10:11:45 +08:00
# 调整日志级别(如果启用了调试)
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.info("已启用调试日志记录。")
logging.info("启动海报生成测试脚本...")
logging.info(f"使用配置文件:{args.config}")
logging.info(f"使用选题文件:{args.topics_file}")
if args.topic_index is not None:
logging.info(f"将仅处理选题索引:{args.topic_index}")
if args.run_id:
logging.info(f"使用指定的run_id{args.run_id}")
# 加载配置
config = load_config(args.config)
if config is None:
logging.critical("无法加载配置。退出。")
sys.exit(1)
# 初始化输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
# 加载选题数据
logging.info(f"从以下位置加载选题:{args.topics_file}")
topics_list = TopicParser.load_topics_from_json(args.topics_file)
if not topics_list:
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
sys.exit(1)
logging.info(f"成功加载{len(topics_list)}个选题。")
# 设置run_id
run_id = args.run_id
if run_id is None:
# 尝试从文件名推断run_id
try:
base = os.path.basename(args.topics_file)
if base.startswith("tweet_topic_") and base.endswith(".json"):
run_id = base[len("tweet_topic_"): -len(".json")]
logging.info(f"从选题文件名推断的run_id{run_id}")
else:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
logging.info(f"为海报生成的run_id{run_id}")
except Exception as e:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
logging.info(f"生成的run_id{run_id}")
# 加载海报内容系统提示词
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
if not os.path.exists(poster_content_system_prompt_path):
logging.error(f"海报内容系统提示词文件不存在:{poster_content_system_prompt_path}")
sys.exit(1)
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 准备海报生成参数
poster_variants = config.get("variants", 1)
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
2025-04-26 11:40:21 +08:00
title_possibility = config.get("title_possibility", 0.3)
collage_style = config.get("collage_style", None)
timeout = config.get("request_timeout", 180)
2025-04-25 10:11:45 +08:00
# 检查关键路径
if not poster_assets_dir or not img_base_dir:
logging.error("配置中缺少关键路径poster_assets_base_dir或image_base_dir。无法继续。")
sys.exit(1)
# 开始海报生成
pipeline_start_time = time.time()
logging.info("开始执行海报生成...")
poster_success = False
# 如果指定了topic_index只处理该选题
if args.topic_index is not None:
topics_to_process = []
for topic in topics_list:
if topic.get('index') == args.topic_index or (topic.get('index') is None and int(args.topic_index) == 1):
topics_to_process.append(topic)
break
if not topics_to_process:
logging.error(f"未找到索引为{args.topic_index}的选题。")
sys.exit(1)
else:
topics_to_process = topics_list
# 逐个处理选题
for i, topic_item in enumerate(topics_to_process):
topic_index = topic_item.get('index', i + 1)
logging.info(f"--- 处理选题 {topic_index}: {topic_item.get('object', 'N/A')} ---")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
variants=poster_variants,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
2025-04-26 11:40:21 +08:00
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
title_possibility=title_possibility,
timeout=timeout
2025-04-25 10:11:45 +08:00
)
if posters_attempted:
logging.info(f"选题{topic_index}的海报生成过程已完成。")
poster_success = True
else:
logging.warning(f"选题{topic_index}的海报生成被跳过或在早期失败。")
except Exception as e:
logging.exception(f"处理选题{topic_index}的海报生成时出错:")
2025-04-26 11:40:21 +08:00
logging.error(f"错误信息: {e}")
2025-04-25 10:11:45 +08:00
logging.info(f"--- 完成选题 {topic_index} ---")
# 最终化输出
if run_id:
output_handler.finalize(run_id)
pipeline_end_time = time.time()
if poster_success:
logging.info(f"海报生成完成,耗时{pipeline_end_time - pipeline_start_time:.2f}秒。")
else:
logging.warning("海报生成完成,但可能遇到错误或未生成输出。")
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
if __name__ == "__main__":
main()