TravelContentCreator/scripts/reimage/regenerate_missing.py
2025-05-16 17:22:04 +08:00

1087 lines
55 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import json
import argparse
import logging
import shutil
from typing import List, Dict, Any, Tuple, Optional, Set
# python scripts/regenerate_missing.py --run_dir /root/autodl-tmp/TravelContentCreator/result/2025-05-12_21-36-33 --config poster_gen_config.json
# 添加项目根目录到系统路径
sys.path.append("/root/autodl-tmp/TravelContentCreator")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 导入所需的模块
from utils.output_handler import FileSystemOutputHandler
from utils.poster_notes_creator import select_additional_images
from utils.tweet_generator import generate_posters_for_topic, generate_content_for_topic
from core.topic_parser import TopicParser
from core.ai_agent import AI_Agent
from utils.prompt_manager import PromptManager
def load_config(config_path="poster_gen_config.json"):
"""加载配置文件"""
if not os.path.exists(config_path):
logging.error(f"错误:配置文件 '{config_path}' 不存在")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
logging.error(f"加载配置文件时出错: {e}")
return None
def check_content_status(run_dir: str, topic_indices: List[int], variants_count: int) -> Dict:
"""
检查所有主题和变体的内容状态
Args:
run_dir: 运行目录路径
topic_indices: 所有主题索引列表
variants_count: 每个主题应有的变体数量
Returns:
Dict: 包含所有主题变体内容状态的字典
"""
status = {}
for topic_index in topic_indices:
status[topic_index] = {}
for variant_index in range(1, variants_count + 1):
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
# 如果变体目录不存在,则标记为完全缺失
if not os.path.exists(topic_dir):
status[topic_index][variant_index] = {
"exists": False,
"has_content": False,
"has_poster": False,
"has_additional": False
}
continue
# 变体目录存在,检查内容
content_file = os.path.join(topic_dir, "tweet_content.json")
poster_dir = os.path.join(topic_dir, "poster")
# 检查海报图片和元数据
has_poster_image = False
has_poster_metadata = False
has_additional_images = False
if os.path.exists(poster_dir):
poster_files = [f for f in os.listdir(poster_dir)
if f.endswith(".jpg") and not f.startswith("additional_")]
has_poster_image = len(poster_files) > 0
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json")]
has_poster_metadata = len(metadata_files) > 0
additional_files = [f for f in os.listdir(poster_dir)
if f.startswith("additional_") and f.endswith(".jpg")]
has_additional_images = len(additional_files) > 0
# 记录状态
status[topic_index][variant_index] = {
"exists": True,
"has_content": os.path.exists(content_file),
"has_poster": has_poster_image and has_poster_metadata,
"has_additional": has_additional_images
}
return status
def regenerate_missing_content(run_dir: str, config: Dict, force_regenerate: bool = False):
"""重新生成缺失的海报和配图,包括完全缺失的变体
Args:
run_dir: 运行目录路径
config: 配置字典
force_regenerate: 是否强制重新生成所有海报
"""
run_id = os.path.basename(run_dir)
logging.info(f"处理运行ID: {run_id}")
# 创建输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
# 加载主题文件
topics_file = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
if not os.path.exists(topics_file):
alternative_topics_file = os.path.join(run_dir, "tweet_topic.json")
if not os.path.exists(alternative_topics_file):
logging.error(f"主题文件不存在: {topics_file}{alternative_topics_file}")
return
topics_file = alternative_topics_file
logging.info(f"使用主题文件: {topics_file}")
topics_list = TopicParser.load_topics_from_json(topics_file)
if not topics_list:
logging.error("无法加载主题列表")
return
# 获取所有主题索引和应有的变体数量
topic_indices = [topic.get('index') for topic in topics_list if topic.get('index')]
variants_count = config.get("variants", 1)
logging.info(f"找到 {len(topic_indices)} 个主题,每个主题应有 {variants_count} 个变体")
# 检查所有主题和变体的内容状态
content_status = check_content_status(run_dir, topic_indices, variants_count)
# 统计需要处理的内容
missing_variants = []
incomplete_variants = []
existing_variants = []
for topic_index in content_status:
for variant_index in content_status[topic_index]:
status = content_status[topic_index][variant_index]
variant_key = f"{topic_index}_{variant_index}"
if not status["exists"]:
missing_variants.append((topic_index, variant_index))
logging.info(f"变体 {variant_key} 完全不存在,需要创建")
elif not (status["has_poster"] and status["has_additional"]):
incomplete_variants.append((topic_index, variant_index))
logging.info(f"变体 {variant_key} 不完整:海报={status['has_poster']} 配图={status['has_additional']}")
else:
existing_variants.append((topic_index, variant_index))
if force_regenerate:
logging.info(f"变体 {variant_key} 完整,但将强制重新生成海报")
# 如果启用了强制重新生成,则添加所有已存在的变体到待处理列表
if force_regenerate and existing_variants:
logging.info(f"将强制重新生成 {len(existing_variants)} 个已存在的完整海报")
logging.info(f"需要创建 {len(missing_variants)} 个完全缺失的变体,补充 {len(incomplete_variants)} 个不完整的变体")
# 如果需要创建内容则初始化AI代理
if missing_variants or incomplete_variants or (force_regenerate and existing_variants):
# 创建PromptManager实例
try:
prompt_manager = PromptManager(
topic_system_prompt_path=config.get("topic_system_prompt"),
topic_user_prompt_path=config.get("topic_user_prompt"),
content_system_prompt_path=config.get("content_system_prompt"),
prompts_config=config.get("prompts_config"),
resource_dir_config=config.get("resource_dir", []),
topic_gen_num=config.get("num", 1),
topic_gen_date=config.get("date", ""),
content_judger_system_prompt_path=config.get("content_judger_system_prompt")
)
logging.info("PromptManager实例创建成功")
except Exception as e:
logging.error(f"创建PromptManager实例失败: {e}")
return
# 创建AI代理
try:
request_timeout = config.get("request_timeout", 180)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 30)
ai_agent = AI_Agent(
config.get("api_url"),
config.get("model"),
config.get("api_key"),
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout,
)
logging.info("AI代理创建成功")
except Exception as e:
logging.error(f"创建AI代理失败: {e}")
return
else:
ai_agent = None
prompt_manager = None
# 读取配置参数
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
collage_style = config.get("collage_style")
title_possibility = config.get("title_possibility", 0.3)
request_timeout = config.get("request_timeout", 180)
additional_images_count = config.get("additional_images_count", 3)
# 获取图像选择配置参数
image_selection_config = config.get("image_selection", {})
variation_strength = image_selection_config.get("variation_strength", "medium")
extra_effects = image_selection_config.get("extra_effects", True)
# 检查关键配置
if not poster_assets_dir or not img_base_dir or not poster_content_system_prompt_path:
logging.error("缺少关键配置参数")
return
# 读取海报内容系统提示词
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 构建主题索引到主题项的映射
topic_map = {topic.get('index'): topic for topic in topics_list if topic.get('index')}
# 1. 处理完全缺失的变体
if missing_variants and ai_agent and prompt_manager:
for topic_index, variant_index in missing_variants:
variant_key = f"{topic_index}_{variant_index}"
logging.info(f"创建完全缺失的变体: {variant_key}")
topic_item = topic_map.get(topic_index)
if not topic_item:
logging.error(f"找不到主题 {topic_index} 的数据,跳过")
continue
# 1.1 复制已有变体的内容文件(如果存在)
content_copied = False
for existing_variant in range(1, variants_count + 1):
if existing_variant == variant_index:
continue
existing_dir = os.path.join(run_dir, f"{topic_index}_{existing_variant}")
existing_content = os.path.join(existing_dir, "tweet_content.json")
if os.path.exists(existing_content):
# 创建变体目录
new_variant_dir = os.path.join(run_dir, variant_key)
os.makedirs(new_variant_dir, exist_ok=True)
# 复制内容文件
new_content_file = os.path.join(new_variant_dir, "tweet_content.json")
shutil.copy2(existing_content, new_content_file)
logging.info(f"已从变体 {topic_index}_{existing_variant} 复制内容文件到 {variant_key}")
content_copied = True
break
# 1.2 如果没有可复制的内容,则生成新内容
if not content_copied:
logging.info(f"为变体 {variant_key} 生成新内容...")
# 只为当前变体生成内容
content_success = generate_content_for_topic(
ai_agent,
prompt_manager,
topic_item,
run_id,
topic_index,
output_handler,
variants=1,
temperature=config.get("content_temperature", 0.3),
top_p=config.get("content_top_p", 0.4),
presence_penalty=config.get("content_presence_penalty", 1.5),
enable_content_judge=config.get("enable_content_judge", False)
)
if not content_success:
logging.error(f"为变体 {variant_key} 生成内容失败,跳过后续处理")
continue
logging.info(f"已为变体 {variant_key} 生成内容")
# 1.3 生成海报
logging.info(f"为变体 {variant_key} 生成海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1, # 只生成一个变体
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
continue
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
continue
# 1.4 生成额外配图
topic_dir = os.path.join(run_dir, variant_key)
poster_dir = os.path.join(topic_dir, poster_subdir)
if not os.path.exists(poster_dir):
logging.warning(f"海报目录不存在: {poster_dir},无法生成配图")
continue
try:
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成额外配图")
except Exception as e:
logging.warning(f"访问海报目录时出错: {e}")
# 2. 处理已存在但不完整的变体
for topic_index, variant_index in incomplete_variants:
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
topic_item = topic_map.get(topic_index)
if not topic_item:
logging.error(f"找不到主题 {topic_index} 的数据,跳过")
continue
status = content_status[topic_index][variant_index]
# 2.1 如果缺少海报,生成海报
if not status["has_poster"]:
logging.info(f"为变体 {variant_key} 补充生成海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1,
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
# 更新状态
status["has_poster"] = True
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
# 2.2 如果有海报但缺少额外配图,生成配图
if status["has_poster"] and not status["has_additional"]:
poster_dir = os.path.join(topic_dir, poster_subdir)
if not os.path.exists(poster_dir):
logging.warning(f"海报目录不存在: {poster_dir},无法生成配图")
continue
try:
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 补充生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成额外配图")
except Exception as e:
logging.warning(f"访问海报目录时出错: {e}")
# 3. 处理已存在且完整的变体(如果启用了强制重新生成)
if force_regenerate and existing_variants and ai_agent:
logging.info(f"开始强制重新生成 {len(existing_variants)} 个已存在的完整海报")
for topic_index, variant_index in existing_variants:
variant_key = f"{topic_index}_{variant_index}"
topic_dir = os.path.join(run_dir, variant_key)
topic_item = topic_map.get(topic_index)
if not topic_item:
logging.error(f"找不到主题 {topic_index} 的数据,跳过")
continue
# 删除原有海报文件,不创建备份
poster_dir = os.path.join(topic_dir, poster_subdir)
if os.path.exists(poster_dir):
try:
# 直接删除原有海报目录中的所有文件
for file_name in os.listdir(poster_dir):
file_path = os.path.join(poster_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
logging.debug(f"已删除文件: {file_path}")
logging.info(f"已清空变体 {variant_key} 的海报目录,准备重新生成")
except Exception as e:
logging.error(f"删除原有海报文件时出错: {e}")
continue
else:
# 如果目录不存在,创建它
os.makedirs(poster_dir, exist_ok=True)
logging.info(f"创建海报目录: {poster_dir}")
# 重新生成海报
logging.info(f"强制重新生成变体 {variant_key} 的海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1,
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报强制重新生成完成")
else:
logging.warning(f"变体 {variant_key} 海报强制重新生成失败")
continue
except Exception as e:
logging.exception(f"强制重新生成海报时出错: {e}")
continue
# 重新生成额外配图
if os.path.exists(poster_dir):
try:
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 重新生成额外配图,使用新生成的元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
# 删除可能存在的额外配图文件
for file_name in os.listdir(poster_dir):
if file_name.startswith("additional_") and file_name.endswith(".jpg"):
os.remove(os.path.join(poster_dir, file_name))
logging.debug(f"已删除额外配图: {file_name}")
additional_paths = select_additional_images(
run_id=run_id,
topic_index=topic_index,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 重新生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 重新生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"重新生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法重新生成配图")
except Exception as e:
logging.warning(f"访问海报目录时出错: {e}")
# 关闭AI代理
if ai_agent:
ai_agent.close()
logging.info("AI代理已关闭")
def main():
parser = argparse.ArgumentParser(description="补充生成丢失的海报和配图")
parser.add_argument("--run_dir", required=True, help="之前运行结果的目录")
parser.add_argument("--config", default="poster_gen_config.json", help="配置文件路径")
parser.add_argument("--debug", action="store_true", help="启用调试日志")
parser.add_argument("--force", action="store_true", help="强制重新生成所有海报,不仅是缺失的")
parser.add_argument("--topic", type=int, help="指定要重新生成的主题ID")
parser.add_argument("--poster-config", help="指定海报配置文件路径,优先使用该配置")
parser.add_argument("--object", help="指定要使用的景点名称,例如'青青世界'")
args = parser.parse_args()
# 设置日志级别
log_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 加载配置
config = load_config(args.config)
if not config:
sys.exit(1)
# 检查目录是否存在
if not os.path.exists(args.run_dir):
logging.error(f"指定的运行目录不存在: {args.run_dir}")
sys.exit(1)
# 开始处理
if args.poster_config and args.topic:
logging.info(f"开始使用指定配置文件为主题 {args.topic} 重新生成海报")
# 检查海报配置文件是否存在
if not os.path.exists(args.poster_config):
logging.error(f"指定的海报配置文件不存在: {args.poster_config}")
sys.exit(1)
# 读取海报配置文件
try:
with open(args.poster_config, 'r', encoding='utf-8') as f:
poster_configs = json.load(f)
logging.info(f"已加载海报配置文件: {args.poster_config}")
# 从配置文件中查找主题配置
topic_config = None
for config_item in poster_configs:
if config_item.get('index') == args.topic:
topic_config = config_item
break
if not topic_config:
logging.error(f"在配置文件中找不到主题 {args.topic}")
sys.exit(1)
logging.info(f"找到主题 {args.topic} 配置: {json.dumps(topic_config, ensure_ascii=False)}")
# 创建输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
# 加载主题文件以获取运行ID
run_id = os.path.basename(args.run_dir)
# 直接从海报配置创建临时主题对象
# 尝试找到一个有效的目标文件夹用于图片
img_base_dir = config.get("image_base_dir")
target_object = None
# 如果指定了object参数优先使用
if args.object:
target_object = f"景点信息-{args.object}.txt"
logging.info(f"使用指定的景点: {args.object}")
# 否则尝试自动查找合适的图片目录
elif os.path.exists(img_base_dir):
# 寻找一个合适的图片目录作为对象
subdirs = [d for d in os.listdir(img_base_dir) if os.path.isdir(os.path.join(img_base_dir, d))]
if subdirs:
# 优先使用青青世界的图片,因为海报主题是青青世界
qingqing_dirs = [d for d in subdirs if "青青" in d]
if qingqing_dirs:
target_object = f"景点信息-{qingqing_dirs[0]}.txt"
else:
target_object = f"景点信息-{subdirs[0]}.txt"
logging.info(f"使用图片目录: {os.path.splitext(target_object)[0].replace('景点信息-', '')}")
if not target_object:
target_object = "景点信息-青青世界.txt" # 默认假设有青青世界目录
temp_topic_item = {
"index": args.topic,
"title": topic_config.get("main_title", f"Topic {args.topic}"),
"object": target_object
}
logging.info(f"已创建临时主题对象: {json.dumps(temp_topic_item, ensure_ascii=False)}")
# 直接调用生成函数
try:
# 创建AI代理
request_timeout = config.get("request_timeout", 180)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 30)
ai_agent = AI_Agent(
config.get("api_url"),
config.get("model"),
config.get("api_key"),
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout,
)
logging.info("AI代理创建成功")
# 读取配置参数
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
collage_style = config.get("collage_style")
title_possibility = config.get("title_possibility", 0.3)
request_timeout = config.get("request_timeout", 180)
additional_images_count = config.get("additional_images_count", 3)
# 获取图像选择配置参数
image_selection_config = config.get("image_selection", {})
variation_strength = image_selection_config.get("variation_strength", "medium")
extra_effects = image_selection_config.get("extra_effects", True)
# 读取海报内容系统提示词
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 处理主题的每个变体
variants_count = config.get("variants", 1)
for variant_index in range(1, variants_count + 1):
variant_key = f"{args.topic}_{variant_index}"
topic_dir = os.path.join(args.run_dir, variant_key)
if not os.path.exists(topic_dir):
os.makedirs(topic_dir, exist_ok=True)
logging.info(f"创建变体目录: {topic_dir}")
# 确保内容文件存在
content_file = os.path.join(topic_dir, "tweet_content.json")
if not os.path.exists(content_file):
# 查找同主题的其他变体复制内容
content_copied = False
for other_variant in range(1, variants_count + 1):
if other_variant == variant_index:
continue
other_dir = os.path.join(args.run_dir, f"{args.topic}_{other_variant}")
other_content = os.path.join(other_dir, "tweet_content.json")
if os.path.exists(other_content):
shutil.copy2(other_content, content_file)
logging.info(f"已从变体 {args.topic}_{other_variant} 复制内容文件到 {variant_key}")
content_copied = True
break
# 删除现有海报文件
poster_dir = os.path.join(topic_dir, poster_subdir)
if os.path.exists(poster_dir):
for file_name in os.listdir(poster_dir):
file_path = os.path.join(poster_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
logging.debug(f"已删除文件: {file_path}")
logging.info(f"已清空变体 {variant_key} 的海报目录")
else:
os.makedirs(poster_dir, exist_ok=True)
logging.info(f"创建海报目录: {poster_dir}")
# 重新生成海报
logging.info(f"为变体 {variant_key} 生成海报...")
try:
# 使用更新后的主题项生成海报
posters_attempted = generate_posters_for_topic(
topic_item=temp_topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=args.topic,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1,
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
# 检查是否生成了元数据文件
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = temp_topic_item['object'].split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
# 删除旧的配图
for file_name in os.listdir(poster_dir):
if file_name.startswith("additional_") and file_name.endswith(".jpg"):
os.remove(os.path.join(poster_dir, file_name))
additional_paths = select_additional_images(
run_id=run_id,
topic_index=args.topic,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成配图")
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
# 关闭AI代理
ai_agent.close()
except Exception as e:
logging.exception(f"处理主题 {args.topic} 时出错: {e}")
sys.exit(1)
logging.info(f"使用指定配置文件为主题 {args.topic} 的海报和配图生成完成")
except Exception as e:
logging.exception(f"读取海报配置文件时出错: {e}")
sys.exit(1)
elif args.topic:
logging.info(f"开始为主题 {args.topic} 重新生成海报和配图")
# 创建输出处理器
output_handler = FileSystemOutputHandler(config.get("output_dir", "result"))
# 加载主题文件
run_id = os.path.basename(args.run_dir)
topics_file = os.path.join(args.run_dir, f"tweet_topic_{run_id}.json")
if not os.path.exists(topics_file):
alternative_topics_file = os.path.join(args.run_dir, "tweet_topic.json")
if not os.path.exists(alternative_topics_file):
logging.error(f"主题文件不存在: {topics_file}{alternative_topics_file}")
sys.exit(1)
topics_file = alternative_topics_file
logging.info(f"使用主题文件: {topics_file}")
topics_list = TopicParser.load_topics_from_json(topics_file)
if not topics_list:
logging.error("无法加载主题列表")
sys.exit(1)
# 寻找指定的主题
topic_item = None
for topic in topics_list:
if topic.get('index') == args.topic:
topic_item = topic
break
if not topic_item:
logging.error(f"找不到主题 {args.topic}")
sys.exit(1)
# 直接调用生成函数
try:
# 创建AI代理
request_timeout = config.get("request_timeout", 180)
max_retries = config.get("max_retries", 3)
stream_chunk_timeout = config.get("stream_chunk_timeout", 30)
ai_agent = AI_Agent(
config.get("api_url"),
config.get("model"),
config.get("api_key"),
timeout=request_timeout,
max_retries=max_retries,
stream_chunk_timeout=stream_chunk_timeout,
)
logging.info("AI代理创建成功")
# 读取配置参数
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
img_frame_possibility = config.get("img_frame_possibility", 0.7)
text_bg_possibility = config.get("text_bg_possibility", 0)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
poster_content_system_prompt_path = config.get("poster_content_system_prompt")
collage_style = config.get("collage_style")
title_possibility = config.get("title_possibility", 0.3)
request_timeout = config.get("request_timeout", 180)
additional_images_count = config.get("additional_images_count", 3)
# 获取图像选择配置参数
image_selection_config = config.get("image_selection", {})
variation_strength = image_selection_config.get("variation_strength", "medium")
extra_effects = image_selection_config.get("extra_effects", True)
# 读取海报内容系统提示词
with open(poster_content_system_prompt_path, "r", encoding="utf-8") as f:
poster_content_system_prompt = f.read()
# 处理主题的每个变体
variants_count = config.get("variants", 1)
for variant_index in range(1, variants_count + 1):
variant_key = f"{args.topic}_{variant_index}"
topic_dir = os.path.join(args.run_dir, variant_key)
if not os.path.exists(topic_dir):
os.makedirs(topic_dir, exist_ok=True)
logging.info(f"创建变体目录: {topic_dir}")
# 确保内容文件存在
content_file = os.path.join(topic_dir, "tweet_content.json")
if not os.path.exists(content_file):
# 查找同主题的其他变体复制内容
content_copied = False
for other_variant in range(1, variants_count + 1):
if other_variant == variant_index:
continue
other_dir = os.path.join(args.run_dir, f"{args.topic}_{other_variant}")
other_content = os.path.join(other_dir, "tweet_content.json")
if os.path.exists(other_content):
shutil.copy2(other_content, content_file)
logging.info(f"已从变体 {args.topic}_{other_variant} 复制内容文件到 {variant_key}")
content_copied = True
break
# 删除现有海报文件
poster_dir = os.path.join(topic_dir, poster_subdir)
if os.path.exists(poster_dir):
for file_name in os.listdir(poster_dir):
file_path = os.path.join(poster_dir, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
logging.debug(f"已删除文件: {file_path}")
logging.info(f"已清空变体 {variant_key} 的海报目录")
else:
os.makedirs(poster_dir, exist_ok=True)
logging.info(f"创建海报目录: {poster_dir}")
# 重新生成海报
logging.info(f"为变体 {variant_key} 生成海报...")
try:
posters_attempted = generate_posters_for_topic(
topic_item=topic_item,
output_dir=config["output_dir"],
run_id=run_id,
topic_index=args.topic,
output_handler=output_handler,
model_name=config["model"],
base_url=config["api_url"],
api_key=config["api_key"],
variants=1,
title_possibility=title_possibility,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
img_frame_possibility=img_frame_possibility,
text_bg_possibility=text_bg_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
system_prompt=poster_content_system_prompt,
collage_style=collage_style,
timeout=request_timeout
)
if posters_attempted:
logging.info(f"变体 {variant_key} 海报生成完成")
# 检查是否生成了元数据文件
metadata_files = [f for f in os.listdir(poster_dir)
if f.endswith("_metadata.json") and os.path.isfile(os.path.join(poster_dir, f))]
if metadata_files:
poster_metadata_path = os.path.join(poster_dir, metadata_files[0])
logging.info(f"为变体 {variant_key} 生成额外配图,使用元数据: {poster_metadata_path}")
try:
object_name = topic_item.get("object", "").split(".")[0].replace("景点信息-", "").strip()
source_image_dir = os.path.join(img_base_dir, object_name)
if os.path.exists(source_image_dir) and os.path.isdir(source_image_dir):
# 删除旧的配图
for file_name in os.listdir(poster_dir):
if file_name.startswith("additional_") and file_name.endswith(".jpg"):
os.remove(os.path.join(poster_dir, file_name))
additional_paths = select_additional_images(
run_id=run_id,
topic_index=args.topic,
variant_index=variant_index,
poster_metadata_path=poster_metadata_path,
source_image_dir=source_image_dir,
num_additional_images=additional_images_count,
output_handler=output_handler,
variation_strength=variation_strength,
extra_effects=extra_effects
)
if additional_paths:
logging.info(f"已为变体 {variant_key} 生成 {len(additional_paths)} 张额外配图")
else:
logging.warning(f"未能为变体 {variant_key} 生成任何额外配图")
else:
logging.warning(f"源图像目录不存在: {source_image_dir}")
except Exception as e:
logging.exception(f"生成额外配图时出错: {e}")
else:
logging.warning(f"未找到海报元数据文件,无法生成配图")
else:
logging.warning(f"变体 {variant_key} 海报生成失败")
except Exception as e:
logging.exception(f"生成海报时出错: {e}")
# 关闭AI代理
ai_agent.close()
except Exception as e:
logging.exception(f"处理主题 {args.topic} 时出错: {e}")
sys.exit(1)
logging.info(f"主题 {args.topic} 的海报和配图生成完成")
elif args.force:
logging.info(f"开始为 {args.run_dir} 强制重新生成所有海报和配图")
regenerate_missing_content(args.run_dir, config, args.force)
else:
logging.info(f"开始为 {args.run_dir} 补充生成缺失的海报和配图")
regenerate_missing_content(args.run_dir, config, args.force)
logging.info("处理完成")
if __name__ == "__main__":
main()