Compare commits

...

3 Commits

Author SHA1 Message Date
91ac7ca65a 合并poster_update to reinstruct 2025-07-26 18:42:53 +08:00
cead3be01a 封装了海报模块 2025-07-25 18:39:16 +08:00
436d1917ea 更新海报模块api 2025-07-25 17:13:37 +08:00
67 changed files with 377 additions and 246 deletions

Binary file not shown.

View File

@ -5,6 +5,7 @@
海报API模型定义 - 使用 camelCase 命名约定
"""
from re import T
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, Field
@ -12,7 +13,7 @@ from pydantic import BaseModel, Field
class PosterGenerateRequest(BaseModel):
"""海报生成请求模型"""
templateId: str = Field("vibrant", description="模板ID")
imageIds: Optional[List[int]] = Field(None, description="图像ID列表")
imagesBase64: Optional[str] = Field(None, description="图像base64编码")
posterContent: Optional[Dict[str, Any]] = Field(None, description="海报内容,如果提供则直接使用此内容")
contentId: Optional[int] = Field(None, description="内容ID用于AI生成内容")
productId: Optional[int] = Field(None, description="产品ID用于AI生成内容")
@ -24,11 +25,15 @@ class PosterGenerateRequest(BaseModel):
json_schema_extra = {
"example": {
"templateId": "vibrant",
"imageIds": [1, 2, 3],
"numVariations": 3,
"posterContent": {
"title": "直接提供的内容标题",
"slogan": "这是一个预先准备好的口号"
"imagesBase64": "",
"numVariations": 1,
"forceLlmGeneration":False,
"contentId":1,
"productId":1,
"scenicSpotId":1,
"posterContent":{
"title":"天津冒险湾",
"slogan":"天津冒险湾,让你体验不一样的冒险之旅"
}
}
}
@ -59,9 +64,7 @@ class PosterGenerateResponse(BaseModel):
"""海报生成响应模型"""
requestId: str
templateId: str
resultImagesBase64: List[str] = Field(description="生成的海报图像(base64编码)列表")
usedImageIds: List[int] = Field(default_factory=list)
imageUsageInfo: List[ImageUsageInfo] = Field(default_factory=list)
resultImagesBase64: List[Dict[str, Any]] = Field(description="生成的海报图像(base64编码)列表")
metadata: Dict[str, Any] = Field(default_factory=dict)
class ImageUsageRequest(BaseModel):
@ -73,3 +76,4 @@ class ImageUsageResponse(BaseModel):
requestId: str
imageUsageInfo: List[ImageUsageInfo]
summary: Dict[str, Any]

View File

@ -95,14 +95,14 @@ async def generate_poster(
"""
try:
result = await poster_service.generate_poster(
template_id=request.template_id,
poster_content=request.poster_content,
content_id=request.content_id,
product_id=request.product_id,
scenic_spot_id=request.scenic_spot_id,
image_ids=request.image_ids,
generate_collage=request.generate_collage,
force_llm_generation=request.force_llm_generation
template_id=request.templateId,
poster_content=request.posterContent,
content_id=request.contentId,
product_id=request.productId,
scenic_spot_id=request.scenicSpotId,
images_base64=request.imagesBase64,
num_variations=request.numVariations,
force_llm_generation=request.forceLlmGeneration
)
return PosterGenerateResponse(**result)

View File

