调整了主要流程

This commit is contained in:
jinye_huang 2025-04-21 09:33:25 +08:00
parent 9e8a63be54
commit 549223ab53
4 changed files with 93 additions and 64 deletions

View File

@ -16,18 +16,42 @@ class PosterInfo:
class PosterConfig:
def __init__(self, config_path):
self.config_path = config_path
try:
if isinstance(config_path, str) and os.path.exists(config_path):
# 如果是文件路径,从文件读取
print(f"从文件加载配置: {config_path}")
self.config = json.load(open(config_path, "r", encoding="utf-8"))
else:
# 如果是字符串但不是文件路径,尝试直接解析
print("尝试直接解析配置字符串")
self.config = json.loads(config_path)
self.img_list = []
for item in self.config:
print(item)
self.img_list.append([item['index'], item["main_title"], item["texts"]])
# self.img_list.append(PosterInfo(item['img_url'], item['main_title'], item['texts']))
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
print(f"尝试解析的内容: {config_path[:100]}...") # 只打印前100个字符
# 创建一个默认配置
self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}]
self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]]
print("使用默认配置")
except Exception as e:
print(f"加载配置时出错: {e}")
# 创建一个默认配置
self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}]
self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]]
print("使用默认配置")
def get_config(self):
return self.config
def get_config_by_index(self, index):
if index >= len(self.config):
print(f"警告: 索引 {index} 超出配置范围,使用默认配置")
return self.config[0]
return self.config[index]
class PosterGenerator:

11
main.py
View File

@ -18,8 +18,8 @@ TEXT_POSBILITY = 0.3
def main():
config_file = {
"date": "4月17",
"num": 5,
"date": "4月24",
"num": 10,
"model": "qwenQWQ",
"api_url": "vllm",
"api_key": "EMPTY",
@ -42,7 +42,7 @@ def main():
],
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
"variants": 5,
"variants": 10,
"topic_temperature": 0.2,
"content_temperature": 0.3
}
@ -133,6 +133,7 @@ def main():
content_gen = contentGen.ContentGenerator()
response = content_gen.run(info_directory, poster_num, tweet_content_list)
print(response)
try:
poster_config_summary = posterGen.PosterConfig(response)
for j_index in range(config_file["variants"]):
poster_config = poster_config_summary.get_config_by_index(j_index)
@ -187,6 +188,10 @@ def main():
print(f"海报生成过程中出错: {e}")
traceback.print_exc() # 打印完整的堆栈跟踪信息
continue
except Exception as e:
print(f"配置解析失败,跳过当前项目: {e}")
traceback.print_exc() # 打印完整的堆栈跟踪信息
continue
if __name__ == "__main__":
main()