修改了数据交互传递方式, 现在更适合API封装调用了

This commit is contained in:
jinye_huang 2025-04-22 18:14:31 +08:00
parent b176add779
commit 67722a5c72
12 changed files with 769 additions and 598 deletions

View File

@ -237,9 +237,9 @@ finally:
**注意:** 如果你只需要最终的完整结果而不需要流式处理,可以直接调用 `ai_agent.work(...)` 方法,它会内部处理好拼接并直接返回结果字符串。
### Planned Refactoring: Decoupling Generation and Output Handling
### Refactoring Complete: Decoupling Generation and Output Handling
To enhance the flexibility and extensibility of this tool, we are planning a refactoring effort to separate the core content/image generation logic from the output handling (currently, saving results to the local filesystem).
To enhance the flexibility and extensibility of this tool, a refactoring effort has been completed to separate the core content/image generation logic from the output handling (previously, saving results directly to the local filesystem).
**Motivation:**
@ -247,27 +247,29 @@ To enhance the flexibility and extensibility of this tool, we are planning a ref
* **Alternative Storage:** Enable saving results to different backends like databases, cloud storage (e.g., S3, OSS), etc.
* **Modularity:** Improve code structure by separating concerns.
**Approach:**
**Approach Taken:**
1. **Modify Core Functions:** Functions responsible for generating topics, content, and posters (primarily in `utils/tweet_generator.py` and potentially `core` modules) will be updated to **return** the generated data (e.g., Python dictionaries, lists, PIL Image objects, or image bytes) rather than directly writing files to the `./result` directory.
2. **Introduce Output Handlers:** An "Output Handler" pattern will be implemented.
* An abstract base class or interface (`OutputHandler`) will define methods for processing different types of results (topics, content, configurations, images).
* An initial concrete implementation (`FileSystemOutputHandler`) will replicate the current behavior of saving all results to the `./result/{run_id}/...` directory structure.
3. **Update Main Workflow:** The main script (`main.py`) will be modified to:
* Instantiate a specific `OutputHandler` (initially the `FileSystemOutputHandler`).
* Call the generation functions to get the data.
* Pass the returned data to the `OutputHandler` instance for processing (e.g., saving).
1. **Modified Core Functions:** Functions responsible for generating topics, content, and posters (primarily in `utils/tweet_generator.py`, `core/simple_collage.py`, `core/posterGen.py`) have been updated:
* Topic generation now returns the generated data (`run_id`, `topics_list`, prompts).
* Image generation functions (`simple_collage.process_directory`, `posterGen.create_poster`) now return PIL Image objects instead of saving files.
* Content and poster generation workflows accept an `OutputHandler` instance to process results (content JSON, prompts, configurations, image data) immediately after generation.
2. **Introduced Output Handlers:** An "Output Handler" pattern has been implemented (`utils/output_handler.py`).
* An abstract base class (`OutputHandler`) defines methods for processing different types of results.
* A concrete implementation (`FileSystemOutputHandler`) replicates the original behavior of saving all results to the `./result/{run_id}/...` directory structure.
3. **Updated Main Workflow:** The main script (`main.py`) now:
* Instantiates a specific `OutputHandler` (currently `FileSystemOutputHandler`).
* Calls the generation functions, passing the `OutputHandler` where needed.
* Uses the `OutputHandler` to process data returned by the topic generation step.
4. **Reduced Config Dependency:** Core logic functions (`PromptManager`, `generate_content_for_topic`, `generate_posters_for_topic`, etc.) now receive necessary configuration values as specific parameters rather than relying on the entire `config` dictionary, making them more independent and testable.
**Future Possibilities:**
This refactoring will make it straightforward to add new output handlers in the future, such as:
This refactoring makes it straightforward to add new output handlers in the future, such as:
* `ApiOutputHandler`: Formats results for API responses.
* `DatabaseOutputHandler`: Stores results in a database.
* `CloudStorageOutputHandler`: Uploads results (especially images) to cloud storage and potentially stores metadata elsewhere.
**(Note:** This refactoring is in progress. The current documentation reflects the existing file-saving behavior, but expect changes in how results are handled internally.)
### 配置文件说明 (Configuration)
主配置文件为 `poster_gen_config.json` (可以复制 `example_config.json` 并修改)。主要包含以下部分:

View File

@ -590,102 +590,94 @@ class PosterGenerator:
draw.text((x, y), text, font=font, fill=text_color)
def _print_text_debug_info(self, text_type, font, text_width, x, y, font_path):
"""打印文字调试信息"""
print(f"{text_type}信息:")
print(f"- 字体大小: {font.size}")
print(f"- 文字宽度: {text_width}")
print(f"- 文字位置: x={x}, y={y}")
print(f"- 使用字体: {os.path.basename(font_path)}")
pass
# print(f" {text_type}: Font={os.path.basename(font_path)}, Size={font.size}, Width={text_width:.0f}, Pos=({x:.0f}, {y:.0f})")
def create_poster(self, image_path, text_data, output_name):
"""生成海报"""
def create_poster(self, image_path, text_data):
"""
Creates a poster by combining the base image, frame (optional), stickers (optional),
and text layers.
Args:
image_path: Path to the base image (e.g., the generated collage).
text_data: Dictionary containing text information (
{
'title': 'Main Title Text',
'subtitle': 'Subtitle Text (optional)',
'additional_texts': [{'text': '...', 'position': 'top/bottom/center', 'size_factor': 0.8}, ...]
}
).
Returns:
A PIL Image object representing the final poster, or None if creation failed.
"""
target_size = (900, 1200) # TODO: Make target_size a parameter?
print(f"\n--- Creating Poster --- ")
print(f"Input Image: {image_path}")
print(f"Text Data: {text_data}")
# print(f"Output Name: {output_name}") # output_name is removed
try:
# 增加计数器
self.poster_count += 1
# 设置目标尺寸为3:4比例
target_size = (900, 1200) # 3:4比例
# 创建三个图层
# 1. 创建底层(图片)
base_layer = self.create_base_layer(image_path, target_size)
middle_layer = self.create_middle_layer(target_size)
# 先合成底层和中间层
final_image = Image.new('RGBA', target_size, (0, 0, 0, 0))
final_image.paste(base_layer, (0, 0))
final_image.alpha_composite(middle_layer)
# 创建并添加文字层
text_layer = self.create_text_layer(target_size, text_data)
final_image.alpha_composite(text_layer) # 确保文字层在最上面
# 使用模10的余数决定是否添加边框
# 每十张中的第1、5、9张添加边框(余数为1,5,9)
add_frame_flag = random.random() < self.img_frame_posbility
if add_frame_flag:
final_image = self.add_frame(final_image, target_size)
# 保存结果
# 检查output_name是否已经是完整路径
if os.path.dirname(output_name):
output_path = output_name
# 确保目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
if not base_layer:
raise ValueError("Failed to create base layer.")
print("Base layer created.")
# 2. (可选) 添加边框 - 根据计数器决定
# if self.poster_count % 5 == 0: # 每5张海报添加一次边框
if random.random() < self.img_frame_posbility:
print("Attempting to add frame...")
base_layer = self.add_frame(base_layer, target_size)
else:
# 如果只是文件名,拼接输出目录
output_path = os.path.join(self.output_dir, output_name)
# 如果没有扩展名,添加.jpg
if not output_path.lower().endswith(('.jpg', '.jpeg', '.png')):
output_path += '.jpg'
final_image.convert('RGB').save(output_path)
print(f"海报已保存: {output_path}")
print(f"图片尺寸: {target_size[0]}x{target_size[1]} (3:4比例)")
print(f"使用的文字特效: {self.selected_effect}")
print("Skipping frame addition.")
self.poster_count += 1 # 增加计数器
return output_path
# 3. 创建中间层(文本框底图)
# middle_layer 包含文本区域定义 (self.title_area, self.additional_text_area)
middle_layer = self.create_middle_layer(target_size)
print("Middle layer (text areas defined) created.")
# 4. 创建文本层
text_layer = self.create_text_layer(target_size, text_data)
print("Text layer created.")
# 5. 合成图层
# Start with the base layer (which might already have a frame)
final_poster = base_layer
# Add middle layer (text backgrounds)
final_poster.alpha_composite(middle_layer)
# Add text layer
final_poster.alpha_composite(text_layer)
# (可选) 添加贴纸
final_poster = self.add_stickers(final_poster)
print("Layers composed.")
# 转换回 RGB 以保存为 JPG/PNG (移除 alpha通道)
final_poster_rgb = final_poster.convert("RGB")
# 移除保存逻辑,直接返回 Image 对象
# final_save_path = os.path.join(self.output_dir, output_name)
# os.makedirs(os.path.dirname(final_save_path), exist_ok=True)
# final_poster_rgb.save(final_save_path)
# print(f"Final poster saved to: {final_save_path}")
# return final_save_path
return final_poster_rgb
except Exception as e:
print(f"生成海报失败: {e}")
print(f"Error creating poster: {e}")
traceback.print_exc()
return None
def process_directory(self, input_dir, text_data=None):
pass
"""遍历处理目录中的所有图片"""
# 支持的图片格式
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
try:
# 获取目录中的所有文件
files = os.listdir(input_dir)
image_files = [f for f in files if f.lower().endswith(image_extensions)]
if not image_files:
print(f"在目录 {input_dir} 中未找到图片文件")
return
print(f"找到 {len(image_files)} 个图片文件")
# 处理每个图片文件
for i, image_file in enumerate(image_files, 1):
image_path = os.path.join(input_dir, image_file)
print(f"\n处理第 {i}/{len(image_files)} 个图片: {image_file}")
try:
# 构建输出文件名
output_name = os.path.splitext(image_file)[0]
# 生成海报
self.create_poster(image_path, text_data, output_name)
print(f"完成处理: {image_file}")
except Exception as e:
print(f"处理图片 {image_file} 时出错: {e}")
continue
except Exception as e:
print(f"处理目录时出错: {e}")
"""处理目录中的所有图片并为每张图片生成海报"""
# ... (此函数可能不再需要,或者需要重构以适应新的流程) ...
pass
def main():
# 设置是否使用文本框底图True为使用False为不使用
@ -708,7 +700,7 @@ def main():
}
# 处理目录中的所有图片
img_path = "/root/autodl-tmp/poster_baseboard_0403/output_collage/random_collage_1_collage.png"
generator.create_poster(img_path, text_data, f"{item['index']}.jpg")
generator.create_poster(img_path, text_data)
if __name__ == "__main__":
main()

