TravelContentCreator/core/contentGen.py

108 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
from utils.content_generator import ContentGenerator as NewContentGenerator
# 为了向后兼容我们保留ContentGenerator类但内部使用新的实现
class ContentGenerator:
def __init__(self,
model_name="qwenQWQ",
api_base_url="http://localhost:8000/v1",
api_key="EMPTY",
output_dir="/root/autodl-tmp/poster_generate_result"):
"""
初始化海报生成器,内部使用新的实现
参数:
output_dir: 输出结果保存目录
model_name: 使用的模型名称
api_base_url: API基础URL
api_key: API密钥
"""
# 创建新的实现实例
self._impl = NewContentGenerator(
output_dir=output_dir,
temperature=0.7,
top_p=0.8,
presence_penalty=1.2
)
# 存储API参数
self.model_name = model_name
self.api_base_url = api_base_url
self.api_key = api_key
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
# 为向后兼容保留的字段
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
# 将方法委托给新实现
def load_infomation(self, info_directory_path):
"""加载额外描述文件"""
self._impl.load_infomation(info_directory_path)
self.add_description = self._impl.add_description
def split_content(self, content):
"""分割JSON内容"""
return self._impl.split_content(content)
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
"""生成海报内容"""
return self._impl.generate_posters(
poster_num,
tweet_content,
system_prompt,
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key,
timeout=60,
max_retries=max_retries
)
def _generate_fallback_content(self, poster_num):
"""生成备用内容"""
return self._impl._generate_fallback_content(poster_num)
def save_result(self, full_response):
"""保存生成结果到文件"""
return self._impl.save_result(full_response)
def _validate_and_fix_data(self, data):
"""验证并修复数据格式"""
return self._impl._validate_and_fix_data(data)
def run(self, info_directory, poster_num, tweet_content, system_prompt=None):
"""运行海报内容生成流程"""
return self._impl.run(
info_directory,
poster_num,
tweet_content,
system_prompt,
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key
)
def set_temperature(self, temperature):
"""设置温度参数"""
self._impl.set_temperature(temperature)
def set_top_p(self, top_p):
"""设置top_p参数"""
self._impl.set_top_p(top_p)
def set_presence_penalty(self, presence_penalty):
"""设置存在惩罚参数"""
self._impl.set_presence_penalty(presence_penalty)
def set_model_para(self, temperature, top_p, presence_penalty):
"""一次性设置所有模型参数"""
self._impl.set_model_para(temperature, top_p, presence_penalty)