@ -902,32 +902,35 @@ class DatabaseService:
return
try:
with self.db_pool.get_connection() as conn:
with conn.cursor() as cursor:
# 更新或插入统计记录
cursor.execute("""
INSERT INTO templateUsageStats (templateId, usageCount, successCount, errorCount, avgProcessingTime, lastUsedAt)
VALUES (%s, 1, %s, %s, %s, NOW())
ON DUPLICATE KEY UPDATE
usageCount = usageCount + 1,
successCount = successCount + %s,
errorCount = errorCount + %s,
avgProcessingTime = (avgProcessingTime * (usageCount - 1) + %s) / usageCount,
lastUsedAt = NOW(),
updateTime = NOW()
""", (
template_id,
1 if success else 0,
0 if success else 1,
processing_time,
1 if success else 0,
0 if success else 1,
processing_time
))
conn = self.db_pool.get_connection()
cursor = conn.cursor()
conn.commit()
# 更新或插入统计记录
cursor.execute("""
INSERT INTO templateUsageStats (templateId, usageCount, successCount, errorCount, avgProcessingTime, lastUsedAt)
VALUES (%s, 1, %s, %s, %s, NOW())
ON DUPLICATE KEY UPDATE
usageCount = usageCount + 1,
successCount = successCount + %s,
errorCount = errorCount + %s,
avgProcessingTime = (avgProcessingTime * (usageCount - 1) + %s) / usageCount,
lastUsedAt = NOW(),
updatedAt = NOW()
""", (
template_id,
1 if success else 0,
0 if success else 1,
processing_time,
1 if success else 0,
0 if success else 1,
processing_time
))
logger.info(f"更新模板使用统计: template_id={template_id}, success={success}, time={processing_time:.3f}s")
conn.commit()
cursor.close()
conn.close()
logger.info(f"更新模板使用统计: template_id={template_id}, success={success}, time={processing_time:.3f}s")
except Exception as e:
logger.error(f"更新模板使用统计失败: {e}")

View File

