diff --git a/core/__pycache__/ai_agent.cpython-312.pyc b/core/__pycache__/ai_agent.cpython-312.pyc index 270aa14..57a5f52 100644 Binary files a/core/__pycache__/ai_agent.cpython-312.pyc and b/core/__pycache__/ai_agent.cpython-312.pyc differ diff --git a/core/ai_agent.py b/core/ai_agent.py index e894610..c1c4c44 100644 --- a/core/ai_agent.py +++ b/core/ai_agent.py @@ -1,12 +1,13 @@ import os from openai import OpenAI import time +import random class AI_Agent(): """AI代理类,负责与AI模型交互生成文本内容""" - def __init__(self, base_url, model_name, api): + def __init__(self, base_url, model_name, api, timeout=10, max_retries=3): self.url_list = { "ali": "https://dashscope.aliyuncs.com/compatible-mode/v1", "kimi": "https://api.moonshot.cn/v1", @@ -18,11 +19,13 @@ class AI_Agent(): self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url self.api = api self.model_name = model_name + self.timeout = timeout # 设置超时时间(秒) + self.max_retries = max_retries # 最大重试次数 self.client = OpenAI( api_key=self.api, base_url=self.base_url, - # timeout=10 + timeout=self.timeout # 设置OpenAI客户端超时 ) def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty): @@ -35,29 +38,68 @@ class AI_Agent(): print(f"Base URL: {self.base_url}") print(f"Model: {self.model_name}") - response = self.client.chat.completions.create( - model=self.model_name, - messages=[{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}], - temperature=temperature, - top_p=top_p, - # presence_penalty=presence_penalty, - stream=True, - max_tokens=8192, - extra_body={ - "repetition_penalty": 1.05, - }, - ) + retry_count = 0 + max_retry_wait = 10 # 最大重试等待时间(秒) - # 收集完整的输出内容 - full_response = "" - for chunk in response: - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: - content = chunk.choices[0].delta.content - full_response += content - print(content, end="", flush=True) - if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": - break + while retry_count <= self.max_retries: + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}], + temperature=temperature, + top_p=top_p, + # presence_penalty=presence_penalty, + stream=True, + max_tokens=8192, + timeout=self.timeout, # 设置请求超时 + extra_body={ + "repetition_penalty": 1.05, + }, + ) + + # 收集完整的输出内容 + full_response = "" + try: + for chunk in response: + if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + full_response += content + print(content, end="", flush=True) + if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop": + break + + # 成功完成,跳出重试循环 + break + + except: + # 处理流式响应中的超时 + print(f"\n接收响应时超时") + if len(full_response) > 0: + print(f"已接收部分响应({len(full_response)}字符)") + # 如果已接收足够内容,可以考虑使用已有内容 + if len(full_response) > 100: # 假设至少需要100个字符才有意义 + print("使用已接收的部分内容继续处理") + break + + # 否则准备重试 + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 + print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") + time.sleep(wait_time) + continue + + except Exception as e: + print(f"\n请求发生错误: {e}") + retry_count += 1 + if retry_count <= self.max_retries: + wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避 + print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...") + time.sleep(wait_time) + else: + print(f"已达到最大重试次数({self.max_retries}),放弃请求") + return "请求失败,无法生成内容。", 0 print("\n完成生成,正在处理结果...") diff --git a/genPrompts/Style/攻略风文案提示词 b/genPrompts/Style/攻略风文案提示词 index 6c7450e..4913e8a 100644 --- a/genPrompts/Style/攻略风文案提示词 +++ b/genPrompts/Style/攻略风文案提示词 @@ -24,7 +24,7 @@ -(正文内容,不含带#的TAG内容) +(正文内容) -(TAG内容,如#周末去哪) +(最后一段是TAG内容,如#周末去哪) \ No newline at end of file diff --git a/genPrompts/Style/极力推荐风文案提示词.txt b/genPrompts/Style/极力推荐风文案提示词.txt index 255a024..d9acf54 100644 --- a/genPrompts/Style/极力推荐风文案提示词.txt +++ b/genPrompts/Style/极力推荐风文案提示词.txt @@ -24,7 +24,7 @@ -(正文内容,不含带#的TAG内容) +(正文内容) -(TAG内容,如#周末去哪) +(最后一段是TAG内容,如#周末去哪) \ No newline at end of file diff --git a/genPrompts/Style/美食风文案提示词.txt b/genPrompts/Style/美食风文案提示词.txt index 8eeb73a..db0d394 100644 --- a/genPrompts/Style/美食风文案提示词.txt +++ b/genPrompts/Style/美食风文案提示词.txt @@ -28,7 +28,7 @@ -(正文内容,不含带#的TAG内容) +(正文内容) -(TAG内容,如#周末去哪) +(最后一段是TAG内容,如#周末去哪) \ No newline at end of file diff --git a/genPrompts/Style/轻奢风文案提示词.txt b/genPrompts/Style/轻奢风文案提示词.txt index e4178e0..311e5bd 100644 --- a/genPrompts/Style/轻奢风文案提示词.txt +++ b/genPrompts/Style/轻奢风文案提示词.txt @@ -32,5 +32,5 @@ (正文内容) -(最后一段是TAG内容,如#周末去哪) +(最后一段是TAG内容,如#周末去哪) diff --git a/main.py b/main.py index 443eea3..07a3317 100644 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ import core.posterGen as posterGen import core.simple_collage as simple_collage from utils.resource_loader import ResourceLoader from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content - +import random def main(): config_file = { "date": "4月17日", @@ -77,60 +77,56 @@ def main(): # 直接使用同一个AI Agent实例 for i in range(len(tweet_topic_record.topics_list)): + tweet_content_list = [] for j in range(config_file["variants"]): tweet_content, gen_result = generate_single_content( ai_agent, content_system_prompt, tweet_topic_record.topics_list[i], config_file["prompts_dir"], config_file["resource_dir"], output_dir, run_id, i+1, j+1, config_file["content_temperature"] ) - + tweet_content_list.append(tweet_content.get_json_file()) if not tweet_content: print(f"生成第{i+1}篇文章的第{j+1}个变体失败,跳过") continue - - object_name = tweet_topic_record.topics_list[i]["object"] - try: - object_name = object_name.split(".")[0] - except: - pass - try: - object_name = object_name.split("景点信息-")[1] - except: - pass + object_name = tweet_topic_record.topics_list[i]["object"] + try: + object_name = object_name.split(".")[0] + except: + pass + try: + object_name = object_name.split("景点信息-")[1] + except: + pass # 处理对象名称中可能包含的"+"等特殊字符 # if "+" in object_name: # # 取第一个景点名称作为主要对象 # object_name = object_name.split("+")[0].strip() # print(f"对象名称包含多个景点,使用第一个景点:{object_name}") - # 检查图片路径是否存在 - img_dir_path = f"/root/autodl-tmp/sanming_img/modify/{object_name}" - if not os.path.exists(img_dir_path): - print(f"图片目录不存在:{img_dir_path},跳过该对象") - continue - img_dir = os.path.join(output_dir, f"{i+1}_{j+1}") - info_directory = [ + # 检查图片路径是否存在 + img_dir_path = f"/root/autodl-tmp/sanming_img/modify/{object_name}" + if not os.path.exists(img_dir_path): + print(f"图片目录不存在:{img_dir_path},跳过该对象") + continue + img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}") + info_directory = [ f"/root/autodl-tmp/sanming_img/相机/{object_name}/description.txt" ] - # 检查描述文件是否存在 - if not os.path.exists(info_directory[0]): - print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代") - info_directory = [] + # 检查描述文件是否存在 + if not os.path.exists(info_directory[0]): + print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代") + info_directory = [] - poster_num = 1 - tweet_content_json = tweet_content.get_json_file() - tweet_content_str = f""" - {tweet_content_json} - """ - input_dir = img_dir_path # 使用前面检查过的目录路径 - # img_dir = output_dir - target_size = (900, 1200) - result_path = [] + poster_num = config_file["variants"] - content_gen = contentGen.ContentGenerator() - response = content_gen.run(info_directory, poster_num, tweet_content_str) - print(response) - + input_dir = img_dir_path # 使用前面检查过的目录路径 + target_size = (900, 1200) + result_path = [] + + content_gen = contentGen.ContentGenerator() + response = content_gen.run(info_directory, poster_num, tweet_content_list) + print(response) + for j_index in range(config_file["variants"]): try: # 创建输出目录 collage_output_dir = os.path.join(img_dir, "collage_img") @@ -170,7 +166,7 @@ def main(): } img_path = img_list[index]['path'] print(f"使用图片路径: {img_path}") - output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j+1}_{index}.jpg") + output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j_index+1}_{index}.jpg") result_path.append(poster_gen.create_poster(img_path, text_data, output_path)) except Exception as e: print(f"海报生成过程中出错: {e}") diff --git a/utils/__pycache__/resource_loader.cpython-312.pyc b/utils/__pycache__/resource_loader.cpython-312.pyc index 577c0ee..2e3fb7f 100644 Binary files a/utils/__pycache__/resource_loader.cpython-312.pyc and b/utils/__pycache__/resource_loader.cpython-312.pyc differ diff --git a/utils/__pycache__/tweet_generator.cpython-312.pyc b/utils/__pycache__/tweet_generator.cpython-312.pyc index dd20a51..1027710 100644 Binary files a/utils/__pycache__/tweet_generator.cpython-312.pyc and b/utils/__pycache__/tweet_generator.cpython-312.pyc differ diff --git a/utils/resource_loader.py b/utils/resource_loader.py index 5545a78..0a601b8 100644 --- a/utils/resource_loader.py +++ b/utils/resource_loader.py @@ -1,5 +1,5 @@ import os - +import random class ResourceLoader: """资源加载器,用于加载提示词和参考资料""" @@ -20,7 +20,7 @@ class ResourceLoader: return "" @staticmethod - def load_all_refer_files(refer_dir): + def load_all_refer_files(refer_dir, refer_content_length=50): """加载Refer目录下的所有文件内容""" refer_content = "" try: @@ -30,6 +30,10 @@ class ResourceLoader: file_path = os.path.join(refer_dir, file) if os.path.isfile(file_path): content = ResourceLoader.load_file_content(file_path) + # 用\n分割content,取前length条 + content_lines = content.split("\n") + content_lines = random.sample(content_lines, refer_content_length) + content = "\n".join(content_lines) refer_content += f"## {file}\n{content}\n\n" return refer_content except Exception as e: