TravelContentCreator/test_poster.py

232 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试海报生成的脚本
"""
import os
import time
import argparse
import json
import logging
import sys
import traceback
from datetime import datetime
from core.ai_agent import AI_Agent
from utils.tweet_generator import generate_posters_for_topic
from utils.output_handler import FileSystemOutputHandler
from core.topic_parser import TopicParser
def load_config(config_path="poster_config.json"):
"""从JSON文件加载配置"""
if not os.path.exists(config_path):
print(f"错误:配置文件 '{config_path}' 未找到。")
sys.exit(1)
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 基本验证
required_keys = ["api_url", "model", "api_key", "resource_dir", "output_dir",
"image_base_dir", "poster_assets_base_dir",
"poster_content_system_prompt"]
if not all(key in config for key in required_keys):
missing_keys = [key for key in required_keys if key not in config]
print(f"错误:配置文件 '{config_path}' 缺少必需的键:{missing_keys}")
sys.exit(1)
return config
except json.JSONDecodeError:
print(f"错误:无法从 '{config_path}' 解码JSON。请检查文件格式。")
sys.exit(1)
except Exception as e:
print(f"'{config_path}' 加载配置时出错:{e}")
sys.exit(1)
def main():
# 设置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 解析命令行参数
parser = argparse.ArgumentParser(description="测试海报生成")
parser.add_argument(
"--config",
type=str,
default="poster_config.json",
help="配置文件路径例如poster_config.json"
)
parser.add_argument(
"--topics_file",
type=str,
required=True,
help="必需的选题JSON文件路径用于获取海报生成的主题数据。"
)
parser.add_argument(
"--topic_index",
type=int,
default=None,
help="要生成海报的特定选题索引。如果未提供,将为所有选题生成海报。"
)
parser.add_argument(
"--run_id",
type=str,
default=None,
help="可选的指定运行ID。如果未提供将生成基于时间戳的ID。"
)
parser.add_argument(
"--debug",
action='store_true',
help="启用调试级别日志记录。"
)
args = parser.parse_args()
# 调整日志级别(如果启用了调试)
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
logging.info("已启用调试日志记录。")
logging.info("启动海报生成测试脚本...")
logging.info(f"使用配置文件:{args.config}")
logging.info(f"使用选题文件:{args.topics_file}")
if args.topic_index is not None:
logging.info(f"将仅处理选题索引:{args.topic_index}")
if args.run_id:
logging.info(f"使用指定的run_id{args.run_id}")
# 加载配置
config = load_config(args.config)
if config is None:
logging.critical("无法加载配置。退出。")
sys.exit(1)
# 初始化输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
logging.info(f"使用输出处理器:{output_handler.__class__.__name__}")
# 加载选题数据
logging.info(f"从以下位置加载选题:{args.topics_file}")
topics_list = TopicParser.load_topics_from_json(args.topics_file)
if not topics_list:
logging.error(f"无法从{args.topics_file}加载选题。无法继续。")
sys.exit(1)
logging.info(f"成功加载{len(topics_list)}个选题。")
# 设置run_id
run_id = args.run_id
if run_id is None:
# 尝试从文件名推断run_id
try:
base = os.path.basename(args.topics_file)
if base.startswith("tweet_topic_") and base.endswith(".json"):
run_id = base[len("tweet_topic_"): -len(".json")]
logging.info(f"从选题文件名推断的run_id{run_id}")
else:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
logging.info(f"为海报生成的run_id{run_id}")
except Exception as e:
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S_poster")
logging.info(f"生成的run_id{run_id}")
# 加载海报内容系统提示词
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
if not os.path.exists(poster_content_system_prompt_path):
logging.error(f"海报内容系统提示词文件不存在:{poster_content_system_prompt_path}")
sys.exit(1)
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 准备海报生成参数
poster_variants = config.get("variants", 1)
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
# 检查关键路径
if not poster_assets_dir or not img_base_dir:
logging.error("配置中缺少关键路径poster_assets_base_dir或image_base_dir。无法继续。")
sys.exit(1)
# 开始海报生成
pipeline_start_time = time.time()
logging.info("开始执行海报生成...")
poster_success = False
# 如果指定了topic_index只处理该选题
if args.topic_index is not None:
topics_to_process = []
for topic in topics_list:
if topic.get('index') == args.topic_index or (topic.get('index') is None and int(args.topic_index) == 1):
topics_to_process.append(topic)
break
if not topics_to_process:
logging.error(f"未找到索引为{args.topic_index}的选题。")
sys.exit(1)
else:
topics_to_process = topics_list
# 逐个处理选题
for i, topic_item in enumerate(topics_to_process):
topic_index = topic_item.get('index', i + 1)
logging.info(f"--- 处理选题 {topic_index}: {topic_item.get('object', 'N/A')} ---")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
variants=poster_variants,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt
)
if posters_attempted:
logging.info(f"选题{topic_index}的海报生成过程已完成。")
poster_success = True
else:
logging.warning(f"选题{topic_index}的海报生成被跳过或在早期失败。")
except Exception as e:
logging.exception(f"处理选题{topic_index}的海报生成时出错:")
logging.info(f"--- 完成选题 {topic_index} ---")
# 最终化输出
if run_id:
output_handler.finalize(run_id)
pipeline_end_time = time.time()
if poster_success:
logging.info(f"海报生成完成,耗时{pipeline_end_time - pipeline_start_time:.2f}秒。")
else:
logging.warning("海报生成完成,但可能遇到错误或未生成输出。")
logging.info(f"运行ID'{run_id}'的结果位于:{os.path.join(config.get('output_dir', 'result'), run_id)}")
if __name__ == "__main__":
main()