@ -6,6 +6,7 @@
封装核心功能支持基于模板的动态内容生成和海报创建
"""
import logging
import uuid
import time
@ -62,18 +63,18 @@ class PosterService:
'vibrant': {
'id': 'vibrant',
'name': '活力风格',
'handler_path': 'poster.templates.vibrant_template',
'class_name': 'VibrantTemplate',
'handlerPath': 'poster.templates.vibrant_template',
'className': 'VibrantTemplate',
'description': '适合景点、活动等充满活力的场景',
'is_active': True
'isActive': True
},
'business': {
'id': 'business',
'name': '商务风格',
'handler_path': 'poster.templates.business_template',
'class_name': 'BusinessTemplate',
'handlerPath': 'poster.templates.business_template',
'className': 'BusinessTemplate',
'description': '适合酒店、房地产等商务场景',
'is_active': True
'isActive': True
}
}
@ -88,11 +89,11 @@ class PosterService:
return self._template_instances[template_id]
template_info = self._templates[template_id]
handler_path = template_info.get('handler_path')
class_name = template_info.get('class_name')
handler_path = template_info.get('handlerPath')
class_name = template_info.get('className')
if not handler_path or not class_name:
logger.error(f"模板 {template_id} 缺少 handler_path 或 class_name")
logger.error(f"模板 {template_id} 缺少 handlerPath 或 className")
return None
try:
@ -105,11 +106,11 @@ class PosterService:
# 设置字体目录(如果配置了)
from core.config import PosterConfig
poster_config = self.config_manager.get_config('poster', PosterConfig)
if poster_config:
font_dir = poster_config.font_dir
if font_dir and hasattr(template_instance, 'set_font_dir'):
template_instance.set_font_dir(font_dir)
# poster_config = self.config_manager.get_config('poster', PosterConfig)
# if poster_config:
# font_dir = poster_config.font_dir
# if font_dir and hasattr(template_instance, 'set_font_dir'):
# template_instance.set_font_dir(font_dir)
# 缓存实例以便重用
self._template_instances[template_id] = template_instance
@ -158,7 +159,7 @@ class PosterService:
content_id: Optional[int],
product_id: Optional[int],
scenic_spot_id: Optional[int],
image_ids: Optional[List[int]],
images_base64: Optional[List[str]] ,
num_variations: int = 1,
force_llm_generation: bool = False) -> Dict[str, Any]:
"""
@ -193,12 +194,28 @@ class PosterService:
if not final_content:
raise ValueError("无法获取用于生成海报的内容")
# 3. 准备图片
images = []
if image_ids:
images = self.db_service.get_images_by_ids(image_ids)
if not images:
raise ValueError("无法获取指定的图片")
# # 3. 准备图片
# images = []
# if image_ids:
# images = self.db_service.get_images_by_ids(image_ids)
# if not images:
# raise ValueError("无法获取指定的图片")
# # 3. 图片解码
try:
# 移除可能存在的MIME类型前缀
if images_base64.startswith("data:"):
images_base64 = images_base64.split(",", 1)[1]
# 解码base64
image_bytes = base64.b64decode(images_base64)
# 创建PIL Image对象
images = Image.open(BytesIO(image_bytes))
except Exception as e:
print(f"解码失败: {e}")
# 4. 调用模板生成海报
try:
@ -213,22 +230,22 @@ class PosterService:
# 5. 保存海报并返回结果
variations = []
for i, poster in enumerate(posters):
output_path = self._save_poster(poster, template_id, i)
if output_path:
variations.append({
"variation_id": i,
"poster_path": str(output_path),
"base64": self._image_to_base64(poster)
})
i=0 ## 用于多个海报时,指定海报的编号,此时只有一个没有用上,但是接口开放着。
output_path = self._save_poster(posters, template_id, i)
if output_path:
variations.append({
"variation_id": i,
"poster_path": str(output_path),
"base64": self._image_to_base64(posters)
})
# 记录模板使用情况
self._update_template_stats(template_id, bool(variations), time.time() - start_time)
return {
"request_id": f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}",
"template_id": template_id,
"variations": variations,
"requestId": f"poster-{datetime.now().strftime('%Y%m%d-%H%M%S')}-{str(uuid.uuid4())[:8]}",
"templateId": template_id,
"resultImagesBase64": variations,
"metadata": {
"generation_time": f"{time.time() - start_time:.2f}s",
"model_used": self.ai_agent.config.model if force_llm_generation or not poster_content else None,
@ -240,7 +257,7 @@ class PosterService:
self._update_template_stats(template_id, False, time.time() - start_time)
raise ValueError(f"生成海报失败: {str(e)}")
def _save_poster(self, poster: Image.Image, template_id: str, variation_id: int) -> Optional[Path]:
def _save_poster(self, poster: Image.Image, template_id: str, variation_id: int=1) -> Optional[Path]:
"""保存海报到文件系统"""
try:
# 创建唯一的主题ID用于保存
@ -285,9 +302,8 @@ class PosterService:
"""使用LLM生成海报内容"""
# 获取提示词
template_info = self._templates.get(template_id, {})
system_prompt = template_info.get('system_prompt', "")
user_prompt_template = template_info.get('user_prompt_template', "")
system_prompt = template_info.get('systemPrompt', "")
user_prompt_template = template_info.get('userPromptTemplate', "")
if not system_prompt or not user_prompt_template:
logger.error(f"模板 {template_id} 缺少提示词配置")
return None
@ -300,16 +316,18 @@ class PosterService:
data['product'] = self.db_service.get_product_by_id(product_id)
if scenic_spot_id:
data['scenic_spot'] = self.db_service.get_scenic_spot_by_id(scenic_spot_id)
logger.info(f"data: {data}")
# 格式化提示词
try:
user_prompt = user_prompt_template.format(**data)
logger.info(f"user_prompt: {user_prompt}")
except KeyError as e:
logger.warning(f"格式化提示词时缺少键: {e}")
user_prompt = user_prompt_template + f"\n可用数据: {json.dumps(data, ensure_ascii=False)}"
try:
response, _, _, _ = await self.ai_agent.generate_text(system_prompt=system_prompt, user_prompt=user_prompt)
response, _, _, _ = await self.ai_agent.generate_text(system_prompt=system_prompt, user_prompt=user_prompt,use_stream=True)
json_start = response.find('{')
json_end = response.rfind('}') + 1
if json_start != -1 and json_end != -1:

View File

@ -1,8 +1,13 @@
{
"host": "localhost",
"user": "root",
<<<<<<< HEAD
"password": "Kj#9mP2$",
"database": "travel_content",
=======
"password": "mysql2025.",
"database": "bangbang",
>>>>>>> poster_update
"port": 3306,
"charset": "utf8mb4"
}

139
main.py
View File

@ -1,139 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Travel Content Creator
主入口文件
"""
import os
import sys
import time
import logging
import asyncio
import argparse
from datetime import datetime
from pathlib import Path
from core.config import get_config_manager
from core.ai import AIAgent
from utils.file_io import OutputManager
from tweet.topic_generator import TopicGenerator
from tweet.content_generator import ContentGenerator
from tweet.content_judger import ContentJudger
from poster.poster_generator import PosterGenerator
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)
class Pipeline:
"""
内容生成流水线
协调各个模块的工作
"""
def __init__(self):
# 初始化配置
self.config_manager = get_config_manager()
self.config_manager.load_from_directory("config")
# 初始化输出管理器
run_id = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.output_manager = OutputManager("result", run_id)
# 初始化AI代理
ai_config = self.config_manager.get_config('ai_model', AIModelConfig)
self.ai_agent = AIAgent(ai_config)
# 初始化各个组件
self.topic_generator = TopicGenerator(self.ai_agent, self.config_manager, self.output_manager)
self.content_generator = ContentGenerator(self.ai_agent, self.config_manager, self.output_manager)
self.content_judger = ContentJudger(self.ai_agent, self.config_manager, self.output_manager)
self.poster_generator = PosterGenerator(self.config_manager, self.output_manager)
async def run(self):
"""运行完整流水线"""
start_time = time.time()
logger.info("--- 开始执行内容生成流水线 ---")
# 步骤1: 生成选题
logger.info("--- 步骤 1: 开始生成选题 ---")
topics = await self.topic_generator.generate_topics()
if not topics:
logger.error("未能生成任何选题,流程终止。")
return
logger.info(f"成功生成 {len(topics)} 个选题")
# 步骤2: 为每个选题生成内容
logger.info("--- 步骤 2: 开始生成内容 ---")
contents = {}
for topic in topics:
topic_index = topic.get('index', 'unknown')
logger.info(f"--- 步骤 2: 开始为选题 {topic_index} 生成内容 ---")
content = await self.content_generator.generate_content_for_topic(topic)
contents[topic_index] = content
# 步骤3: 审核内容
logger.info("--- 步骤 3: 开始审核内容 ---")
judged_contents = {}
for topic_index, content in contents.items():
topic = next((t for t in topics if t.get('index') == topic_index), None)
if not topic:
logger.warning(f"找不到选题 {topic_index} 的原始数据,跳过审核")
continue
logger.info(f"--- 步骤 3: 开始审核选题 {topic_index} 的内容 ---")
try:
judged_data = await self.content_judger.judge_content(content, topic)
judged_contents[topic_index] = judged_data
except Exception as e:
logger.critical(f"为选题 {topic_index} 处理内容审核时发生意外错误: {e}", exc_info=True)
# 步骤4: 生成海报
# logger.info("--- 步骤 4: 开始生成海报 ---")
# posters = {}
# for topic_index, content in judged_contents.items():
# if not content.get('judge_success', False):
# logger.warning(f"选题 {topic_index} 的内容审核未通过,跳过海报生成")
# continue
# logger.info(f"--- 步骤 4: 开始为选题 {topic_index} 生成海报 ---")
# poster_path = self.poster_generator.generate_poster(content, topic_index)
# if poster_path:
# posters[topic_index] = poster_path
# 完成
logger.info("--- 所有任务已完成 ---")
end_time = time.time()
logger.info(f"--- 运行结束 --- 耗时: {end_time - start_time:.2f} 秒 ---")
async def main():
"""主函数"""
parser = argparse.ArgumentParser(description="Travel Content Creator")
parser.add_argument("--config-dir", default="config", help="配置目录路径")
args = parser.parse_args()
# 检查配置目录
if not os.path.isdir(args.config_dir):
logger.error(f"配置目录不存在: {args.config_dir}")
sys.exit(1)
# 运行流水线
pipeline = Pipeline()
await pipeline.run()
if __name__ == "__main__":
# 导入这里避免循环导入
from core.config import AIModelConfig
asyncio.run(main())

View File

@ -4,6 +4,7 @@
"""
Vibrant风格活力风格海报模板
"""
from ast import List
import logging
import math
from typing import Dict, Any, Optional, Tuple
@ -44,19 +45,21 @@ class VibrantTemplate(BaseTemplate):
}
def generate(self,
image_path: str,
images: List,
content: Optional[Dict[str, Any]] = None,
theme_color: Optional[str] = None,
glass_intensity: float = 1.5,
num_variations: int = 1,
**kwargs) -> Image.Image:
"""
生成Vibrant风格海报
Args:
image_path (str): 主图路径
images (List): 主图
content (Optional[Dict[str, Any]]): 包含所有文本信息的字典
theme_color (Optional[str]): 预设颜色主题的名称
glass_intensity (float): 毛玻璃效果强度
num_variations (int): 生成海报数量
Returns:
Image.Image: 生成的海报图像
@ -66,12 +69,13 @@ class VibrantTemplate(BaseTemplate):
self.config['glass_effect']['intensity_multiplier'] = glass_intensity
main_image = self.image_processor.load_image(image_path)
main_image = images
logger.info(f"main_image的类型: {np.shape(main_image)}")
if not main_image:
logger.error(f"无法加载图片: {image_path}")
logger.error(f"无法加载图片: ")
return None
main_image = self.image_processor.resize_and_crop(main_image, self.size)
main_image = self.image_processor.resize_image(image=main_image, target_size=self.size)
estimated_height = self._estimate_content_height(content)
gradient_start = self._detect_gradient_start_position(main_image, estimated_height)
@ -254,13 +258,13 @@ class VibrantTemplate(BaseTemplate):
def _calculate_content_margins(self, content: Dict[str, Any], width: int, center_x: int) -> Tuple[int, int]:
"""计算内容区域的左右边距"""
title_text = content.get("title", "")
title_size, title_width = self.text_renderer.calculate_font_size_and_width(
title_text, int(width * 0.95), max_size=130)
title_size=self.text_renderer.calculate_optimal_font_size(title_text,int(width * 0.95),max_size=130)
title_width,title_height=self.text_renderer.get_text_size(title_text,self.text_renderer._load_default_font(title_size))
title_x = center_x - title_width // 2
slogan_text = content.get("slogan", "")
subtitle_size, subtitle_width = self.text_renderer.calculate_font_size_and_width(
slogan_text, int(width * 0.9), max_size=50)
subtitle_size=self.text_renderer.calculate_optimal_font_size(slogan_text,int(width * 0.9),max_size=50)
subtitle_width,subtitle_height=self.text_renderer.get_text_size(slogan_text,self.text_renderer._load_default_font(subtitle_size))
subtitle_x = center_x - subtitle_width // 2
padding = 20
@ -276,7 +280,7 @@ class VibrantTemplate(BaseTemplate):
def _render_footer(self, draw: ImageDraw.Draw, content: Dict[str, Any], y: int, left: int, right: int):
"""渲染页脚文本"""
font = self.text_renderer.load_font(18)
font = self.text_renderer._load_default_font(18)
if tag := content.get("tag"):
draw.text((left, y), tag, font=font, fill=(255, 255, 255))
if pagination := content.get("pagination"):
@ -288,8 +292,8 @@ class VibrantTemplate(BaseTemplate):
# 标题
title_text = content.get("title", "默认标题")
title_target_width = int((right - left) * 0.98)
title_size, _ = self.text_renderer.calculate_font_size_and_width(title_text, title_target_width, max_size=140, min_size=40)
title_font = self.text_renderer.load_font(title_size)
title_size=self.text_renderer.calculate_optimal_font_size(title_text,title_target_width,max_size=140,min_size=40)
title_font = self.text_renderer._load_default_font(title_size)
text_w, text_h = self.text_renderer.get_text_size(title_text, title_font)
title_x = center_x - text_w // 2
@ -299,8 +303,8 @@ class VibrantTemplate(BaseTemplate):
# 副标题 (slogan)
subtitle_text = content.get("slogan", "")
subtitle_target_width = int((right - left) * 0.95)
subtitle_size, _ = self.text_renderer.calculate_font_size_and_width(subtitle_text, subtitle_target_width, max_size=75, min_size=20)
subtitle_font = self.text_renderer.load_font(subtitle_size)
subtitle_size=self.text_renderer.calculate_optimal_font_size(subtitle_text,subtitle_target_width,max_size=75,min_size=20)
subtitle_font = self.text_renderer._load_default_font(subtitle_size)
sub_text_w, sub_text_h = self.text_renderer.get_text_size(subtitle_text, subtitle_font)
subtitle_x = center_x - sub_text_w // 2
@ -321,18 +325,18 @@ class VibrantTemplate(BaseTemplate):
def _render_left_column(self, draw: ImageDraw.Draw, content: Dict[str, Any], y: int, x: int, width: int, canvas_height: int):
"""渲染左栏内容:按钮和项目列表"""
button_font = self.text_renderer.load_font(30)
button_font = self.text_renderer._load_default_font(30)
button_text = content.get("content_button", "套餐内容")
button_width, _ = self.text_renderer.get_text_size(button_text, button_font)
button_width += 40
button_height = 50
self.text_renderer.draw_rounded_rectangle(draw, (x, y), (button_width, button_height), 20, fill=(0, 140, 210, 180), outline=(255, 255, 255, 255), width=1)
self.text_renderer.draw_rounded_rectangle(draw=draw, position=(x, y), size=(button_width, button_height), radius=20, fill_color=(0, 140, 210, 180), outline_color=(255, 255, 255, 255), outline_width=1)
draw.text((x + 20, y + (button_height - 30) // 2), button_text, font=button_font, fill=(255, 255, 255))
items = content.get("content_items", [])
if not items: return
font = self.text_renderer.load_font(28)
font = self.text_renderer._load_default_font(28)
list_y = y + button_height + 20
available_h = canvas_height - 30 - (len(content.get("remarks", [])) * 25 + 10) - list_y - 20
total_items_h = len(items) * 36
@ -346,11 +350,12 @@ class VibrantTemplate(BaseTemplate):
def _render_right_column(self, draw: ImageDraw.Draw, content: Dict[str, Any], y: int, x: int, right_margin: int):
"""渲染右栏内容:价格、票种和备注"""
price_text = content.get('price', '')
price_size, price_width = self.text_renderer.calculate_font_size_and_width(
price_text, int((right_margin - x) * 0.7), max_size=120, min_size=40)
price_font = self.text_renderer.load_font(price_size)
price_size=self.text_renderer.calculate_optimal_font_size(price_text,int((right_margin - x) * 0.7),max_size=120,min_size=40)
price_width,_=self.text_renderer.get_text_size(price_text,self.text_renderer._load_default_font(price_size))
suffix_font = self.text_renderer.load_font(int(price_size * 0.3))
price_font = self.text_renderer._load_default_font(price_size)
suffix_font = self.text_renderer._load_default_font(int(price_size * 0.3))
_, price_height = self.text_renderer.get_text_size(price_text, price_font)
suffix_width, suffix_height = self.text_renderer.get_text_size("CNY起", suffix_font)
@ -364,9 +369,9 @@ class VibrantTemplate(BaseTemplate):
draw.line([(price_x - 10, underline_y), (right_margin, underline_y)], fill=(255, 255, 255, 80), width=2)
ticket_text = content.get("ticket_type", "")
ticket_size, ticket_width = self.text_renderer.calculate_font_size_and_width(
ticket_text, int((right_margin - x) * 0.7), max_size=60, min_size=30)
ticket_font = self.text_renderer.load_font(ticket_size)
ticket_size=self.text_renderer.calculate_optimal_font_size(ticket_text,int((right_margin - x) * 0.7),max_size=60,min_size=30)
ticket_width,_=self.text_renderer.get_text_size(ticket_text,self.text_renderer._load_default_font(ticket_size))
ticket_font = self.text_renderer._load_default_font(ticket_size)
ticket_x = right_margin - ticket_width
ticket_y = y + price_height + 35
self.text_renderer.draw_text_with_shadow(draw, (ticket_x, ticket_y), ticket_text, ticket_font)
@ -374,7 +379,7 @@ class VibrantTemplate(BaseTemplate):
remarks = content.get("remarks", [])
if remarks:
remarks_font = self.text_renderer.load_font(16)
remarks_font = self.text_renderer._load_default_font(16)
remarks_y = ticket_y + ticket_height + 30
for i, remark in enumerate(remarks):
remark_width, _ = self.text_renderer.get_text_size(remark, remarks_font)

View File

@ -213,6 +213,141 @@ class TextRenderer:
self.logger.warning(f"加载默认字体失败: {e}")
return ImageFont.load_default()
def draw_text_with_outline(self, draw: ImageDraw.Draw,
position: Tuple[int, int],
text: str,
font: ImageFont.FreeTypeFont,
text_color: Tuple[int, int, int, int] = (255, 255, 255, 255),
outline_color: Tuple[int, int, int, int] = (0, 0, 0, 255),
outline_width: int = 2):
"""
绘制带描边的文字
Args:
draw: PIL绘图对象
position: 文字位置
text: 文字内容
font: 字体对象
text_color: 文字颜色
outline_color: 描边颜色
outline_width: 描边宽度
"""
x, y = position
# 绘制描边
for offset_x in range(-outline_width, outline_width + 1):
for offset_y in range(-outline_width, outline_width + 1):
if offset_x == 0 and offset_y == 0:
continue
draw.text((x + offset_x, y + offset_y), text, font=font, fill=outline_color)
# 绘制文字
draw.text(position, text, font=font, fill=text_color)
def draw_rounded_rectangle(self, draw: ImageDraw.Draw,
position: Tuple[int, int],
size: Tuple[int, int],
radius: int,
fill_color: Tuple[int, int, int, int],
outline_color: Optional[Tuple[int, int, int, int]] = None,
outline_width: int = 0):
"""
绘制圆角矩形
Args:
draw: PIL绘图对象
position: 左上角位置
size: 矩形大小
radius: 圆角半径
fill_color: 填充颜色
outline_color: 边框颜色
outline_width: 边框宽度
"""
x, y = position
width, height = size
# 确保尺寸有效
if width <= 0 or height <= 0:
return
# 限制圆角半径
radius = min(radius, width // 2, height // 2)
# 创建圆角矩形路径
# 这是一个简化版本PIL的较新版本有更好的圆角矩形支持
if radius > 0:
# 绘制中心矩形
draw.rectangle([x + radius, y, x + width - radius, y + height], fill=fill_color)
draw.rectangle([x, y + radius, x + width, y + height - radius], fill=fill_color)
# 绘制四个圆角
draw.pieslice([x, y, x + 2*radius, y + 2*radius], 180, 270, fill=fill_color)
draw.pieslice([x + width - 2*radius, y, x + width, y + 2*radius], 270, 360, fill=fill_color)
draw.pieslice([x, y + height - 2*radius, x + 2*radius, y + height], 90, 180, fill=fill_color)
draw.pieslice([x + width - 2*radius, y + height - 2*radius, x + width, y + height], 0, 90, fill=fill_color)
else:
# 普通矩形
draw.rectangle([x, y, x + width, y + height], fill=fill_color)
# 绘制边框(如果需要)
if outline_color and outline_width > 0:
# 简化的边框绘制 - 使用线条而不是矩形避免坐标错误
for i in range(outline_width):
offset = i
# 确保坐标有效
if radius > 0:
# 上边
if x + radius + offset < x + width - radius - offset:
draw.line([x + radius + offset, y + offset,
x + width - radius - offset, y + offset],
fill=outline_color, width=1)
# 下边
if x + radius + offset < x + width - radius - offset and y + height - offset >= y + offset:
draw.line([x + radius + offset, y + height - offset,
x + width - radius - offset, y + height - offset],
fill=outline_color, width=1)
# 左边
if y + radius + offset < y + height - radius - offset:
draw.line([x + offset, y + radius + offset,
x + offset, y + height - radius - offset],
fill=outline_color, width=1)
# 右边
if y + radius + offset < y + height - radius - offset:
draw.line([x + width - offset, y + radius + offset,
x + width - offset, y + height - radius - offset],
fill=outline_color, width=1)
else:
# 普通矩形边框
draw.rectangle([x + offset, y + offset, x + width - offset, y + height - offset],
outline=outline_color, width=1)
def draw_text_with_shadow(self, draw: ImageDraw.Draw,
position: Tuple[int, int],
text: str,
font: ImageFont.FreeTypeFont,
text_color: Tuple[int, int, int, int] = (255, 255, 255, 255),
shadow_color: Tuple[int, int, int, int] = (0, 0, 0, 128),
shadow_offset: Tuple[int, int] = (2, 2)):
"""
绘制带阴影的文字
Args:
draw: PIL绘图对象
position: 文字位置
text: 文字内容
font: 字体对象
text_color: 文字颜色
shadow_color: 阴影颜色
shadow_offset: 阴影偏移
"""
x, y = position
shadow_x, shadow_y = shadow_offset
# 绘制阴影
draw.text((x + shadow_x, y + shadow_y), text, font=font, fill=shadow_color)
# 绘制文字
draw.text(position, text, font=font, fill=text_color)
def get_font(self, font_name: Optional[str] = None, size: int = 24) -> ImageFont.FreeTypeFont:
"""
获取指定字体
@ -318,6 +453,70 @@ class TextRenderer:
return image
def calculate_optimal_font_size(self, text: str,
target_width: int,
font_name: Optional[str] = None,
max_size: int = 120,
min_size: int = 10) -> int:
"""
计算最适合的字体大小
Args:
text: 文字内容
target_width: 目标宽度
font_name: 字体文件名
max_size: 最大字体大小
min_size: 最小字体大小
Returns:
最适合的字体大小
"""
if not text.strip():
return min_size
# 二分查找最佳字体大小
left, right = min_size, max_size
best_size = min_size
while left <= right:
mid_size = (left + right) // 2
try:
if font_path:
font = self._load_default_font(mid_size)
else:
font = ImageFont.load_default()
# 获取文字边界框
bbox = font.getbbox(text)
text_width = bbox[2] - bbox[0]
if text_width <= target_width:
best_size = mid_size
left = mid_size + 1
else:
right = mid_size - 1
except Exception:
right = mid_size - 1
return best_size
def get_text_size(self, text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]:
"""
获取文字的尺寸
Args:
text: 文字内容
font: 字体对象
Returns:
文字尺寸 (width, height)
"""
bbox = font.getbbox(text)
return bbox[2] - bbox[0], bbox[3] - bbox[1]
def draw_multiline_text(self, image: Image.Image, text: str, position: Tuple[int, int],
font_name: Optional[str] = None, font_size: int = 24,
color: Tuple[int, int, int] = (0, 0, 0),

36
requirements_complete.txt Normal file
View File

@ -0,0 +1,36 @@
# 原有依赖
json_repair==0.47.6
numpy==2.3.1
openai==1.93.3
opencv_python==4.11.0.86
opencv_python_headless==4.11.0.86
Pillow==11.3.0
psutil==6.1.0
pydantic==2.11.7
scikit_learn==1.7.0
scipy==1.16.0
simplejson==3.20.1
tiktoken==0.9.0
# Web框架相关
fastapi>=0.104.1
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
# 数据库相关
sqlalchemy>=2.0.0
asyncpg>=0.29.0
aiosqlite>=0.19.0
# 认证和安全
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
# 工具库
python-dotenv>=1.0.0
aiofiles>=23.2.1
requests>=2.31.0
# 可选依赖
pyyaml>=6.0.1
jinja2>=3.1.2