TravelContentCreator/core/contentGen.py

108 lines
3.6 KiB
Python
Raw Normal View History

import os
2025-04-22 21:26:56 +08:00
import logging
2025-04-24 20:35:25 +08:00
from utils.content_generator import ContentGenerator as NewContentGenerator
2025-04-24 20:35:25 +08:00
# 为了向后兼容我们保留ContentGenerator类但内部使用新的实现
class ContentGenerator:
def __init__(self,
model_name="qwenQWQ",
api_base_url="http://localhost:8000/v1",
api_key="EMPTY",
2025-04-24 20:35:25 +08:00
output_dir="/root/autodl-tmp/poster_generate_result"):
"""
2025-04-24 20:35:25 +08:00
初始化海报生成器内部使用新的实现
参数:
output_dir: 输出结果保存目录
model_name: 使用的模型名称
api_base_url: API基础URL
api_key: API密钥
"""
2025-04-24 20:35:25 +08:00
# 创建新的实现实例
self._impl = NewContentGenerator(
output_dir=output_dir,
temperature=0.7,
top_p=0.8,
presence_penalty=1.2
)
# 存储API参数
self.model_name = model_name
self.api_base_url = api_base_url
self.api_key = api_key
2025-04-24 20:35:25 +08:00
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
2025-04-24 20:35:25 +08:00
# 为向后兼容保留的字段
self.df = None
self.all_images_info = []
self.structured_prompt = ""
self.current_img_info = None
self.add_description = ""
2025-04-24 20:35:25 +08:00
# 将方法委托给新实现
def load_infomation(self, info_directory_path):
2025-04-24 20:35:25 +08:00
"""加载额外描述文件"""
self._impl.load_infomation(info_directory_path)
self.add_description = self._impl.add_description
def split_content(self, content):
2025-04-24 20:35:25 +08:00
"""分割JSON内容"""
return self._impl.split_content(content)
def generate_posters(self, poster_num, tweet_content, system_prompt=None, max_retries=3):
2025-04-24 20:35:25 +08:00
"""生成海报内容"""
return self._impl.generate_posters(
poster_num,
tweet_content,
system_prompt,
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key,
timeout=60,
max_retries=max_retries
)
2025-04-23 16:18:02 +08:00
def _generate_fallback_content(self, poster_num):
2025-04-24 20:35:25 +08:00
"""生成备用内容"""
return self._impl._generate_fallback_content(poster_num)
def save_result(self, full_response):
2025-04-24 20:35:25 +08:00
"""保存生成结果到文件"""
return self._impl.save_result(full_response)
2025-04-23 19:47:20 +08:00
def _validate_and_fix_data(self, data):
2025-04-24 20:35:25 +08:00
"""验证并修复数据格式"""
return self._impl._validate_and_fix_data(data)
2025-04-24 18:57:05 +08:00
def run(self, info_directory, poster_num, tweet_content, system_prompt=None):
2025-04-24 20:35:25 +08:00
"""运行海报内容生成流程"""
return self._impl.run(
info_directory,
poster_num,
tweet_content,
system_prompt,
api_url=self.api_base_url,
model_name=self.model_name,
api_key=self.api_key
)
def set_temperature(self, temperature):
2025-04-24 20:35:25 +08:00
"""设置温度参数"""
self._impl.set_temperature(temperature)
def set_top_p(self, top_p):
2025-04-24 20:35:25 +08:00
"""设置top_p参数"""
self._impl.set_top_p(top_p)
def set_presence_penalty(self, presence_penalty):
2025-04-24 20:35:25 +08:00
"""设置存在惩罚参数"""
self._impl.set_presence_penalty(presence_penalty)
def set_model_para(self, temperature, top_p, presence_penalty):
2025-04-24 20:35:25 +08:00
"""一次性设置所有模型参数"""
self._impl.set_model_para(temperature, top_p, presence_penalty)