diff --git a/core/__pycache__/posterGen.cpython-312.pyc b/core/__pycache__/posterGen.cpython-312.pyc index 293e5a7..7d4de3b 100644 Binary files a/core/__pycache__/posterGen.cpython-312.pyc and b/core/__pycache__/posterGen.cpython-312.pyc differ diff --git a/core/posterGen.py b/core/posterGen.py index 0d61f83..dde2be7 100644 --- a/core/posterGen.py +++ b/core/posterGen.py @@ -16,18 +16,42 @@ class PosterInfo: class PosterConfig: def __init__(self, config_path): self.config_path = config_path - self.config = json.load(open(config_path, "r", encoding="utf-8")) - - self.img_list = [] - for item in self.config: - print(item) - self.img_list.append([item['index'], item["main_title"], item["texts"]]) - # self.img_list.append(PosterInfo(item['img_url'], item['main_title'], item['texts'])) + try: + if isinstance(config_path, str) and os.path.exists(config_path): + # 如果是文件路径,从文件读取 + print(f"从文件加载配置: {config_path}") + self.config = json.load(open(config_path, "r", encoding="utf-8")) + else: + # 如果是字符串但不是文件路径,尝试直接解析 + print("尝试直接解析配置字符串") + self.config = json.loads(config_path) + self.img_list = [] + for item in self.config: + print(item) + self.img_list.append([item['index'], item["main_title"], item["texts"]]) + # self.img_list.append(PosterInfo(item['img_url'], item['main_title'], item['texts'])) + except json.JSONDecodeError as e: + print(f"JSON解析错误: {e}") + print(f"尝试解析的内容: {config_path[:100]}...") # 只打印前100个字符 + # 创建一个默认配置 + self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}] + self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]] + print("使用默认配置") + except Exception as e: + print(f"加载配置时出错: {e}") + # 创建一个默认配置 + self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}] + self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]] + print("使用默认配置") + def get_config(self): return self.config def get_config_by_index(self, index): + if index >= len(self.config): + print(f"警告: 索引 {index} 超出配置范围,使用默认配置") + return self.config[0] return self.config[index] class PosterGenerator: diff --git a/main.py b/main.py index 82dbeb4..273569e 100644 --- a/main.py +++ b/main.py @@ -18,8 +18,8 @@ TEXT_POSBILITY = 0.3 def main(): config_file = { - "date": "4月17日", - "num": 5, + "date": "4月24日", + "num": 10, "model": "qwenQWQ", "api_url": "vllm", "api_key": "EMPTY", @@ -42,7 +42,7 @@ def main(): ], "prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts", "output_dir": "/root/autodl-tmp/TravelContentCreator/result", - "variants": 5, + "variants": 10, "topic_temperature": 0.2, "content_temperature": 0.3 } @@ -133,60 +133,65 @@ def main(): content_gen = contentGen.ContentGenerator() response = content_gen.run(info_directory, poster_num, tweet_content_list) print(response) - poster_config_summary = posterGen.PosterConfig(response) - for j_index in range(config_file["variants"]): - poster_config = poster_config_summary.get_config_by_index(j_index) - img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}") - try: - # 创建输出目录 - collage_output_dir = os.path.join(img_dir, "collage_img") - os.makedirs(collage_output_dir, exist_ok=True) - poster_output_dir = os.path.join(img_dir, "poster") - os.makedirs(poster_output_dir, exist_ok=True) - - # 处理图片目录 - img_list = simple_collage.process_directory( - input_dir, - target_size=target_size, - output_count=1, - output_dir=collage_output_dir - ) - print(img_list) - - if not img_list or len(img_list) == 0: - print(f"未能生成拼贴图片,跳过海报生成") + try: + poster_config_summary = posterGen.PosterConfig(response) + for j_index in range(config_file["variants"]): + poster_config = poster_config_summary.get_config_by_index(j_index) + img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}") + try: + # 创建输出目录 + collage_output_dir = os.path.join(img_dir, "collage_img") + os.makedirs(collage_output_dir, exist_ok=True) + poster_output_dir = os.path.join(img_dir, "poster") + os.makedirs(poster_output_dir, exist_ok=True) + + # 处理图片目录 + img_list = simple_collage.process_directory( + input_dir, + target_size=target_size, + output_count=1, + output_dir=collage_output_dir + ) + print(img_list) + + if not img_list or len(img_list) == 0: + print(f"未能生成拼贴图片,跳过海报生成") + continue + + # 生成海报 + poster_gen = posterGen.PosterGenerator() + + if random.random() < TEXT_POSBILITY: + text_data = { + "title": f"{poster_config['main_title']}", + "subtitle": "", + "additional_texts": [ + {"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5}, + {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5} + ] + } + else: + text_data = { + "title": f"{poster_config['main_title']}", + "subtitle": "", + "additional_texts": [ + {"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5}, + # {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5} + ] + } + print(text_data) + img_path = img_list[0]['path'] + print(f"使用图片路径: {img_path}") + output_path = os.path.join(poster_output_dir, f"poster.jpg") + result_path.append(poster_gen.create_poster(img_path, text_data, output_path)) + except Exception as e: + print(f"海报生成过程中出错: {e}") + traceback.print_exc() # 打印完整的堆栈跟踪信息 continue - - # 生成海报 - poster_gen = posterGen.PosterGenerator() - - if random.random() < TEXT_POSBILITY: - text_data = { - "title": f"{poster_config['main_title']}", - "subtitle": "", - "additional_texts": [ - {"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5}, - {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5} - ] - } - else: - text_data = { - "title": f"{poster_config['main_title']}", - "subtitle": "", - "additional_texts": [ - {"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5}, - # {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5} - ] - } - print(text_data) - img_path = img_list[0]['path'] - print(f"使用图片路径: {img_path}") - output_path = os.path.join(poster_output_dir, f"poster.jpg") - result_path.append(poster_gen.create_poster(img_path, text_data, output_path)) - except Exception as e: - print(f"海报生成过程中出错: {e}") - traceback.print_exc() # 打印完整的堆栈跟踪信息 - continue + except Exception as e: + print(f"配置解析失败,跳过当前项目: {e}") + traceback.print_exc() # 打印完整的堆栈跟踪信息 + continue if __name__ == "__main__": main() diff --git a/utils/tweet_generator.py b/utils/tweet_generator.py index 62490c1..4719fc8 100644 --- a/utils/tweet_generator.py +++ b/utils/tweet_generator.py @@ -215,7 +215,7 @@ def prepare_topic_generation( """准备选题生成的环境和参数""" # 创建AI Agent ai_agent = AI_Agent(base_url, model_name, api_key) - + # 加载系统提示词 with open(system_prompt_path, "r", encoding="utf-8") as f: system_prompt = f.read()