修改了生成逻辑
This commit is contained in:
parent
f547c3bfdc
commit
6d09ec65f8
Binary file not shown.
@ -1,12 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class AI_Agent():
|
class AI_Agent():
|
||||||
"""AI代理类,负责与AI模型交互生成文本内容"""
|
"""AI代理类,负责与AI模型交互生成文本内容"""
|
||||||
|
|
||||||
def __init__(self, base_url, model_name, api):
|
def __init__(self, base_url, model_name, api, timeout=10, max_retries=3):
|
||||||
self.url_list = {
|
self.url_list = {
|
||||||
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
"ali": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
"kimi": "https://api.moonshot.cn/v1",
|
"kimi": "https://api.moonshot.cn/v1",
|
||||||
@ -18,11 +19,13 @@ class AI_Agent():
|
|||||||
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
|
self.base_url = self.url_list[base_url] if base_url in self.url_list else base_url
|
||||||
self.api = api
|
self.api = api
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.timeout = timeout # 设置超时时间(秒)
|
||||||
|
self.max_retries = max_retries # 最大重试次数
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=self.api,
|
api_key=self.api,
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
# timeout=10
|
timeout=self.timeout # 设置OpenAI客户端超时
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
def generate_text(self, system_prompt, user_prompt, temperature, top_p, presence_penalty):
|
||||||
@ -35,6 +38,11 @@ class AI_Agent():
|
|||||||
print(f"Base URL: {self.base_url}")
|
print(f"Base URL: {self.base_url}")
|
||||||
print(f"Model: {self.model_name}")
|
print(f"Model: {self.model_name}")
|
||||||
|
|
||||||
|
retry_count = 0
|
||||||
|
max_retry_wait = 10 # 最大重试等待时间(秒)
|
||||||
|
|
||||||
|
while retry_count <= self.max_retries:
|
||||||
|
try:
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
messages=[{"role": "system", "content": system_prompt},
|
messages=[{"role": "system", "content": system_prompt},
|
||||||
@ -44,6 +52,7 @@ class AI_Agent():
|
|||||||
# presence_penalty=presence_penalty,
|
# presence_penalty=presence_penalty,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=8192,
|
max_tokens=8192,
|
||||||
|
timeout=self.timeout, # 设置请求超时
|
||||||
extra_body={
|
extra_body={
|
||||||
"repetition_penalty": 1.05,
|
"repetition_penalty": 1.05,
|
||||||
},
|
},
|
||||||
@ -51,6 +60,7 @@ class AI_Agent():
|
|||||||
|
|
||||||
# 收集完整的输出内容
|
# 收集完整的输出内容
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
try:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
|
||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
@ -59,6 +69,38 @@ class AI_Agent():
|
|||||||
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].finish_reason == "stop":
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 成功完成,跳出重试循环
|
||||||
|
break
|
||||||
|
|
||||||
|
except:
|
||||||
|
# 处理流式响应中的超时
|
||||||
|
print(f"\n接收响应时超时")
|
||||||
|
if len(full_response) > 0:
|
||||||
|
print(f"已接收部分响应({len(full_response)}字符)")
|
||||||
|
# 如果已接收足够内容,可以考虑使用已有内容
|
||||||
|
if len(full_response) > 100: # 假设至少需要100个字符才有意义
|
||||||
|
print("使用已接收的部分内容继续处理")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 否则准备重试
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
|
||||||
|
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n请求发生错误: {e}")
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
wait_time = min(2 ** retry_count + random.random(), max_retry_wait) # 指数退避
|
||||||
|
print(f"\n等待 {wait_time:.2f} 秒后重试({retry_count}/{self.max_retries})...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
print(f"已达到最大重试次数({self.max_retries}),放弃请求")
|
||||||
|
return "请求失败,无法生成内容。", 0
|
||||||
|
|
||||||
print("\n完成生成,正在处理结果...")
|
print("\n完成生成,正在处理结果...")
|
||||||
|
|
||||||
# 由于使用流式输出,无法获取真实的token计数,因此返回估计值
|
# 由于使用流式输出,无法获取真实的token计数,因此返回估计值
|
||||||
|
|||||||
@ -24,7 +24,7 @@
|
|||||||
</title>
|
</title>
|
||||||
|
|
||||||
<content>
|
<content>
|
||||||
(正文内容,不含带#的TAG内容)
|
(正文内容)
|
||||||
|
|
||||||
(TAG内容,如#周末去哪)
|
(最后一段是TAG内容,如#周末去哪)
|
||||||
</content>
|
</content>
|
||||||
@ -24,7 +24,7 @@
|
|||||||
</title>
|
</title>
|
||||||
|
|
||||||
<content>
|
<content>
|
||||||
(正文内容,不含带#的TAG内容)
|
(正文内容)
|
||||||
|
|
||||||
(TAG内容,如#周末去哪)
|
(最后一段是TAG内容,如#周末去哪)
|
||||||
</content>
|
</content>
|
||||||
@ -28,7 +28,7 @@
|
|||||||
</title>
|
</title>
|
||||||
|
|
||||||
<content>
|
<content>
|
||||||
(正文内容,不含带#的TAG内容)
|
(正文内容)
|
||||||
|
|
||||||
(TAG内容,如#周末去哪)
|
(最后一段是TAG内容,如#周末去哪)
|
||||||
</content>
|
</content>
|
||||||
@ -32,5 +32,5 @@
|
|||||||
<content>
|
<content>
|
||||||
(正文内容)
|
(正文内容)
|
||||||
|
|
||||||
(最后一段是TAG内容,如#周末去哪)
|
(最后一段是TAG内容,如#周末去哪)
|
||||||
</content>
|
</content>
|
||||||
|
|||||||
22
main.py
22
main.py
@ -11,7 +11,7 @@ import core.posterGen as posterGen
|
|||||||
import core.simple_collage as simple_collage
|
import core.simple_collage as simple_collage
|
||||||
from utils.resource_loader import ResourceLoader
|
from utils.resource_loader import ResourceLoader
|
||||||
from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content
|
from utils.tweet_generator import prepare_topic_generation, generate_topics, generate_single_content
|
||||||
|
import random
|
||||||
def main():
|
def main():
|
||||||
config_file = {
|
config_file = {
|
||||||
"date": "4月17日",
|
"date": "4月17日",
|
||||||
@ -77,17 +77,17 @@ def main():
|
|||||||
|
|
||||||
# 直接使用同一个AI Agent实例
|
# 直接使用同一个AI Agent实例
|
||||||
for i in range(len(tweet_topic_record.topics_list)):
|
for i in range(len(tweet_topic_record.topics_list)):
|
||||||
|
tweet_content_list = []
|
||||||
for j in range(config_file["variants"]):
|
for j in range(config_file["variants"]):
|
||||||
tweet_content, gen_result = generate_single_content(
|
tweet_content, gen_result = generate_single_content(
|
||||||
ai_agent, content_system_prompt, tweet_topic_record.topics_list[i],
|
ai_agent, content_system_prompt, tweet_topic_record.topics_list[i],
|
||||||
config_file["prompts_dir"], config_file["resource_dir"],
|
config_file["prompts_dir"], config_file["resource_dir"],
|
||||||
output_dir, run_id, i+1, j+1, config_file["content_temperature"]
|
output_dir, run_id, i+1, j+1, config_file["content_temperature"]
|
||||||
)
|
)
|
||||||
|
tweet_content_list.append(tweet_content.get_json_file())
|
||||||
if not tweet_content:
|
if not tweet_content:
|
||||||
print(f"生成第{i+1}篇文章的第{j+1}个变体失败,跳过")
|
print(f"生成第{i+1}篇文章的第{j+1}个变体失败,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
object_name = tweet_topic_record.topics_list[i]["object"]
|
object_name = tweet_topic_record.topics_list[i]["object"]
|
||||||
try:
|
try:
|
||||||
object_name = object_name.split(".")[0]
|
object_name = object_name.split(".")[0]
|
||||||
@ -108,7 +108,7 @@ def main():
|
|||||||
if not os.path.exists(img_dir_path):
|
if not os.path.exists(img_dir_path):
|
||||||
print(f"图片目录不存在:{img_dir_path},跳过该对象")
|
print(f"图片目录不存在:{img_dir_path},跳过该对象")
|
||||||
continue
|
continue
|
||||||
img_dir = os.path.join(output_dir, f"{i+1}_{j+1}")
|
img_dir = os.path.join(output_dir, f"{i+1}_{j_index+1}")
|
||||||
info_directory = [
|
info_directory = [
|
||||||
f"/root/autodl-tmp/sanming_img/相机/{object_name}/description.txt"
|
f"/root/autodl-tmp/sanming_img/相机/{object_name}/description.txt"
|
||||||
]
|
]
|
||||||
@ -117,20 +117,16 @@ def main():
|
|||||||
print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代")
|
print(f"描述文件不存在:{info_directory[0]},使用生成的内容替代")
|
||||||
info_directory = []
|
info_directory = []
|
||||||
|
|
||||||
poster_num = 1
|
poster_num = config_file["variants"]
|
||||||
tweet_content_json = tweet_content.get_json_file()
|
|
||||||
tweet_content_str = f"""
|
|
||||||
{tweet_content_json}
|
|
||||||
"""
|
|
||||||
input_dir = img_dir_path # 使用前面检查过的目录路径
|
input_dir = img_dir_path # 使用前面检查过的目录路径
|
||||||
# img_dir = output_dir
|
|
||||||
target_size = (900, 1200)
|
target_size = (900, 1200)
|
||||||
result_path = []
|
result_path = []
|
||||||
|
|
||||||
content_gen = contentGen.ContentGenerator()
|
content_gen = contentGen.ContentGenerator()
|
||||||
response = content_gen.run(info_directory, poster_num, tweet_content_str)
|
response = content_gen.run(info_directory, poster_num, tweet_content_list)
|
||||||
print(response)
|
print(response)
|
||||||
|
for j_index in range(config_file["variants"]):
|
||||||
try:
|
try:
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
collage_output_dir = os.path.join(img_dir, "collage_img")
|
collage_output_dir = os.path.join(img_dir, "collage_img")
|
||||||
@ -170,7 +166,7 @@ def main():
|
|||||||
}
|
}
|
||||||
img_path = img_list[index]['path']
|
img_path = img_list[index]['path']
|
||||||
print(f"使用图片路径: {img_path}")
|
print(f"使用图片路径: {img_path}")
|
||||||
output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j+1}_{index}.jpg")
|
output_path = os.path.join(poster_output_dir, f"img_{i+1}_{j_index+1}_{index}.jpg")
|
||||||
result_path.append(poster_gen.create_poster(img_path, text_data, output_path))
|
result_path.append(poster_gen.create_poster(img_path, text_data, output_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"海报生成过程中出错: {e}")
|
print(f"海报生成过程中出错: {e}")
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
class ResourceLoader:
|
class ResourceLoader:
|
||||||
"""资源加载器,用于加载提示词和参考资料"""
|
"""资源加载器,用于加载提示词和参考资料"""
|
||||||
@ -20,7 +20,7 @@ class ResourceLoader:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_all_refer_files(refer_dir):
|
def load_all_refer_files(refer_dir, refer_content_length=50):
|
||||||
"""加载Refer目录下的所有文件内容"""
|
"""加载Refer目录下的所有文件内容"""
|
||||||
refer_content = ""
|
refer_content = ""
|
||||||
try:
|
try:
|
||||||
@ -30,6 +30,10 @@ class ResourceLoader:
|
|||||||
file_path = os.path.join(refer_dir, file)
|
file_path = os.path.join(refer_dir, file)
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
content = ResourceLoader.load_file_content(file_path)
|
content = ResourceLoader.load_file_content(file_path)
|
||||||
|
# 用\n分割content,取前length条
|
||||||
|
content_lines = content.split("\n")
|
||||||
|
content_lines = random.sample(content_lines, refer_content_length)
|
||||||
|
content = "\n".join(content_lines)
|
||||||
refer_content += f"## {file}\n{content}\n\n"
|
refer_content += f"## {file}\n{content}\n\n"
|
||||||
return refer_content
|
return refer_content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user