2025-04-17 11:05:46 +08:00
|
|
|
|
import os
|
2025-04-22 21:26:56 +08:00
|
|
|
|
import logging
|
2025-04-24 20:35:25 +08:00
|
|
|
|
from utils.content_generator import ContentGenerator as NewContentGenerator
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
2025-04-24 20:35:25 +08:00
|
|
|
|
# 为了向后兼容,我们保留ContentGenerator类,但内部使用新的实现
|
2025-04-17 11:05:46 +08:00
|
|
|
|
class ContentGenerator:
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
model_name="qwenQWQ",
|
|
|
|
|
|
api_base_url="http://localhost:8000/v1",
|
|
|
|
|
|
api_key="EMPTY",
|
2025-04-24 20:35:25 +08:00
|
|
|
|
output_dir="/root/autodl-tmp/poster_generate_result"):
|
2025-04-17 11:05:46 +08:00
|
|
|
|
"""
|
2025-04-24 20:35:25 +08:00
|
|
|
|
初始化海报生成器,内部使用新的实现
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
|
output_dir: 输出结果保存目录
|
|
|
|
|
|
model_name: 使用的模型名称
|
|
|
|
|
|
api_base_url: API基础URL
|
|
|
|
|
|
api_key: API密钥
|
|
|
|
|
|
"""
|
2025-04-24 20:35:25 +08:00
|
|
|
|
# 创建新的实现实例
|
|
|
|
|
|
self._impl = NewContentGenerator(
|
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
|
temperature=0.7,
|
|
|
|
|
|
top_p=0.8,
|
|
|
|
|
|
presence_penalty=1.2
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 存储API参数
|
2025-04-17 11:05:46 +08:00
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
self.api_base_url = api_base_url
|
|
|
|
|
|
self.api_key = api_key
|
|
|
|
|
|
|
2025-04-24 20:35:25 +08:00
|
|
|
|
# 设置日志
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO,
|
|
|
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
self.logger = logging.getLogger(__name__)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
2025-04-24 20:35:25 +08:00
|
|
|
|
# 为向后兼容保留的字段
|
2025-04-17 11:05:46 +08:00
|
|
|
|
self.df = None
|
|
|
|
|
|
self.all_images_info = []
|
|
|
|
|
|
self.structured_prompt = ""
|
|
|
|
|
|
self.current_img_info = None
|
|
|
|
|
|
self.add_description = ""
|
|
|
|
|
|
|
2025-04-24 20:35:25 +08:00
|
|
|
|
# 将方法委托给新实现
|
2025-04-17 11:05:46 +08:00
|
|
|
|
def load_infomation(self, info_directory_path):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""加载额外描述文件"""
|
|
|
|
|
|
self._impl.load_infomation(info_directory_path)
|
|
|
|
|
|
self.add_description = self._impl.add_description
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
def split_content(self, content):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""分割JSON内容"""
|
|
|
|
|
|
return self._impl.split_content(content)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""生成海报内容"""
|
|
|
|
|
|
return self._impl.generate_posters(
|
|
|
|
|
|
poster_num,
|
|
|
|
|
|
tweet_content,
|
|
|
|
|
|
system_prompt,
|
|
|
|
|
|
api_url=self.api_base_url,
|
|
|
|
|
|
model_name=self.model_name,
|
|
|
|
|
|
api_key=self.api_key,
|
|
|
|
|
|
timeout=60,
|
|
|
|
|
|
max_retries=max_retries
|
|
|
|
|
|
)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
2025-04-23 16:18:02 +08:00
|
|
|
|
def _generate_fallback_content(self, poster_num):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""生成备用内容"""
|
|
|
|
|
|
return self._impl._generate_fallback_content(poster_num)
|
|
|
|
|
|
|
2025-04-17 11:05:46 +08:00
|
|
|
|
def save_result(self, full_response):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""保存生成结果到文件"""
|
|
|
|
|
|
return self._impl.save_result(full_response)
|
|
|
|
|
|
|
2025-04-23 19:47:20 +08:00
|
|
|
|
def _validate_and_fix_data(self, data):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""验证并修复数据格式"""
|
|
|
|
|
|
return self._impl._validate_and_fix_data(data)
|
|
|
|
|
|
|
2025-04-24 18:57:05 +08:00
|
|
|
|
def run(self, info_directory, poster_num, tweet_content, system_prompt=None):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""运行海报内容生成流程"""
|
|
|
|
|
|
return self._impl.run(
|
|
|
|
|
|
info_directory,
|
|
|
|
|
|
poster_num,
|
|
|
|
|
|
tweet_content,
|
|
|
|
|
|
system_prompt,
|
|
|
|
|
|
api_url=self.api_base_url,
|
|
|
|
|
|
model_name=self.model_name,
|
|
|
|
|
|
api_key=self.api_key
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-04-17 11:05:46 +08:00
|
|
|
|
def set_temperature(self, temperature):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""设置温度参数"""
|
|
|
|
|
|
self._impl.set_temperature(temperature)
|
|
|
|
|
|
|
2025-04-17 11:05:46 +08:00
|
|
|
|
def set_top_p(self, top_p):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""设置top_p参数"""
|
|
|
|
|
|
self._impl.set_top_p(top_p)
|
|
|
|
|
|
|
2025-04-17 11:05:46 +08:00
|
|
|
|
def set_presence_penalty(self, presence_penalty):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""设置存在惩罚参数"""
|
|
|
|
|
|
self._impl.set_presence_penalty(presence_penalty)
|
2025-04-17 11:05:46 +08:00
|
|
|
|
|
|
|
|
|
|
def set_model_para(self, temperature, top_p, presence_penalty):
|
2025-04-24 20:35:25 +08:00
|
|
|
|
"""一次性设置所有模型参数"""
|
|
|
|
|
|
self._impl.set_model_para(temperature, top_p, presence_penalty)
|