调整了主要流程
This commit is contained in:
parent
9e8a63be54
commit
549223ab53
Binary file not shown.
@ -16,18 +16,42 @@ class PosterInfo:
|
|||||||
class PosterConfig:
|
class PosterConfig:
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path):
|
||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
self.config = json.load(open(config_path, "r", encoding="utf-8"))
|
try:
|
||||||
|
if isinstance(config_path, str) and os.path.exists(config_path):
|
||||||
self.img_list = []
|
# 如果是文件路径,从文件读取
|
||||||
for item in self.config:
|
print(f"从文件加载配置: {config_path}")
|
||||||
print(item)
|
self.config = json.load(open(config_path, "r", encoding="utf-8"))
|
||||||
self.img_list.append([item['index'], item["main_title"], item["texts"]])
|
else:
|
||||||
# self.img_list.append(PosterInfo(item['img_url'], item['main_title'], item['texts']))
|
# 如果是字符串但不是文件路径,尝试直接解析
|
||||||
|
print("尝试直接解析配置字符串")
|
||||||
|
self.config = json.loads(config_path)
|
||||||
|
|
||||||
|
self.img_list = []
|
||||||
|
for item in self.config:
|
||||||
|
print(item)
|
||||||
|
self.img_list.append([item['index'], item["main_title"], item["texts"]])
|
||||||
|
# self.img_list.append(PosterInfo(item['img_url'], item['main_title'], item['texts']))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"JSON解析错误: {e}")
|
||||||
|
print(f"尝试解析的内容: {config_path[:100]}...") # 只打印前100个字符
|
||||||
|
# 创建一个默认配置
|
||||||
|
self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}]
|
||||||
|
self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]]
|
||||||
|
print("使用默认配置")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载配置时出错: {e}")
|
||||||
|
# 创建一个默认配置
|
||||||
|
self.config = [{"index": 0, "main_title": "景点风光", "texts": ["自然美景", "人文体验"]}]
|
||||||
|
self.img_list = [[0, "景点风光", ["自然美景", "人文体验"]]]
|
||||||
|
print("使用默认配置")
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return self.config
|
return self.config
|
||||||
|
|
||||||
def get_config_by_index(self, index):
|
def get_config_by_index(self, index):
|
||||||
|
if index >= len(self.config):
|
||||||
|
print(f"警告: 索引 {index} 超出配置范围,使用默认配置")
|
||||||
|
return self.config[0]
|
||||||
return self.config[index]
|
return self.config[index]
|
||||||
|
|
||||||
class PosterGenerator:
|
class PosterGenerator:
|
||||||
|
|||||||
117
main.py
117
main.py
@ -18,8 +18,8 @@ TEXT_POSBILITY = 0.3
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
config_file = {
|
config_file = {
|
||||||
"date": "4月17日",
|
"date": "4月24日",
|
||||||
"num": 5,
|
"num": 10,
|
||||||
"model": "qwenQWQ",
|
"model": "qwenQWQ",
|
||||||
"api_url": "vllm",
|
"api_url": "vllm",
|
||||||
"api_key": "EMPTY",
|
"api_key": "EMPTY",
|
||||||
@ -42,7 +42,7 @@ def main():
|
|||||||
],
|
],
|
||||||
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
|
"prompts_dir": "/root/autodl-tmp/TravelContentCreator/genPrompts",
|
||||||
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
|
"output_dir": "/root/autodl-tmp/TravelContentCreator/result",
|
||||||
"variants": 5,
|
"variants": 10,
|
||||||
"topic_temperature": 0.2,
|
"topic_temperature": 0.2,
|
||||||
"content_temperature": 0.3
|
"content_temperature": 0.3
|
||||||
}
|
}
|
||||||
@ -133,60 +133,65 @@ def main():
|
|||||||
content_gen = contentGen.ContentGenerator()
|
content_gen = contentGen.ContentGenerator()
|
||||||
response = content_gen.run(info_directory, poster_num, tweet_content_list)
|
response = content_gen.run(info_directory, poster_num, tweet_content_list)
|
||||||
print(response)
|
print(response)
|
||||||
poster_config_summary = posterGen.PosterConfig(response)
|
try:
|
||||||
for j_index in range(config_file["variants"]):
|
poster_config_summary = posterGen.PosterConfig(response)
|
||||||
poster_config = poster_config_summary.get_config_by_index(j_index)
|
for j_index in range(config_file["variants"]):
|
||||||
img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}")
|
poster_config = poster_config_summary.get_config_by_index(j_index)
|
||||||
try:
|
img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}")
|
||||||
# 创建输出目录
|
try:
|
||||||
collage_output_dir = os.path.join(img_dir, "collage_img")
|
# 创建输出目录
|
||||||
os.makedirs(collage_output_dir, exist_ok=True)
|
collage_output_dir = os.path.join(img_dir, "collage_img")
|
||||||
poster_output_dir = os.path.join(img_dir, "poster")
|
os.makedirs(collage_output_dir, exist_ok=True)
|
||||||
os.makedirs(poster_output_dir, exist_ok=True)
|
poster_output_dir = os.path.join(img_dir, "poster")
|
||||||
|
os.makedirs(poster_output_dir, exist_ok=True)
|
||||||
# 处理图片目录
|
|
||||||
img_list = simple_collage.process_directory(
|
# 处理图片目录
|
||||||
input_dir,
|
img_list = simple_collage.process_directory(
|
||||||
target_size=target_size,
|
input_dir,
|
||||||
output_count=1,
|
target_size=target_size,
|
||||||
output_dir=collage_output_dir
|
output_count=1,
|
||||||
)
|
output_dir=collage_output_dir
|
||||||
print(img_list)
|
)
|
||||||
|
print(img_list)
|
||||||
if not img_list or len(img_list) == 0:
|
|
||||||
print(f"未能生成拼贴图片,跳过海报生成")
|
if not img_list or len(img_list) == 0:
|
||||||
|
print(f"未能生成拼贴图片,跳过海报生成")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成海报
|
||||||
|
poster_gen = posterGen.PosterGenerator()
|
||||||
|
|
||||||
|
if random.random() < TEXT_POSBILITY:
|
||||||
|
text_data = {
|
||||||
|
"title": f"{poster_config['main_title']}",
|
||||||
|
"subtitle": "",
|
||||||
|
"additional_texts": [
|
||||||
|
{"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5},
|
||||||
|
{"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
text_data = {
|
||||||
|
"title": f"{poster_config['main_title']}",
|
||||||
|
"subtitle": "",
|
||||||
|
"additional_texts": [
|
||||||
|
{"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5},
|
||||||
|
# {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
print(text_data)
|
||||||
|
img_path = img_list[0]['path']
|
||||||
|
print(f"使用图片路径: {img_path}")
|
||||||
|
output_path = os.path.join(poster_output_dir, f"poster.jpg")
|
||||||
|
result_path.append(poster_gen.create_poster(img_path, text_data, output_path))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"海报生成过程中出错: {e}")
|
||||||
|
traceback.print_exc() # 打印完整的堆栈跟踪信息
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
# 生成海报
|
print(f"配置解析失败,跳过当前项目: {e}")
|
||||||
poster_gen = posterGen.PosterGenerator()
|
traceback.print_exc() # 打印完整的堆栈跟踪信息
|
||||||
|
continue
|
||||||
if random.random() < TEXT_POSBILITY:
|
|
||||||
text_data = {
|
|
||||||
"title": f"{poster_config['main_title']}",
|
|
||||||
"subtitle": "",
|
|
||||||
"additional_texts": [
|
|
||||||
{"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5},
|
|
||||||
{"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
text_data = {
|
|
||||||
"title": f"{poster_config['main_title']}",
|
|
||||||
"subtitle": "",
|
|
||||||
"additional_texts": [
|
|
||||||
{"text": f"{poster_config['texts'][0]}", "position": "bottom", "size_factor": 0.5},
|
|
||||||
# {"text": f"{poster_config['texts'][1]}", "position": "bottom", "size_factor": 0.5}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
print(text_data)
|
|
||||||
img_path = img_list[0]['path']
|
|
||||||
print(f"使用图片路径: {img_path}")
|
|
||||||
output_path = os.path.join(poster_output_dir, f"poster.jpg")
|
|
||||||
result_path.append(poster_gen.create_poster(img_path, text_data, output_path))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"海报生成过程中出错: {e}")
|
|
||||||
traceback.print_exc() # 打印完整的堆栈跟踪信息
|
|
||||||
continue
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -215,7 +215,7 @@ def prepare_topic_generation(
|
|||||||
"""准备选题生成的环境和参数"""
|
"""准备选题生成的环境和参数"""
|
||||||
# 创建AI Agent
|
# 创建AI Agent
|
||||||
ai_agent = AI_Agent(base_url, model_name, api_key)
|
ai_agent = AI_Agent(base_url, model_name, api_key)
|
||||||
|
|
||||||
# 加载系统提示词
|
# 加载系统提示词
|
||||||
with open(system_prompt_path, "r", encoding="utf-8") as f:
|
with open(system_prompt_path, "r", encoding="utf-8") as f:
|
||||||
system_prompt = f.read()
|
system_prompt = f.read()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user