View File

@ -687,47 +687,58 @@ class ImageCollageCreator:
self.collage_style = collage_style
return self.collage_style
def process_directory(input_dir, target_size=(900, 1200), output_count=10, output_dir=None, collage_styles = [
"grid_2x2", # 标准2x2网格
"asymmetric", # 非对称布局
"filmstrip", # 胶片条布局
"circles", # 圆形布局
"overlap", # 重叠风格
"mosaic", # 马赛克风格 3x3
"fullscreen", # 全覆盖拼图样式
"vertical_stack" # 新增:上下拼图样式
]):
"""处理目录中的图片并生成指定数量的拼贴画"""
try:
# 创建拼贴画生成器
creator = ImageCollageCreator()
creator.set_collage_style(collage_styles)
# 设置3:4比例的输出尺寸
target_size = target_size # 3:4比例
output_list = []
for i in range(output_count):
print(f"\n生成第 {i+1}/{output_count} 个拼贴画")
# 使用指定尺寸创建拼贴画
collage = creator.create_collage_with_style(input_dir, target_size=target_size)
# 保存拼贴画
if collage:
creator.save_collage(collage, os.path.join(output_dir, f"random_collage_{i}.png"))
print(f"完成第 {i+1} 个拼贴画")
output_list.append(
{ "collage_style": collage_styles[i],
"collage_index": i,
"collage": collage,
"path": os.path.join(output_dir, f"random_collage_{i}.png"),
}
)
else:
print(f"{i+1} 个拼贴画创建失败")
return output_list
except Exception as e:
print(f"处理目录时出错: {e}")
traceback.print_exc()
def process_directory(directory_path, target_size=(900, 1200), output_count=1):
"""
Processes images in a directory: finds main subject, adjusts contrast/saturation,
performs smart cropping/resizing, creates a collage, and returns PIL Image objects.
Args:
directory_path: Path to the directory containing images.
target_size: Tuple (width, height) for the final collage.
output_count: Number of collages to generate.
Returns:
A list containing the generated PIL Image objects for the collages,
or an empty list if processing fails.
"""
image_files = [os.path.join(directory_path, f) for f in os.listdir(directory_path)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not image_files:
print(f"No images found in {directory_path}")
return []
# Create collage
collage_images = []
for i in range(output_count):
collage = create_collage(image_files, target_size)
if collage:
# collage_filename = f"collage_{i}.png"
# save_path = os.path.join(output_dir, collage_filename)
# collage.save(save_path)
# print(f"Collage saved to {save_path}")
# collage_images.append({'path': save_path, 'image': collage})
collage_images.append(collage) # Return the PIL Image object directly
else:
print(f"Failed to create collage {i}")
return collage_images
def create_collage(image_paths, target_size=(900, 1200)):
# ... (internal logic, including find_main_subject, adjust_image, smart_crop_and_resize) ...
pass
def find_main_subject(image):
# ... (keep the existing implementation) ...
pass
def adjust_image(image, contrast=1.0, saturation=1.0):
# ... (keep the existing implementation) ...
pass
def smart_crop_and_resize(image, target_aspect_ratio):
# ... (keep the existing implementation) ...
pass
def main():
# 设置基础路径

155
main.py
View File

@ -20,6 +20,10 @@ from utils.tweet_generator import ( # Import the moved functions
)
from utils.prompt_manager import PromptManager # Import PromptManager
import random
# Import Output Handlers
from utils.output_handler import FileSystemOutputHandler, OutputHandler
from core.topic_parser import TopicParser
from utils.tweet_generator import tweetTopicRecord # Needed only if loading old topics files?
def load_config(config_path="poster_gen_config.json"):
"""Loads configuration from a JSON file."""
@ -51,27 +55,26 @@ def load_config(config_path="poster_gen_config.json"):
# --- Main Orchestration Step (Remains in main.py) ---
def generate_content_and_posters_step(config, run_id, tweet_topic_record):
def generate_content_and_posters_step(config, run_id, topics_list, output_handler):
"""
Step 2: Generates content and posters for each topic in the record.
Returns True if successful (at least partially), False otherwise.
"""
if not tweet_topic_record or not tweet_topic_record.topics_list:
if not topics_list or not topics_list:
# print("Skipping content/poster generation: No valid topics found in the record.")
logging.warning("Skipping content/poster generation: No valid topics found in the record.")
return False
# print(f"\n--- Starting Step 2: Content and Poster Generation for run_id: {run_id} ---")
logging.info(f"Starting Step 2: Content and Poster Generation for run_id: {run_id}")
# print(f"Processing {len(tweet_topic_record.topics_list)} topics...")
logging.info(f"Processing {len(tweet_topic_record.topics_list)} topics...")
# print(f"Processing {len(topics_list)} topics...")
logging.info(f"Processing {len(topics_list)} topics...")
success_flag = False
prompt_manager = PromptManager(config)
ai_agent = None
try:
# --- Initialize AI Agent for Content Generation ---
# print("Initializing AI Agent for content generation...")
request_timeout = config.get("request_timeout", 30) # Get timeout from config
max_retries = config.get("max_retries", 3) # Get max_retries from config
ai_agent = AI_Agent(
@ -84,34 +87,95 @@ def generate_content_and_posters_step(config, run_id, tweet_topic_record):
logging.info("AI Agent for content generation initialized.")
# --- Iterate through Topics ---
for i, topic_item in enumerate(tweet_topic_record.topics_list):
for i, topic_item in enumerate(topics_list):
topic_index = topic_item.get('index', i + 1) # Use parsed index if available
# print(f"\nProcessing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic_item.get('object', 'N/A')}")
logging.info(f"--- Processing Topic {topic_index}/{len(tweet_topic_record.topics_list)}: {topic_item.get('object', 'N/A')} ---") # Make it stand out
logging.info(f"--- Processing Topic {topic_index}/{len(topics_list)}: {topic_item.get('object', 'N/A')} ---") # Make it stand out
# --- Generate Content Variants ---
tweet_content_list = generate_content_for_topic(
ai_agent, prompt_manager, config, topic_item, config["output_dir"], run_id, topic_index
# 读取内容生成需要的参数
content_variants = config.get("variants", 1)
content_temp = config.get("content_temperature", 0.3)
content_top_p = config.get("content_top_p", 0.4)
content_presence_penalty = config.get("content_presence_penalty", 1.5)
# 调用修改后的 generate_content_for_topic
content_success = generate_content_for_topic(
ai_agent,
prompt_manager, # Pass PromptManager instance
topic_item,
run_id,
topic_index,
output_handler, # Pass OutputHandler instance
# 传递具体参数
variants=content_variants,
temperature=content_temp,
top_p=content_top_p,
presence_penalty=content_presence_penalty
)
if tweet_content_list:
# print(f" Content generation successful for Topic {topic_index}.")
# if tweet_content_list: # generate_content_for_topic 现在返回 bool
if content_success:
logging.info(f"Content generation successful for Topic {topic_index}.")
# --- Generate Posters ---
# TODO: 重构 generate_posters_for_topic 以移除 config 依赖
# TODO: 需要确定如何将 content 数据传递给 poster 生成步骤 (已解决:函数内部读取)
# 临时方案:可能需要在这里读取由 output_handler 保存的 content 文件
# 或者修改 generate_content_for_topic 以返回收集到的 content 数据列表 (选项1)
# 暂时跳过 poster 生成的调用,直到确定方案
# --- 重新启用 poster 生成调用 ---
logging.info(f"Proceeding to poster generation for Topic {topic_index}...")
# --- 读取 Poster 生成所需参数 ---
poster_variants = config.get("variants", 1) # 通常与 content variants 相同
poster_assets_dir = config.get("poster_assets_base_dir")
img_base_dir = config.get("image_base_dir")
mod_img_subdir = config.get("modify_image_subdir", "modify")
res_dir_config = config.get("resource_dir", [])
poster_size = tuple(config.get("poster_target_size", [900, 1200]))
txt_possibility = config.get("text_possibility", 0.3)
collage_subdir = config.get("output_collage_subdir", "collage_img")
poster_subdir = config.get("output_poster_subdir", "poster")
poster_filename = config.get("output_poster_filename", "poster.jpg")
cam_img_subdir = config.get("camera_image_subdir", "相机")
# 检查关键路径是否存在
if not poster_assets_dir or not img_base_dir:
logging.error(f"Missing critical paths for poster generation (poster_assets_base_dir or image_base_dir) in config. Skipping posters for topic {topic_index}.")
continue # 跳过此主题的海报生成
# --- 结束读取参数 ---
posters_attempted = generate_posters_for_topic(
config, topic_item, tweet_content_list, config["output_dir"], run_id, topic_index
topic_item=topic_item,
output_dir=config["output_dir"], # Base output dir is still needed
run_id=run_id,
topic_index=topic_index,
output_handler=output_handler, # <--- 传递 Output Handler
# 传递具体参数
variants=poster_variants,
poster_assets_base_dir=poster_assets_dir,
image_base_dir=img_base_dir,
modify_image_subdir=mod_img_subdir,
resource_dir_config=res_dir_config,
poster_target_size=poster_size,
text_possibility=txt_possibility,
output_collage_subdir=collage_subdir,
output_poster_subdir=poster_subdir,
output_poster_filename=poster_filename,
camera_image_subdir=cam_img_subdir
)
if posters_attempted: # Log even if no posters were *successfully* created, but the attempt was made
# print(f" Poster generation process completed for Topic {topic_index}.")
if posters_attempted:
logging.info(f"Poster generation process completed for Topic {topic_index}.")
success_flag = True # Mark success if content and poster attempts were made
success_flag = True # Mark overall success if content AND poster attempts were made
else:
# print(f" Poster generation skipped or failed early for Topic {topic_index}.")
logging.warning(f"Poster generation skipped or failed early for Topic {topic_index}.")
# 即使海报失败,只要内容成功了,也算部分成功?根据需求决定 success_flag
# success_flag = True # 取决于是否认为内容成功就足够
# logging.warning(f"Skipping poster generation for Topic {topic_index} pending refactor and data passing strategy.")
# Mark overall success if content generation succeeded
# success_flag = True
else:
# print(f" Content generation failed or yielded no valid results for Topic {topic_index}. Skipping posters.")
logging.warning(f"Content generation failed or yielded no valid results for Topic {topic_index}. Skipping posters.")
# print(f"--- Finished Topic {topic_index} ---")
logging.info(f"--- Finished Topic {topic_index} ---")
@ -181,28 +245,31 @@ def main():
logging.info("Debug logging enabled.")
# --- End Debug Level Adjustment ---
# print("Starting Travel Content Creator Pipeline...")
logging.info("Starting Travel Content Creator Pipeline...")
# print(f"Using configuration file: {args.config}")
logging.info(f"Using configuration file: {args.config}")
# if args.run_id: print(f"Using specific run_id: {args.run_id}")
# if args.topics_file: print(f"Using existing topics file: {args.topics_file}")
if args.run_id: logging.info(f"Using specific run_id: {args.run_id}")
if args.topics_file: logging.info(f"Using existing topics file: {args.topics_file}")
config = load_config(args.config)
if config is None:
# print("Critical: Failed to load configuration. Exiting.")
logging.critical("Failed to load configuration. Exiting.")
sys.exit(1)
# --- Initialize Output Handler ---
# For now, always use FileSystemOutputHandler. Later, this could be configurable.
output_handler: OutputHandler = FileSystemOutputHandler(config.get("output_dir", "result"))
logging.info(f"Using Output Handler: {output_handler.__class__.__name__}")
# --- End Output Handler Init ---
run_id = args.run_id
tweet_topic_record = None
# tweet_topic_record = None # No longer the primary way to pass data
topics_list = None
system_prompt = None
user_prompt = None
pipeline_start_time = time.time()
# --- Step 1: Topic Generation (or Load Existing) ---
if args.topics_file:
# print(f"Skipping Topic Generation (Step 1) - Loading topics from: {args.topics_file}")
logging.info(f"Skipping Topic Generation (Step 1) - Loading topics from: {args.topics_file}")
topics_list = TopicParser.load_topics_from_json(args.topics_file)
if topics_list:
@ -232,12 +299,12 @@ def main():
# print(f"Generated run_id for loaded topics: {run_id}")
logging.info(f"Generated run_id for loaded topics: {run_id}")
# Create a minimal tweetTopicRecord (prompts might be missing)
# We need the output_dir from config to make the record useful
output_dir = config.get("output_dir", "result") # Default to result if not in config
tweet_topic_record = tweetTopicRecord(topics_list, "", "", output_dir, run_id)
# print(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}")
logging.info(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}")
# Prompts are missing when loading from file, handle this if needed later
system_prompt = "" # Placeholder
user_prompt = "" # Placeholder
logging.info(f"Successfully loaded {len(topics_list)} topics for run_id: {run_id}. Prompts are not available.")
# Optionally, save the loaded topics using the handler?
# output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
else:
# print(f"Error: Failed to load topics from {args.topics_file}. Cannot proceed.")
logging.error(f"Failed to load topics from {args.topics_file}. Cannot proceed.")
@ -246,22 +313,29 @@ def main():
# print("--- Executing Topic Generation (Step 1) ---")
logging.info("Executing Topic Generation (Step 1)...")
step1_start = time.time()
run_id, tweet_topic_record = run_topic_generation_pipeline(config, args.run_id) # Pass run_id if provided
# Call the updated function, receive raw data
run_id, topics_list, system_prompt, user_prompt = run_topic_generation_pipeline(config, args.run_id)
step1_end = time.time()
if run_id and tweet_topic_record:
if run_id is not None and topics_list is not None: # Check if step succeeded
# print(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}")
logging.info(f"Step 1 completed successfully in {step1_end - step1_start:.2f} seconds. Run ID: {run_id}")
# --- Use Output Handler to save results ---
output_handler.handle_topic_results(run_id, topics_list, system_prompt, user_prompt)
else:
# print("Critical: Topic Generation (Step 1) failed. Exiting.")
logging.critical("Topic Generation (Step 1) failed. Exiting.")
sys.exit(1)
# --- Step 2: Content & Poster Generation ---
if run_id and tweet_topic_record:
if run_id is not None and topics_list is not None:
# print("\n--- Executing Content and Poster Generation (Step 2) ---")
logging.info("Executing Content and Poster Generation (Step 2)...")
step2_start = time.time()
success = generate_content_and_posters_step(config, run_id, tweet_topic_record)
# TODO: Refactor generate_content_and_posters_step to accept topics_list
# and use the output_handler instead of saving files directly.
# For now, we might need to pass topics_list and handler, or adapt it.
# Let's tentatively adapt the call signature, assuming the function will be refactored.
success = generate_content_and_posters_step(config, run_id, topics_list, output_handler)
step2_end = time.time()
if success:
# print(f"Step 2 completed in {step2_end - step2_start:.2f} seconds.")
@ -270,8 +344,13 @@ def main():
# print("Warning: Step 2 finished, but may have encountered errors or generated no output.")
logging.warning("Step 2 finished, but may have encountered errors or generated no output.")
else:
# print("Error: Cannot proceed to Step 2: Invalid run_id or tweet_topic_record from Step 1.")
logging.error("Cannot proceed to Step 2: Invalid run_id or tweet_topic_record from Step 1.")
# print("Error: Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.")
logging.error("Cannot proceed to Step 2: Invalid run_id or topics_list from Step 1.")
# --- Finalize Output ---
if run_id:
output_handler.finalize(run_id)
# --- End Finalize ---
pipeline_end_time = time.time()
# print(f"\nPipeline finished. Total execution time: {pipeline_end_time - pipeline_start_time:.2f} seconds.")

Binary file not shown.

145
utils/output_handler.py Normal file
View File

@ -0,0 +1,145 @@
import os
import json
import logging
from abc import ABC, abstractmethod
class OutputHandler(ABC):
"""Abstract base class for handling the output of the generation pipeline."""
@abstractmethod
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
"""Handles the results from the topic generation step."""
pass
@abstractmethod
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Handles the results for a single content variant."""
pass
@abstractmethod
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Handles the poster configuration generated for a topic."""
pass
@abstractmethod
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str):
"""Handles a generated image (collage or final poster).
Args:
image_type: Either 'collage' or 'poster'.
image_data: The image data (e.g., PIL Image object or bytes).
output_filename: The desired filename for the output (e.g., 'poster.jpg').
"""
pass
@abstractmethod
def finalize(self, run_id: str):
"""Perform any final actions for the run (e.g., close files, upload manifests)."""
pass
class FileSystemOutputHandler(OutputHandler):
"""Handles output by saving results to the local file system."""
def __init__(self, base_output_dir: str = "result"):
self.base_output_dir = base_output_dir
logging.info(f"FileSystemOutputHandler initialized. Base output directory: {self.base_output_dir}")
def _get_run_dir(self, run_id: str) -> str:
"""Gets the specific directory for a run, creating it if necessary."""
run_dir = os.path.join(self.base_output_dir, run_id)
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _get_variant_dir(self, run_id: str, topic_index: int, variant_index: int, subdir: str | None = None) -> str:
"""Gets the specific directory for a variant, optionally within a subdirectory (e.g., 'poster'), creating it if necessary."""
run_dir = self._get_run_dir(run_id)
variant_base_dir = os.path.join(run_dir, f"{topic_index}_{variant_index}")
target_dir = variant_base_dir
if subdir:
target_dir = os.path.join(variant_base_dir, subdir)
os.makedirs(target_dir, exist_ok=True)
return target_dir
def handle_topic_results(self, run_id: str, topics_list: list, system_prompt: str, user_prompt: str):
run_dir = self._get_run_dir(run_id)
# Save topics list
topics_path = os.path.join(run_dir, f"tweet_topic_{run_id}.json")
try:
with open(topics_path, "w", encoding="utf-8") as f:
json.dump(topics_list, f, ensure_ascii=False, indent=4)
logging.info(f"Topics list saved successfully to: {topics_path}")
except Exception as e:
logging.exception(f"Error saving topic JSON file to {topics_path}:")
# Save prompts
prompt_path = os.path.join(run_dir, f"topic_prompt_{run_id}.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
f.write("--- SYSTEM PROMPT ---\n")
f.write(system_prompt + "\n\n")
f.write("--- USER PROMPT ---\n")
f.write(user_prompt + "\n")
logging.info(f"Topic prompts saved successfully to: {prompt_path}")
except Exception as e:
logging.exception(f"Error saving topic prompts file to {prompt_path}:")
def handle_content_variant(self, run_id: str, topic_index: int, variant_index: int, content_data: dict, prompt_data: str):
"""Saves content JSON and prompt for a specific variant."""
variant_dir = self._get_variant_dir(run_id, topic_index, variant_index)
# Save content JSON
content_path = os.path.join(variant_dir, "article.json")
try:
with open(content_path, "w", encoding="utf-8") as f:
json.dump(content_data, f, ensure_ascii=False, indent=4)
logging.info(f"Content JSON saved to: {content_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {content_path}: {e}")
# Save content prompt
prompt_path = os.path.join(variant_dir, "tweet_prompt.txt")
try:
with open(prompt_path, "w", encoding="utf-8") as f:
# Assuming prompt_data is the user prompt used for this variant
f.write(prompt_data + "\n")
logging.info(f"Content prompt saved to: {prompt_path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {prompt_path}: {e}")
def handle_poster_configs(self, run_id: str, topic_index: int, config_data: list | dict):
"""Saves the complete poster configuration list/dict for a topic."""
run_dir = self._get_run_dir(run_id)
config_path = os.path.join(run_dir, f"topic_{topic_index}_poster_configs.json")
try:
with open(config_path, 'w', encoding='utf-8') as f_cfg_topic:
json.dump(config_data, f_cfg_topic, ensure_ascii=False, indent=4)
logging.info(f"Saved complete poster configurations for topic {topic_index} to: {config_path}")
except Exception as save_err:
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {config_path}: {save_err}")
def handle_generated_image(self, run_id: str, topic_index: int, variant_index: int, image_type: str, image_data, output_filename: str):
"""Saves a generated image (PIL Image) to the appropriate variant subdirectory."""
subdir = None
if image_type == 'collage':
subdir = 'collage_img' # TODO: Make these subdir names configurable?
elif image_type == 'poster':
subdir = 'poster'
else:
logging.warning(f"Unknown image type '{image_type}'. Saving to variant root.")
subdir = None # Save directly in variant dir if type is unknown
target_dir = self._get_variant_dir(run_id, topic_index, variant_index, subdir=subdir)
save_path = os.path.join(target_dir, output_filename)
try:
# Assuming image_data is a PIL Image object based on posterGen/simple_collage
image_data.save(save_path)
logging.info(f"Saved {image_type} image to: {save_path}")
except Exception as e:
logging.exception(f"Failed to save {image_type} image to {save_path}: {e}")
def finalize(self, run_id: str):
logging.info(f"FileSystemOutputHandler finalizing run: {run_id}. No specific actions needed.")
pass # Nothing specific to do for file system finalize

View File

@ -12,49 +12,54 @@ from .resource_loader import ResourceLoader # Use relative import within the sam
class PromptManager:
"""Handles the loading and construction of prompts."""
def __init__(self, config):
"""Initializes the PromptManager with the global configuration."""
self.config = config
# Instantiate ResourceLoader once, assuming it's mostly static methods or stateless
self.resource_loader = ResourceLoader()
def __init__(self,
topic_system_prompt_path: str,
topic_user_prompt_path: str,
content_system_prompt_path: str,
prompts_dir: str,
resource_dir_config: list,
topic_gen_num: int = 1, # Default values if needed
topic_gen_date: str = ""
):
self.topic_system_prompt_path = topic_system_prompt_path
self.topic_user_prompt_path = topic_user_prompt_path
self.content_system_prompt_path = content_system_prompt_path
self.prompts_dir = prompts_dir
self.resource_dir_config = resource_dir_config
self.topic_gen_num = topic_gen_num
self.topic_gen_date = topic_gen_date
def get_topic_prompts(self):
"""Constructs the system and user prompts for topic generation."""
logging.info("Constructing prompts for topic generation...")
try:
# --- System Prompt ---
system_prompt_path = self.config.get("topic_system_prompt")
if not system_prompt_path:
logging.error("topic_system_prompt path not specified in config.")
if not self.topic_system_prompt_path:
logging.error("Topic system prompt path not provided during PromptManager initialization.")
return None, None
# Use ResourceLoader's static method directly
system_prompt = ResourceLoader.load_file_content(system_prompt_path)
system_prompt = ResourceLoader.load_file_content(self.topic_system_prompt_path)
if not system_prompt:
logging.error(f"Failed to load topic system prompt from '{system_prompt_path}'.")
logging.error(f"Failed to load topic system prompt from '{self.topic_system_prompt_path}'.")
return None, None
# --- User Prompt ---
user_prompt_path = self.config.get("topic_user_prompt")
if not user_prompt_path:
logging.error("topic_user_prompt path not specified in config.")
if not self.topic_user_prompt_path:
logging.error("Topic user prompt path not provided during PromptManager initialization.")
return None, None
base_user_prompt = ResourceLoader.load_file_content(user_prompt_path)
if base_user_prompt is None: # Check for None explicitly
logging.error(f"Failed to load base topic user prompt from '{user_prompt_path}'.")
base_user_prompt = ResourceLoader.load_file_content(self.topic_user_prompt_path)
if base_user_prompt is None:
logging.error(f"Failed to load base topic user prompt from '{self.topic_user_prompt_path}'.")
return None, None
# --- Build the dynamic part of the user prompt (Logic moved from prepare_topic_generation) ---
user_prompt_dynamic = "你拥有的创作资料如下:\n"
# Add genPrompts directory structure
gen_prompts_path = self.config.get("prompts_dir")
if gen_prompts_path and os.path.isdir(gen_prompts_path):
if self.prompts_dir and os.path.isdir(self.prompts_dir):
try:
gen_prompts_list = os.listdir(gen_prompts_path)
gen_prompts_list = os.listdir(self.prompts_dir)
for gen_prompt_folder in gen_prompts_list:
folder_path = os.path.join(gen_prompts_path, gen_prompt_folder)
folder_path = os.path.join(self.prompts_dir, gen_prompt_folder)
if os.path.isdir(folder_path):
try:
# List files, filter out subdirs if needed
@ -63,13 +68,12 @@ class PromptManager:
except OSError as e:
logging.warning(f"Could not list directory {folder_path}: {e}")
except OSError as e:
logging.warning(f"Could not list base prompts directory {gen_prompts_path}: {e}")
logging.warning(f"Could not list base prompts directory {self.prompts_dir}: {e}")
else:
logging.warning(f"Prompts directory '{gen_prompts_path}' not found or invalid.")
logging.warning(f"Prompts directory '{self.prompts_dir}' not found or invalid.")
# Add resource directory contents
resource_dir_config = self.config.get("resource_dir", [])
for dir_info in resource_dir_config:
for dir_info in self.resource_dir_config:
source_type = dir_info.get("type", "UnknownType")
source_file_paths = dir_info.get("file_path", [])
for file_path in source_file_paths:
@ -81,7 +85,7 @@ class PromptManager:
logging.warning(f"Could not load resource file {file_path}")
# Add dateline information (optional)
user_prompt_dir = os.path.dirname(user_prompt_path)
user_prompt_dir = os.path.dirname(self.topic_user_prompt_path)
dateline_path = os.path.join(user_prompt_dir, "2025各月节日宣传节点时间表.md") # Consider making this configurable
if os.path.exists(dateline_path):
dateline_content = ResourceLoader.load_file_content(dateline_path)
@ -91,9 +95,7 @@ class PromptManager:
# Combine dynamic part, base template, and final parameters
user_prompt = user_prompt_dynamic + base_user_prompt
select_num = self.config.get("num", 1) # Default to 1 if not specified
select_date = self.config.get("date", "")
user_prompt += f"\n选题数量:{select_num}\n选题日期:{select_date}\n"
user_prompt += f"\n选题数量:{self.topic_gen_num}\n选题日期:{self.topic_gen_date}\n"
# --- End of moved logic ---
logging.info(f"Topic prompts constructed. System: {len(system_prompt)} chars, User: {len(user_prompt)} chars.")
@ -108,20 +110,18 @@ class PromptManager:
logging.info(f"Constructing content prompts for topic: {topic_item.get('object', 'N/A')}...")
try:
# --- System Prompt ---
system_prompt_path = self.config.get("content_system_prompt")
if not system_prompt_path:
logging.error("content_system_prompt path not specified in config.")
if not self.content_system_prompt_path:
logging.error("Content system prompt path not provided during PromptManager initialization.")
return None, None
# Use ResourceLoader's static method. load_system_prompt was just load_file_content.
system_prompt = ResourceLoader.load_file_content(system_prompt_path)
system_prompt = ResourceLoader.load_file_content(self.content_system_prompt_path)
if not system_prompt:
logging.error(f"Failed to load content system prompt from '{system_prompt_path}'.")
logging.error(f"Failed to load content system prompt from '{self.content_system_prompt_path}'.")
return None, None
# --- User Prompt (Logic moved from ResourceLoader.build_user_prompt) ---
user_prompt = ""
prompts_dir = self.config.get("prompts_dir")
resource_dir_config = self.config.get("resource_dir", [])
prompts_dir = self.prompts_dir
resource_dir_config = self.resource_dir_config
if not prompts_dir or not os.path.isdir(prompts_dir):
logging.warning(f"Prompts directory '{prompts_dir}' not found or invalid. Content user prompt might be incomplete.")

View File

@ -26,6 +26,7 @@ from utils.prompt_manager import PromptManager # Keep this as it's importing fro
from core import contentGen as core_contentGen
from core import posterGen as core_posterGen
from core import simple_collage as core_simple_collage
from .output_handler import OutputHandler # <-- 添加导入
class tweetTopic:
def __init__(self, index, date, logic, object, product, product_logic, style, style_logic, target_audience, target_audience_logic):
@ -41,60 +42,29 @@ class tweetTopic:
self.target_audience_logic = target_audience_logic
class tweetTopicRecord:
def __init__(self, topics_list, system_prompt, user_prompt, output_dir, run_id):
def __init__(self, topics_list, system_prompt, user_prompt, run_id):
self.topics_list = topics_list
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.output_dir = output_dir
self.run_id = run_id
def save_topics(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
json.dump(self.topics_list, f, ensure_ascii=False, indent=4)
logging.info(f"Topics list successfully saved to {path}") # Change to logging
except Exception as e:
# Keep print for traceback, but add logging
logging.exception(f"保存选题失败到 {path}: {e}") # Log exception
# print(f"保存选题失败到 {path}: {e}")
# print("--- Traceback for save_topics error ---")
# traceback.print_exc()
# print("--- End Traceback ---")
return False
return True
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.system_prompt + "\n")
f.write(self.user_prompt + "\n")
# f.write(self.output_dir + "\n") # Output dir not needed in prompt file?
# f.write(self.run_id + "\n") # run_id not needed in prompt file?
logging.info(f"Prompts saved to {path}")
except Exception as e:
logging.exception(f"保存提示词失败: {e}")
# print(f"保存提示词失败: {e}")
return False
return True
class tweetContent:
def __init__(self, result, prompt, output_dir, run_id, article_index, variant_index):
def __init__(self, result, prompt, run_id, article_index, variant_index):
self.result = result
self.prompt = prompt
self.output_dir = output_dir
self.run_id = run_id
self.article_index = article_index
self.variant_index = variant_index
try:
self.title, self.content = self.split_content(result)
self.json_file = self.gen_result_json()
self.json_data = self.gen_result_json()
except Exception as e:
logging.error(f"Failed to parse AI result for {article_index}_{variant_index}: {e}")
logging.debug(f"Raw result: {result[:500]}...") # Log partial raw result
self.title = "[Parsing Error]"
self.content = "[Failed to parse AI content]"
self.json_file = {"title": self.title, "content": self.content, "error": True, "raw_result": result}
self.json_data = {"title": self.title, "content": self.content, "error": True, "raw_result": result}
def split_content(self, result):
# Assuming split logic might still fail, keep it simple or improve with regex/json
@ -122,28 +92,19 @@ class tweetContent:
"title": self.title,
"content": self.content
}
# Add error flag if it exists
if hasattr(self, 'json_data') and self.json_data.get('error'):
json_file['error'] = True
json_file['raw_result'] = self.json_data.get('raw_result')
return json_file
def save_content(self, json_path):
try:
with open(json_path, "w", encoding="utf-8") as f:
# If parsing failed, save the error structure
json.dump(self.json_file, f, ensure_ascii=False, indent=4)
logging.info(f"Content JSON saved to: {json_path}")
except Exception as e:
logging.exception(f"Failed to save content JSON to {json_path}: {e}")
return None # Indicate failure
return json_path
def get_json_data(self):
"""Returns the generated JSON data dictionary."""
return self.json_data
def save_prompt(self, path):
try:
with open(path, "w", encoding="utf-8") as f:
f.write(self.prompt + "\n")
logging.info(f"Content prompt saved to: {path}")
except Exception as e:
logging.exception(f"Failed to save content prompt to {path}: {e}")
return None # Indicate failure
return path
def get_prompt(self):
"""Returns the user prompt used to generate this content."""
return self.prompt
def get_content(self):
return self.content
@ -151,12 +112,9 @@ class tweetContent:
def get_title(self):
return self.title
def get_json_file(self):
return self.json_file
def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in)"""
def generate_topics(ai_agent, system_prompt, user_prompt, run_id, temperature=0.2, top_p=0.5, presence_penalty=1.5):
"""生成选题列表 (run_id is now passed in, output_dir removed as argument)"""
logging.info("Starting topic generation...")
time_start = time.time()
@ -171,26 +129,20 @@ def generate_topics(ai_agent, system_prompt, user_prompt, output_dir, run_id, te
result_list = TopicParser.parse_topics(result)
if not result_list:
logging.warning("Topic parsing resulted in an empty list.")
# Optionally save raw result here if parsing fails?
error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt")
try:
os.makedirs(os.path.dirname(error_log_path), exist_ok=True)
with open(error_log_path, "w", encoding="utf-8") as f_err:
f_err.write("--- Topic Parsing Failed ---\n")
f_err.write(result)
logging.info(f"Saved raw AI output due to parsing failure to: {error_log_path}")
except Exception as log_err:
logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
# Optionally handle raw result logging here if needed, but saving is responsibility of OutputHandler
# error_log_path = os.path.join(output_dir, run_id, f"topic_parsing_error_{run_id}.txt") # output_dir is not available here
# try:
# # ... (save raw output logic) ...
# except Exception as log_err:
# logging.error(f"Failed to save raw AI output on parsing failure: {log_err}")
# Create record object (even if list is empty)
tweet_topic_record = tweetTopicRecord(result_list, system_prompt, user_prompt, output_dir, run_id)
return tweet_topic_record # Return only the record
# 直接返回解析后的列表
return result_list
def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_dir, run_id,
def generate_single_content(ai_agent, system_prompt, user_prompt, item, run_id,
article_index, variant_index, temperature=0.3, top_p=0.4, presence_penalty=1.5):
"""生成单篇文章内容. Requires prompts to be passed in."""
"""Generates single content variant data. Returns (content_json, user_prompt) or (None, None)."""
logging.info(f"Generating content for topic {article_index}, variant {variant_index}")
try:
if not system_prompt or not user_prompt:
@ -201,29 +153,37 @@ def generate_single_content(ai_agent, system_prompt, user_prompt, item, output_d
time.sleep(random.random() * 0.5)
# Generate content (updated return values)
# Generate content (non-streaming work returns result, tokens, time_cost)
result, tokens, time_cost = ai_agent.work(
system_prompt, user_prompt, "", temperature, top_p, presence_penalty
)
if result is None: # Check if AI call failed
logging.error(f"AI agent work failed for {article_index}_{variant_index}. No result returned.")
return None, None
logging.info(f"Content generation for {article_index}_{variant_index} completed in {time_cost:.2f}s. Estimated tokens: {tokens}")
# --- Correct directory structure ---
run_specific_output_dir = os.path.join(output_dir, run_id)
variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
os.makedirs(variant_result_dir, exist_ok=True)
# --- Create tweetContent object (handles parsing) ---
# Pass user_prompt instead of full prompt? Yes, user_prompt is what we need later.
tweet_content = tweetContent(result, user_prompt, run_id, article_index, variant_index)
# Create tweetContent object (handles potential parsing errors inside its __init__)
tweet_content = tweetContent(result, user_prompt, output_dir, run_id, article_index, variant_index)
# Save content and prompt
content_save_path = os.path.join(variant_result_dir, "article.json")
prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
tweet_content.save_content(content_save_path)
tweet_content.save_prompt(prompt_save_path)
# logging.info(f" Saved article content to: {content_save_path}") # Already logged in save_content
# --- Remove Saving Logic ---
# run_specific_output_dir = os.path.join(output_dir, run_id) # output_dir no longer available
# variant_result_dir = os.path.join(run_specific_output_dir, f"{article_index}_{variant_index}")
# os.makedirs(variant_result_dir, exist_ok=True)
# content_save_path = os.path.join(variant_result_dir, "article.json")
# prompt_save_path = os.path.join(variant_result_dir, "tweet_prompt.txt")
# tweet_content.save_content(content_save_path) # Method removed
# tweet_content.save_prompt(prompt_save_path) # Method removed
# --- End Remove Saving Logic ---
# Return the data needed by the output handler
content_json = tweet_content.get_json_data()
prompt_data = tweet_content.get_prompt() # Get the stored user prompt
return content_json, prompt_data # Return data pair
return tweet_content, result # Return object and raw result
except Exception as e:
logging.exception(f"Error generating single content for {article_index}_{variant_index}:")
return None, None
@ -256,7 +216,7 @@ def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompt
# 调用单篇文章生成函数
tweet_content, result = generate_single_content(
ai_agent, system_prompt, item, output_dir, run_id, i+1, j+1, temperature
ai_agent, system_prompt, item, run_id, i+1, j+1, temperature
)
if tweet_content:
@ -270,207 +230,248 @@ def generate_content(ai_agent, system_prompt, topics, output_dir, run_id, prompt
return processed_results
def prepare_topic_generation(
config # Pass the whole config dictionary now
# select_date, select_num,
# system_prompt_path, user_prompt_path,
# base_url="vllm", model_name="qwenQWQ", api_key="EMPTY",
# gen_prompts_path="/root/autodl-tmp/TravelContentCreator/genPrompts",
# resource_dir="/root/autodl-tmp/TravelContentCreator/resource",
# output_dir="/root/autodl-tmp/TravelContentCreator/result"
):
"""准备选题生成的环境和参数. Returns agent and prompts."""
def prepare_topic_generation(prompt_manager: PromptManager,
api_url: str,
model_name: str,
api_key: str,
timeout: int,
max_retries: int):
"""准备选题生成的环境和参数. Returns agent, system_prompt, user_prompt.
# Initialize PromptManager
prompt_manager = PromptManager(config)
# Get prompts using PromptManager
Args:
prompt_manager: An initialized PromptManager instance.
api_url, model_name, api_key, timeout, max_retries: Parameters for AI_Agent.
"""
logging.info("Preparing for topic generation (using provided PromptManager)...")
# 从传入的 PromptManager 获取 prompts
system_prompt, user_prompt = prompt_manager.get_topic_prompts()
if not system_prompt or not user_prompt:
print("Error: Failed to get topic generation prompts.")
return None, None, None, None
logging.error("Failed to get topic generation prompts from PromptManager.")
return None, None, None # Return three Nones
# 创建AI Agent (still create agent here for the topic generation phase)
# 使用传入的参数初始化 AI Agent
try:
logging.info("Initializing AI Agent for topic generation...")
# --- Read timeout/retry from config ---
request_timeout = config.get("request_timeout", 30) # Default 30 seconds
max_retries = config.get("max_retries", 3) # Default 3 retries
# --- Pass values to AI_Agent ---
ai_agent = AI_Agent(
config["api_url"],
config["model"],
config["api_key"],
timeout=request_timeout,
max_retries=max_retries
api_url, # Use passed arg
model_name, # Use passed arg
api_key, # Use passed arg
timeout=timeout, # Use passed arg
max_retries=max_retries # Use passed arg
)
except Exception as e:
logging.exception("Error initializing AI Agent for topic generation:")
traceback.print_exc()
return None, None, None, None
return None, None, None # Return three Nones
# Removed prompt loading/building logic, now handled by PromptManager
# Return agent and the generated prompts
return ai_agent, system_prompt, user_prompt, config["output_dir"]
# 返回 agent 和 prompts
return ai_agent, system_prompt, user_prompt
def run_topic_generation_pipeline(config, run_id=None):
"""Runs the complete topic generation pipeline based on the configuration."""
"""
Runs the complete topic generation pipeline based on the configuration.
Returns: (run_id, topics_list, system_prompt, user_prompt) or (None, None, None, None) on failure.
"""
logging.info("Starting Step 1: Topic Generation Pipeline...")
# --- Handle run_id ---
if run_id is None:
logging.info("No run_id provided, generating one based on timestamp.")
run_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
else:
logging.info(f"Using provided run_id: {run_id}")
# --- End run_id handling ---
# Prepare necessary inputs and the AI agent for topic generation
ai_agent, system_prompt, user_prompt, base_output_dir = None, None, None, None
ai_agent, system_prompt, user_prompt = None, None, None # Initialize
topics_list = None
prompt_manager = None # Initialize prompt_manager
try:
# Pass the config directly to prepare_topic_generation
ai_agent, system_prompt, user_prompt, base_output_dir = prepare_topic_generation(config)
# --- 读取 PromptManager 所需参数 ---
topic_sys_prompt_path = config.get("topic_system_prompt")
topic_user_prompt_path = config.get("topic_user_prompt")
content_sys_prompt_path = config.get("content_system_prompt") # 虽然这里不用,但 PromptManager 可能需要
prompts_dir_path = config.get("prompts_dir")
resource_config = config.get("resource_dir", [])
topic_num = config.get("num", 1)
topic_date = config.get("date", "")
# --- 创建 PromptManager 实例 ---
prompt_manager = PromptManager(
topic_system_prompt_path=topic_sys_prompt_path,
topic_user_prompt_path=topic_user_prompt_path,
content_system_prompt_path=content_sys_prompt_path,
prompts_dir=prompts_dir_path,
resource_dir_config=resource_config,
topic_gen_num=topic_num,
topic_gen_date=topic_date
)
logging.info("PromptManager instance created.")
# --- 读取 AI Agent 所需参数 ---
ai_api_url = config.get("api_url")
ai_model = config.get("model")
ai_api_key = config.get("api_key")
ai_timeout = config.get("request_timeout", 30)
ai_max_retries = config.get("max_retries", 3)
# 检查必需的 AI 参数是否存在
if not all([ai_api_url, ai_model, ai_api_key]):
raise ValueError("Missing required AI configuration (api_url, model, api_key) in config.")
# --- 调用修改后的 prepare_topic_generation ---
ai_agent, system_prompt, user_prompt = prepare_topic_generation(
prompt_manager, # Pass instance
ai_api_url,
ai_model,
ai_api_key,
ai_timeout,
ai_max_retries
)
if not ai_agent or not system_prompt or not user_prompt:
raise ValueError("Failed to prepare topic generation (agent or prompts missing).")
except Exception as e:
logging.exception("Error during topic generation preparation:")
traceback.print_exc()
return None, None
# Generate topics using the prepared agent and prompts
try:
# Pass the determined run_id to generate_topics
tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config["output_dir"],
run_id, # Pass the run_id
# --- Generate topics (保持不变) ---
topics_list = generate_topics(
ai_agent, system_prompt, user_prompt,
run_id, # Pass run_id
config.get("topic_temperature", 0.2),
config.get("topic_top_p", 0.5),
config.get("topic_presence_penalty", 1.5)
)
except Exception as e:
logging.exception("Error during topic generation API call:")
traceback.print_exc()
if ai_agent: ai_agent.close() # Ensure agent is closed on error
return None, None
logging.exception("Error during topic generation pipeline execution:")
# Ensure agent is closed even if generation fails mid-way
if ai_agent: ai_agent.close()
return None, None, None, None # Signal failure
finally:
# Ensure the AI agent is closed after generation attempt (if initialized)
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
# Ensure the AI agent is closed after generation
if ai_agent:
logging.info("Closing topic generation AI Agent...")
ai_agent.close()
if topics_list is None: # Check if generate_topics returned None (though it currently returns list)
logging.error("Topic generation failed (generate_topics returned None or an error occurred).")
return None, None, None, None
elif not topics_list: # Check if list is empty
logging.warning(f"Topic generation completed for run {run_id}, but the resulting topic list is empty.")
# Return empty list and prompts anyway, let caller decide
# Process results
if not tweet_topic_record:
logging.error("Topic generation failed (generate_topics returned None).")
return None, None # Return None for run_id as well if record is None
# --- Saving logic removed previously ---
# Use the determined run_id for output directory
output_dir = os.path.join(config["output_dir"], run_id)
try:
os.makedirs(output_dir, exist_ok=True)
# --- Debug: Print the data before attempting to save ---
logging.info("--- Debug: Data to be saved in tweet_topic.json ---")
logging.info(tweet_topic_record.topics_list)
logging.info("--- End Debug ---")
# --- End Debug ---
# Save topics and prompt details
save_topics_success = tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
save_prompt_success = tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
if not save_topics_success or not save_prompt_success:
logging.warning("Warning: Failed to save topic generation results or prompts.")
# Continue but warn user
except Exception as e:
logging.exception("Error saving topic generation results:")
traceback.print_exc()
# Return the generated data even if saving fails, but maybe warn more strongly?
# return run_id, tweet_topic_record # Decide if partial success is okay
return None, None # Or consider failure if saving is critical
logging.info(f"Topics generated successfully. Run ID: {run_id}")
# Return the determined run_id and the record
return run_id, tweet_topic_record
logging.info(f"Topic generation pipeline completed successfully. Run ID: {run_id}")
# Return the raw data needed by the OutputHandler
return run_id, topics_list, system_prompt, user_prompt
# --- Decoupled Functional Units (Moved from main.py) ---
def generate_content_for_topic(ai_agent, prompt_manager, config, topic_item, output_dir, run_id, topic_index):
"""Generates all content variants for a single topic item.
def generate_content_for_topic(ai_agent: AI_Agent,
prompt_manager: PromptManager,
topic_item: dict,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # Changed name to match convention
# 添加具体参数,移除 config 和 output_dir
variants: int,
temperature: float,
top_p: float,
presence_penalty: float):
"""Generates all content variants for a single topic item and uses OutputHandler.
Args:
ai_agent: An initialized AI_Agent instance.
prompt_manager: An initialized PromptManager instance.
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
ai_agent: Initialized AI_Agent instance.
prompt_manager: Initialized PromptManager instance.
topic_item: Dictionary representing the topic.
run_id: Current run ID.
topic_index: 1-based index of the topic.
output_handler: An instance of OutputHandler to process results.
variants: Number of variants to generate.
temperature, top_p, presence_penalty: AI generation parameters.
Returns:
A list of tweet content data (dictionaries) generated for the topic,
or None if generation failed.
bool: True if at least one variant was successfully generated and handled, False otherwise.
"""
logging.info(f"Generating content for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
tweet_content_list = []
variants = config.get("variants", 1)
success_flag = False # Track if any variant succeeded
# 使用传入的 variants 参数
# variants = config.get("variants", 1)
for j in range(variants):
variant_index = j + 1
logging.info(f"Generating Variant {variant_index}/{variants}...")
logging.info(f" Generating Variant {variant_index}/{variants}...")
# Get prompts for this specific topic item
# Assuming prompt_manager is correctly initialized and passed
# PromptManager 实例已传入,直接调用
content_system_prompt, content_user_prompt = prompt_manager.get_content_prompts(topic_item)
if not content_system_prompt or not content_user_prompt:
logging.warning(f"Skipping Variant {variant_index} due to missing content prompts.")
continue # Skip this variant
logging.warning(f" Skipping Variant {variant_index} due to missing content prompts.")
continue
time.sleep(random.random() * 0.5)
try:
# Call the core generation function (generate_single_content is in this file)
tweet_content, gen_result = generate_single_content(
# Call generate_single_content with passed-in parameters
content_json, prompt_data = generate_single_content(
ai_agent, content_system_prompt, content_user_prompt, topic_item,
output_dir, run_id, topic_index, variant_index,
config.get("content_temperature", 0.3),
config.get("content_top_p", 0.4),
config.get("content_presence_penalty", 1.5)
run_id, topic_index, variant_index,
temperature, # 使用传入的参数
top_p, # 使用传入的参数
presence_penalty # 使用传入的参数
)
if tweet_content:
try:
tweet_content_data = tweet_content.get_json_file()
if tweet_content_data:
tweet_content_list.append(tweet_content_data)
else:
logging.warning(f"Warning: tweet_content.get_json_file() for Topic {topic_index}, Variant {variant_index} returned empty data.")
except Exception as parse_err:
logging.error(f"Error processing tweet content after generation for Topic {topic_index}, Variant {variant_index}: {parse_err}")
# Check if generation succeeded and parsing was okay (or error handled within json)
if content_json is not None and prompt_data is not None:
# Use the output handler to process/save the result
output_handler.handle_content_variant(
run_id, topic_index, variant_index, content_json, prompt_data
)
success_flag = True # Mark success for this topic
# Check specifically if the AI result itself indicated a parsing error internally
if content_json.get("error"):
logging.error(f" Content generation for Topic {topic_index}, Variant {variant_index} succeeded but response parsing failed (error flag set in content). Raw data logged by handler.")
else:
logging.info(f" Successfully generated and handled content for Topic {topic_index}, Variant {variant_index}.")
else:
logging.warning(f"Failed to generate content for Topic {topic_index}, Variant {variant_index}. Skipping.")
logging.error(f" Content generation failed for Topic {topic_index}, Variant {variant_index}. Skipping handling.")
except Exception as e:
logging.exception(f"Error during content generation for Topic {topic_index}, Variant {variant_index}:")
# traceback.print_exc()
logging.exception(f" Error during content generation call or handling for Topic {topic_index}, Variant {variant_index}:")
if not tweet_content_list:
logging.warning(f"No valid content generated for Topic {topic_index}.")
return None
else:
logging.info(f"Successfully generated {len(tweet_content_list)} content variants for Topic {topic_index}.")
return tweet_content_list
# Return the success flag for this topic
return success_flag
def generate_posters_for_topic(config, topic_item, tweet_content_list, output_dir, run_id, topic_index):
"""Generates all posters for a single topic item based on its generated content.
def generate_posters_for_topic(topic_item: dict,
output_dir: str,
run_id: str,
topic_index: int,
output_handler: OutputHandler, # 添加 handler
variants: int,
poster_assets_base_dir: str,
image_base_dir: str,
modify_image_subdir: str,
resource_dir_config: list,
poster_target_size: tuple,
text_possibility: float,
output_collage_subdir: str,
output_poster_subdir: str,
output_poster_filename: str,
camera_image_subdir: str
):
"""Generates all posters for a single topic item, handling image data via OutputHandler.
Args:
config: The global configuration dictionary.
topic_item: The dictionary representing a single topic.
tweet_content_list: List of content data generated by generate_content_for_topic.
output_dir: The base output directory for the entire run (e.g., ./result).
run_id: The ID for the current run.
topic_index: The 1-based index of the current topic.
variants: Number of variants.
poster_assets_base_dir: Path to poster assets (fonts, frames etc.).
image_base_dir: Base path for source images.
modify_image_subdir: Subdirectory for modified images.
resource_dir_config: Configuration for resource directories (used for Description).
poster_target_size: Target size tuple (width, height) for the poster.
text_possibility: Probability of adding secondary text.
output_collage_subdir: Subdirectory name for saving collages.
output_poster_subdir: Subdirectory name for saving posters.
output_poster_filename: Filename for the final poster.
camera_image_subdir: Subdirectory for camera images (currently unused in logic?).
output_handler: An instance of OutputHandler to process results.
Returns:
True if poster generation was attempted (regardless of individual variant success),
@ -478,30 +479,47 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di
"""
logging.info(f"Generating posters for Topic {topic_index} (Object: {topic_item.get('object', 'N/A')})...")
# Initialize necessary generators here, assuming they are stateless or cheap to create
# Alternatively, pass initialized instances if they hold state or are expensive
# --- Load content data from files ---
loaded_content_list = []
logging.info(f"Attempting to load content data for {variants} variants for topic {topic_index}...")
for j in range(variants):
variant_index = j + 1
variant_dir = os.path.join(output_dir, run_id, f"{topic_index}_{variant_index}")
content_path = os.path.join(variant_dir, "article.json")
try:
if os.path.exists(content_path):
with open(content_path, 'r', encoding='utf-8') as f_content:
content_data = json.load(f_content)
if isinstance(content_data, dict) and 'title' in content_data and 'content' in content_data:
loaded_content_list.append(content_data)
logging.debug(f" Successfully loaded content from: {content_path}")
else:
logging.warning(f" Content file {content_path} has invalid format. Skipping.")
else:
logging.warning(f" Content file not found for variant {variant_index}: {content_path}. Skipping.")
except json.JSONDecodeError:
logging.error(f" Error decoding JSON from content file: {content_path}. Skipping.")
except Exception as e:
logging.exception(f" Error loading content file {content_path}: {e}")
if not loaded_content_list:
logging.error(f"No valid content data loaded for topic {topic_index}. Cannot generate posters.")
return False
logging.info(f"Successfully loaded content data for {len(loaded_content_list)} variants.")
# --- End Load content data ---
# Initialize generators using parameters
try:
content_gen_instance = core_contentGen.ContentGenerator()
# poster_gen_instance = core_posterGen.PosterGenerator()
# --- Read poster assets base dir from config ---
poster_assets_base_dir = config.get("poster_assets_base_dir")
if not poster_assets_base_dir:
logging.error("Error: 'poster_assets_base_dir' not found in configuration. Cannot generate posters.")
return False # Cannot proceed without assets base dir
# --- Initialize PosterGenerator with the base dir ---
logging.error("Error: 'poster_assets_base_dir' not provided. Cannot generate posters.")
return False
poster_gen_instance = core_posterGen.PosterGenerator(base_dir=poster_assets_base_dir)
except Exception as e:
logging.exception("Error initializing generators for poster creation:")
return False
# --- Setup: Paths and Object Name ---
image_base_dir = config.get("image_base_dir")
if not image_base_dir:
logging.error("Error: image_base_dir missing in config for poster generation.")
return False
modify_image_subdir = config.get("modify_image_subdir", "modify")
camera_image_subdir = config.get("camera_image_subdir", "相机")
# --- Setup: Paths and Object Name ---
object_name = topic_item.get("object", "")
if not object_name:
logging.warning("Warning: Topic object name is missing. Cannot generate posters.")
@ -513,110 +531,95 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di
if not object_name_cleaned:
logging.warning(f"Warning: Object name '{object_name}' resulted in empty string after cleaning. Skipping posters.")
return False
object_name = object_name_cleaned # Use the cleaned name for searching
object_name = object_name_cleaned
except Exception as e:
logging.warning(f"Warning: Could not fully clean object name '{object_name}': {e}. Skipping posters.")
return False
# Construct and check INPUT image paths (still needed for collage)
input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name)
# Construct and check INPUT image paths
input_img_dir_path = os.path.join(image_base_dir, modify_image_subdir, object_name)
if not os.path.exists(input_img_dir_path) or not os.path.isdir(input_img_dir_path):
logging.warning(f"Warning: Modify Image directory not found or not a directory: '{input_img_dir_path}'. Skipping posters for this topic.")
return False
# --- NEW: Locate Description File using resource_dir type "Description" ---
# Locate Description File using resource_dir_config parameter
info_directory = []
description_file_path = None
resource_dir_config = config.get("resource_dir", [])
found_description = False
for dir_info in resource_dir_config:
if dir_info.get("type") == "Description":
for file_path in dir_info.get("file_path", []):
# Match description file based on object name containment
if object_name in os.path.basename(file_path):
description_file_path = file_path
if os.path.exists(description_file_path):
info_directory = [description_file_path] # Pass the found path
info_directory = [description_file_path]
logging.info(f"Found and using description file from config: {description_file_path}")
found_description = True
else:
logging.warning(f"Warning: Description file specified in config not found: {description_file_path}")
break # Found the matching entry in this list
if found_description: # Stop searching resource_dir if found
break
if found_description:
break
if not found_description:
logging.info(f"Warning: No matching description file found for object '{object_name}' in config resource_dir (type='Description').")
# --- End NEW Description File Logic ---
# --- Generate Text Configurations for All Variants ---
# Generate Text Configurations for All Variants
try:
poster_text_configs_raw = content_gen_instance.run(info_directory, config["variants"], tweet_content_list)
poster_text_configs_raw = content_gen_instance.run(info_directory, variants, loaded_content_list)
if not poster_text_configs_raw:
logging.warning("Warning: ContentGenerator returned empty configuration data. Skipping posters.")
return False
# --- Save the COMPLETE poster configs list for this topic ---
run_output_dir_base = os.path.join(output_dir, run_id)
topic_config_save_path = os.path.join(run_output_dir_base, f"topic_{topic_index}_poster_configs.json")
try:
# Assuming poster_text_configs_raw is JSON-serializable (likely a list/dict)
with open(topic_config_save_path, 'w', encoding='utf-8') as f_cfg_topic:
json.dump(poster_text_configs_raw, f_cfg_topic, ensure_ascii=False, indent=4)
logging.info(f"Saved complete poster configurations for topic {topic_index} to: {topic_config_save_path}")
except Exception as save_err:
logging.error(f"Failed to save complete poster configurations for topic {topic_index} to {topic_config_save_path}: {save_err}")
# --- End Save Complete Config ---
# --- 使用 OutputHandler 保存 Poster Config ---
output_handler.handle_poster_configs(run_id, topic_index, poster_text_configs_raw)
# --- 结束使用 Handler 保存 ---
poster_config_summary = core_posterGen.PosterConfig(poster_text_configs_raw)
except Exception as e:
logging.exception("Error running ContentGenerator or parsing poster configs:")
traceback.print_exc()
return False # Cannot proceed if text config fails
return False
# --- Poster Generation Loop for each variant ---
poster_num = config.get("variants", 1)
target_size = tuple(config.get("poster_target_size", [900, 1200]))
# Poster Generation Loop for each variant
poster_num = variants
any_poster_attempted = False
text_possibility = config.get("text_possibility", 0.3) # Get from config
for j_index in range(poster_num):
variant_index = j_index + 1
logging.info(f"Generating Poster {variant_index}/{poster_num}...")
any_poster_attempted = True
collage_img = None # To store the generated collage PIL Image
poster_img = None # To store the final poster PIL Image
try:
poster_config = poster_config_summary.get_config_by_index(j_index)
if not poster_config:
logging.warning(f"Warning: Could not get poster config for index {j_index}. Skipping.")
continue
# Define output directories for this specific variant
run_output_dir = os.path.join(output_dir, run_id)
variant_output_dir = os.path.join(run_output_dir, f"{topic_index}_{variant_index}")
output_collage_subdir = config.get("output_collage_subdir", "collage_img")
output_poster_subdir = config.get("output_poster_subdir", "poster")
collage_output_dir = os.path.join(variant_output_dir, output_collage_subdir)
poster_output_dir = os.path.join(variant_output_dir, output_poster_subdir)
os.makedirs(collage_output_dir, exist_ok=True)
os.makedirs(poster_output_dir, exist_ok=True)
# --- Image Collage ---
# --- Image Collage ---
logging.info(f"Generating collage from: {input_img_dir_path}")
img_list = core_simple_collage.process_directory(
collage_images = core_simple_collage.process_directory(
input_img_dir_path,
target_size=target_size,
output_count=1,
output_dir=collage_output_dir
target_size=poster_target_size,
output_count=1
)
if not img_list or len(img_list) == 0 or not img_list[0].get('path'):
if not collage_images: # 检查列表是否为空
logging.warning(f"Warning: Failed to generate collage image for Variant {variant_index}. Skipping poster.")
continue
collage_img_path = img_list[0]['path']
logging.info(f"Using collage image: {collage_img_path}")
collage_img = collage_images[0] # 获取第一个 PIL Image
logging.info(f"Collage image generated successfully (in memory).")
# --- 使用 Handler 保存 Collage 图片 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='collage',
image_data=collage_img,
output_filename='collage.png' # 或者其他期望的文件名
)
# --- 结束保存 Collage ---
# --- Create Poster ---
# --- Create Poster ---
text_data = {
"title": poster_config.get('main_title', 'Default Title'),
"subtitle": "",
@ -624,98 +627,37 @@ def generate_posters_for_topic(config, topic_item, tweet_content_list, output_di
}
texts = poster_config.get('texts', [])
if texts:
# Ensure TEXT_POSBILITY is accessible, maybe pass via config?
# text_possibility = config.get("text_possibility", 0.3)
text_data["additional_texts"].append({"text": texts[0], "position": "bottom", "size_factor": 0.5})
if len(texts) > 1 and random.random() < text_possibility: # Use variable from config
if len(texts) > 1 and random.random() < text_possibility:
text_data["additional_texts"].append({"text": texts[1], "position": "bottom", "size_factor": 0.5})
# final_poster_path = os.path.join(poster_output_dir, "poster.jpg") # Filename "poster.jpg" is hardcoded
output_poster_filename = config.get("output_poster_filename", "poster.jpg")
final_poster_path = os.path.join(poster_output_dir, output_poster_filename)
result_path = poster_gen_instance.create_poster(collage_img_path, text_data, final_poster_path) # Uses hardcoded output filename
if result_path:
logging.info(f"Successfully generated poster: {result_path}")
# 调用修改后的 create_poster, 接收 PIL Image
poster_img = poster_gen_instance.create_poster(collage_img, text_data)
if poster_img:
logging.info(f"Poster image generated successfully (in memory).")
# --- 使用 Handler 保存 Poster 图片 ---
output_handler.handle_generated_image(
run_id, topic_index, variant_index,
image_type='poster',
image_data=poster_img,
output_filename=output_poster_filename # 使用参数中的文件名
)
# --- 结束保存 Poster ---
else:
logging.warning(f"Warning: Poster generation function did not return a valid path for {final_poster_path}.")
logging.warning(f"Warning: Poster generation function returned None for variant {variant_index}.")
except Exception as e:
logging.exception(f"Error during poster generation for Variant {variant_index}:")
traceback.print_exc()
continue # Continue to next variant
continue
return any_poster_attempted
def main():
"""主函数入口"""
config_file = {
"date": "4月17日",
"num": 5,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
"topic_system_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/systemPrompt.txt",
"topic_user_prompt": "/root/autodl-tmp/TravelContentCreator/SelectPrompt/userPrompt.txt",
"content_system_prompt": "/root/autodl-tmp/TravelContentCreator/genPrompts/systemPrompt.txt",
"resource_dir": [{
"type": "Object",
"num": 4,
"file_path": ["/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-尚书第.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-明清园.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-泰宁古城.txt",
"/root/autodl-tmp/TravelContentCreator/resource/Object/景点信息-甘露寺.txt"
]},
{
"type": "Product",
"num": 0,
"file_path": []
}
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 2,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
if True:
# 1. 首先生成选题
ai_agent, system_prompt, user_prompt, output_dir = prepare_topic_generation(
config_file
)
run_id, tweet_topic_record = generate_topics(
ai_agent, system_prompt, user_prompt, config_file["output_dir"],
config_file["topic_temperature"], 0.5, 1.5
)
output_dir = os.path.join(config_file["output_dir"], run_id)
os.makedirs(output_dir, exist_ok=True)
tweet_topic_record.save_topics(os.path.join(output_dir, "tweet_topic.json"))
tweet_topic_record.save_prompt(os.path.join(output_dir, "tweet_prompt.txt"))
# raise Exception("选题生成失败,退出程序")
if not run_id or not tweet_topic_record:
print("选题生成失败,退出程序")
return
# 2. 然后生成内容
print("\n开始根据选题生成内容...")
# 加载内容生成的系统提示词
content_system_prompt = ResourceLoader.load_system_prompt(config_file["content_system_prompt"])
if not content_system_prompt:
print("内容生成系统提示词为空,使用选题生成的系统提示词")
content_system_prompt = system_prompt
# 直接使用同一个AI Agent实例
result = generate_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list, output_dir, run_id, config_file["prompts_dir"], config_file["resource_dir"],
config_file["variants"], config_file["content_temperature"]
)
if __name__ == "__main__":
main()
# main 函数不再使用,注释掉或移除
# def main():
# """主函数入口"""
# # ... (旧的 main 函数逻辑)
#
# if __name__ == "__main__":
# main()