diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc
index 270aa14..57a5f52 100644
Binary files a/core/__pycache__/ai_agent.cpython-312.pyc and b/core/__pycache__/ai_agent.cpython-312.pyc differ
diff --git a/core/ai_agent.py b/core/ai_agent.py
index e894610..c1c4c44 100644
--- a/core/ai_agent.py
+++ b/core/ai_agent.py
@@ -1,12 +1,13 @@
import os
from openai import OpenAI
import time
+import random
class AI_Agent():
"""AI代理类,负责与AI模型交互生成文本内容"""
- def __init__(self, base_url, model_name, api):
+ def __init__(self, base_url, model_name, api, timeout=10, max_retries=3):
self.url_list = {
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"kimi": "https://api.moonshot.cn/v1",
@@ -18,11 +19,13 @@ class AI_Agent():
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
self.api = api
self.model_name = model_name
+ self.timeout = timeout # 设置超时时间(秒)
+ self.max_retries = max_retries # 最大重试次数
self.client = OpenAI(
api_key=self.api,
base_url=self.base_url,
- # timeout=10
+ timeout=self.timeout # 设置OpenAI客户端超时
)
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
@@ -35,29 +38,68 @@ class AI_Agent():
print(f"Base URL: {self.base_url}")
print(f"Model: {self.model_name}")
- response = self.client.chat.completions.create(
- model=self.model_name,
- messages=[{"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt}],
- temperature=temperature,
- top_p=top_p,
- # presence_penalty=presence_penalty,
- stream=True,
- max_tokens=8192,
- extra_body={
- "repetition_penalty": 1.05,
- },
- )
+ retry_count = 0
+ max_retry_wait = 10 # 最大重试等待时间(秒)
- # 收集完整的输出内容
- full_response = ""
- for chunk in response:
- if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
- content = chunk.choices[0].delta.content
- full_response += content
- print(content, end="", flush=True)
- if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
- break
+ while retry_count <= self.max_retries:
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=[{"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt}],
+ temperature=temperature,
+ top_p=top_p,
+ # presence_penalty=presence_penalty,
+ stream=True,
+ max_tokens=8192,
+ timeout=self.timeout, # 设置请求超时
+ extra_body={
+ "repetition_penalty": 1.05,
+ },
+ )
+
+ # 收集完整的输出内容
+ full_response = ""
+ try:
+ for chunk in response:
+ if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
+ content = chunk.choices[0].delta.content
+ full_response += content
+ print(content, end="", flush=True)
+ if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
+ break
+
+ # 成功完成,跳出重试循环
+ break
+
+ except:
+ # 处理流式响应中的超时
+ print(f"\n接收响应时超时")
+ if len(full_response) > 0:
+ print(f"已接收部分响应({len(full_response)}字符)")
+ # 如果已接收足够内容,可以考虑使用已有内容
+ if len(full_response) > 100: # 假设至少需要100个字符才有意义
+ print("使用已接收的部分内容继续处理")
+ break
+
+ # 否则准备重试
+ retry_count += 1
+ if retry_count <= self.max_retries:
+ wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
+ print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
+ time.sleep(wait_time)
+ continue
+
+ except Exception as e:
+ print(f"\n请求发生错误: {e}")
+ retry_count += 1
+ if retry_count <= self.max_retries:
+ wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
+ print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
+ time.sleep(wait_time)
+ else:
+ print(f"已达到最大重试次数({self.max_retries}),放弃请求")
+ return "请求失败,无法生成内容。", 0
print("\n完成生成,正在处理结果...")
diff --git a/genPrompts/Style/攻略风文案提示词 b/genPrompts/Style/攻略风文案提示词
index 6c7450e..4913e8a 100644
--- a/genPrompts/Style/攻略风文案提示词
+++ b/genPrompts/Style/攻略风文案提示词
@@ -24,7 +24,7 @@
-(正文内容,不含带#的TAG内容)
+(正文内容)
-(TAG内容,如#周末去哪)
+(最后一段是TAG内容,如#周末去哪)
\ No newline at end of file
diff --git a/genPrompts/Style/极力推荐风文案提示词.txt b/genPrompts/Style/极力推荐风文案提示词.txt
index 255a024..d9acf54 100644
--- a/genPrompts/Style/极力推荐风文案提示词.txt
+++ b/genPrompts/Style/极力推荐风文案提示词.txt
@@ -24,7 +24,7 @@
-(正文内容,不含带#的TAG内容)
+(正文内容)
-(TAG内容,如#周末去哪)
+(最后一段是TAG内容,如#周末去哪)
\ No newline at end of file
diff --git a/genPrompts/Style/美食风文案提示词.txt b/genPrompts/Style/美食风文案提示词.txt
index 8eeb73a..db0d394 100644
--- a/genPrompts/Style/美食风文案提示词.txt
+++ b/genPrompts/Style/美食风文案提示词.txt
@@ -28,7 +28,7 @@
-(正文内容,不含带#的TAG内容)
+(正文内容)
-(TAG内容,如#周末去哪)
+(最后一段是TAG内容,如#周末去哪)
\ No newline at end of file
diff --git a/genPrompts/Style/轻奢风文案提示词.txt b/genPrompts/Style/轻奢风文案提示词.txt
index e4178e0..311e5bd 100644
--- a/genPrompts/Style/轻奢风文案提示词.txt
+++ b/genPrompts/Style/轻奢风文案提示词.txt
@@ -32,5 +32,5 @@
(正文内容)
-(最后一段是TAG内容,如#周末去哪)
+(最后一段是TAG内容,如#周末去哪)
diff --git a/main.py b/main.py
index 443eea3..07a3317 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@ import core.posterGen as posterGen
import core.simple_collage as simple_collage
from utils.resource_loader import ResourceLoader
from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content
-
+import random
def main():
config_file = {
"date": "4月17日",
@@ -77,60 +77,56 @@ def main():
# 直接使用同一个AI Agent实例
for i in range(len(tweet_topic_record.topics_list)):
+ tweet_content_list = []
for j in range(config_file["variants"]):
tweet_content, gen_result = generate_single_content(
ai_agent, content_system_prompt, tweet_topic_record.topics_list[i],
config_file["prompts_dir"], config_file["resource_dir"],
output_dir, run_id, i+1, j+1, config_file["content_temperature"]
)
-
+ tweet_content_list.append(tweet_content.get_json_file())
if not tweet_content:
print(f"生成第{i+1}篇文章的第{j+1}个变体失败,跳过")
continue
-
- object_name = tweet_topic_record.topics_list[i]["object"]
- try:
- object_name = object_name.split(".")[0]
- except:
- pass
- try:
- object_name = object_name.split("景点信息-")[1]
- except:
- pass
+ object_name = tweet_topic_record.topics_list[i]["object"]
+ try:
+ object_name = object_name.split(".")[0]
+ except:
+ pass
+ try:
+ object_name = object_name.split("景点信息-")[1]
+ except:
+ pass
# 处理对象名称中可能包含的"+"等特殊字符
# if "+" in object_name:
# # 取第一个景点名称作为主要对象
# object_name = object_name.split("+")[0].strip()
# print(f"对象名称包含多个景点,使用第一个景点:{object_name}")
- # 检查图片路径是否存在
- img_dir_path = f"/root/autodl-tmp/sanming_img/modify/{object_name}"
- if not os.path.exists(img_dir_path):
- print(f"图片目录不存在:{img_dir_path},跳过该对象")
- continue
- img_dir = os.path.join(output_dir, f"{i+1}_{j+1}")
- info_directory = [
+ # 检查图片路径是否存在
+ img_dir_path = f"/root/autodl-tmp/sanming_img/modify/{object_name}"
+ if not os.path.exists(img_dir_path):
+ print(f"图片目录不存在:{img_dir_path},跳过该对象")
+ continue
+ img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}")
+ info_directory = [
f"/root/autodl-tmp/sanming_img/相机/{object_name}/description.txt"
]
- # 检查描述文件是否存在
- if not os.path.exists(info_directory[0]):
- print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代")
- info_directory = []
+ # 检查描述文件是否存在
+ if not os.path.exists(info_directory[0]):
+ print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代")
+ info_directory = []
- poster_num = 1
- tweet_content_json = tweet_content.get_json_file()
- tweet_content_str = f"""
- {tweet_content_json}
- """
- input_dir = img_dir_path # 使用前面检查过的目录路径
- # img_dir = output_dir
- target_size = (900, 1200)
- result_path = []
+ poster_num = config_file["variants"]
- content_gen = contentGen.ContentGenerator()
- response = content_gen.run(info_directory, poster_num, tweet_content_str)
- print(response)
-
+ input_dir = img_dir_path # 使用前面检查过的目录路径
+ target_size = (900, 1200)
+ result_path = []
+
+ content_gen = contentGen.ContentGenerator()
+ response = content_gen.run(info_directory, poster_num, tweet_content_list)
+ print(response)
+ for j_index in range(config_file["variants"]):
try:
# 创建输出目录
collage_output_dir = os.path.join(img_dir, "collage_img")
@@ -170,7 +166,7 @@ def main():
}
img_path = img_list[index]['path']
print(f"使用图片路径: {img_path}")
- output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j+1}_{index}.jpg")
+ output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j_index+1}_{index}.jpg")
result_path.append(poster_gen.create_poster(img_path, text_data, output_path))
except Exception as e:
print(f"海报生成过程中出错: {e}")
diff --git a/utils/__pycache__/resource_loader.cpython-312.pyc b/utils/__pycache__/resource_loader.cpython-312.pyc
index 577c0ee..2e3fb7f 100644
Binary files a/utils/__pycache__/resource_loader.cpython-312.pyc and b/utils/__pycache__/resource_loader.cpython-312.pyc differ
diff --git a/utils/__pycache__/tweet_generator.cpython-312.pyc b/utils/__pycache__/tweet_generator.cpython-312.pyc
index dd20a51..1027710 100644
Binary files a/utils/__pycache__/tweet_generator.cpython-312.pyc and b/utils/__pycache__/tweet_generator.cpython-312.pyc differ
diff --git a/utils/resource_loader.py b/utils/resource_loader.py
index 5545a78..0a601b8 100644
--- a/utils/resource_loader.py
+++ b/utils/resource_loader.py
@@ -1,5 +1,5 @@
import os
-
+import random
class ResourceLoader:
"""资源加载器,用于加载提示词和参考资料"""
@@ -20,7 +20,7 @@ class ResourceLoader:
return ""
@staticmethod
- def load_all_refer_files(refer_dir):
+ def load_all_refer_files(refer_dir, refer_content_length=50):
"""加载Refer目录下的所有文件内容"""
refer_content = ""
try:
@@ -30,6 +30,10 @@ class ResourceLoader:
file_path = os.path.join(refer_dir, file)
if os.path.isfile(file_path):
content = ResourceLoader.load_file_content(file_path)
+ # 用\n分割content,取前length条
+ content_lines = content.split("\n")
+ content_lines = random.sample(content_lines, refer_content_length)
+ content = "\n".join(content_lines)
refer_content += f"## {file}\n{content}\n\n"
return refer_content
except Exception as e: