import os import logging from utils.content_generator import ContentGenerator as NewContentGenerator # 为了向后兼容,我们保留ContentGenerator类,但内部使用新的实现 class ContentGenerator: def __init__(self, model_name="qwenQWQ", api_base_url="http://localhost:8000/v1", api_key="EMPTY", output_dir="/root/autodl-tmp/poster_generate_result"): """ 初始化海报生成器,内部使用新的实现 参数: output_dir: 输出结果保存目录 model_name: 使用的模型名称 api_base_url: API基础URL api_key: API密钥 """ # 创建新的实现实例 self._impl = NewContentGenerator( output_dir=output_dir, temperature=0.7, top_p=0.8, presence_penalty=1.2 ) # 存储API参数 self.model_name = model_name self.api_base_url = api_base_url self.api_key = api_key # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') self.logger = logging.getLogger(__name__) # 为向后兼容保留的字段 self.df = None self.all_images_info = [] self.structured_prompt = "" self.current_img_info = None self.add_description = "" # 将方法委托给新实现 def load_infomation(self, info_directory_path): """加载额外描述文件""" self._impl.load_infomation(info_directory_path) self.add_description = self._impl.add_description def split_content(self, content): """分割JSON内容""" return self._impl.split_content(content) def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3): """生成海报内容""" return self._impl.generate_posters( poster_num, tweet_content, system_prompt, api_url=self.api_base_url, model_name=self.model_name, api_key=self.api_key, timeout=60, max_retries=max_retries ) def _generate_fallback_content(self, poster_num): """生成备用内容""" return self._impl._generate_fallback_content(poster_num) def save_result(self, full_response): """保存生成结果到文件""" return self._impl.save_result(full_response) def _validate_and_fix_data(self, data): """验证并修复数据格式""" return self._impl._validate_and_fix_data(data) def run(self, info_directory, poster_num, tweet_content, system_prompt=None): """运行海报内容生成流程""" return self._impl.run( info_directory, poster_num, tweet_content, system_prompt, api_url=self.api_base_url, model_name=self.model_name, api_key=self.api_key ) def set_temperature(self, temperature): """设置温度参数""" self._impl.set_temperature(temperature) def set_top_p(self, top_p): """设置top_p参数""" self._impl.set_top_p(top_p) def set_presence_penalty(self, presence_penalty): """设置存在惩罚参数""" self._impl.set_presence_penalty(presence_penalty) def set_model_para(self, temperature, top_p, presence_penalty): """一次性设置所有模型参数""" self._impl.set_model_para(temperature, top_p, presence_penalty)