108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
import os
|
||
import logging
|
||
from utils.content_generator import ContentGenerator as NewContentGenerator
|
||
|
||
# 为了向后兼容,我们保留ContentGenerator类,但内部使用新的实现
|
||
class ContentGenerator:
|
||
def __init__(self,
|
||
model_name="qwenQWQ",
|
||
api_base_url="http://localhost:8000/v1",
|
||
api_key="EMPTY",
|
||
output_dir="/root/autodl-tmp/poster_generate_result"):
|
||
"""
|
||
初始化海报生成器,内部使用新的实现
|
||
|
||
参数:
|
||
output_dir: 输出结果保存目录
|
||
model_name: 使用的模型名称
|
||
api_base_url: API基础URL
|
||
api_key: API密钥
|
||
"""
|
||
# 创建新的实现实例
|
||
self._impl = NewContentGenerator(
|
||
output_dir=output_dir,
|
||
temperature=0.7,
|
||
top_p=0.8,
|
||
presence_penalty=1.2
|
||
)
|
||
|
||
# 存储API参数
|
||
self.model_name = model_name
|
||
self.api_base_url = api_base_url
|
||
self.api_key = api_key
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
# 为向后兼容保留的字段
|
||
self.df = None
|
||
self.all_images_info = []
|
||
self.structured_prompt = ""
|
||
self.current_img_info = None
|
||
self.add_description = ""
|
||
|
||
# 将方法委托给新实现
|
||
def load_infomation(self, info_directory_path):
|
||
"""加载额外描述文件"""
|
||
self._impl.load_infomation(info_directory_path)
|
||
self.add_description = self._impl.add_description
|
||
|
||
def split_content(self, content):
|
||
"""分割JSON内容"""
|
||
return self._impl.split_content(content)
|
||
|
||
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
|
||
"""生成海报内容"""
|
||
return self._impl.generate_posters(
|
||
poster_num,
|
||
tweet_content,
|
||
system_prompt,
|
||
api_url=self.api_base_url,
|
||
model_name=self.model_name,
|
||
api_key=self.api_key,
|
||
timeout=60,
|
||
max_retries=max_retries
|
||
)
|
||
|
||
def _generate_fallback_content(self, poster_num):
|
||
"""生成备用内容"""
|
||
return self._impl._generate_fallback_content(poster_num)
|
||
|
||
def save_result(self, full_response):
|
||
"""保存生成结果到文件"""
|
||
return self._impl.save_result(full_response)
|
||
|
||
def _validate_and_fix_data(self, data):
|
||
"""验证并修复数据格式"""
|
||
return self._impl._validate_and_fix_data(data)
|
||
|
||
def run(self, info_directory, poster_num, tweet_content, system_prompt=None):
|
||
"""运行海报内容生成流程"""
|
||
return self._impl.run(
|
||
info_directory,
|
||
poster_num,
|
||
tweet_content,
|
||
system_prompt,
|
||
api_url=self.api_base_url,
|
||
model_name=self.model_name,
|
||
api_key=self.api_key
|
||
)
|
||
|
||
def set_temperature(self, temperature):
|
||
"""设置温度参数"""
|
||
self._impl.set_temperature(temperature)
|
||
|
||
def set_top_p(self, top_p):
|
||
"""设置top_p参数"""
|
||
self._impl.set_top_p(top_p)
|
||
|
||
def set_presence_penalty(self, presence_penalty):
|
||
"""设置存在惩罚参数"""
|
||
self._impl.set_presence_penalty(presence_penalty)
|
||
|
||
def set_model_para(self, temperature, top_p, presence_penalty):
|
||
"""一次性设置所有模型参数"""
|
||
self._impl.set_model_para(temperature, top_p, presence_penalty)
|