417 lines
18 KiB
Python
417 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
提示词管理和构建模块
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
import random
|
||
from typing import Dict, Any, Optional, List
|
||
from pathlib import Path
|
||
import json
|
||
|
||
from core.config import (
|
||
ConfigManager,
|
||
GenerateTopicConfig,
|
||
GenerateContentConfig,
|
||
SystemConfig,
|
||
ResourceConfig
|
||
)
|
||
from .file_io import ResourceLoader
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PromptTemplate:
|
||
"""
|
||
提示词模板基类
|
||
"""
|
||
def __init__(self, system_prompt_path: str, user_prompt_path: str):
|
||
"""
|
||
使用明确的系统和用户提示文件路径初始化
|
||
|
||
Args:
|
||
system_prompt_path (str): 系统提示模板文件的路径
|
||
user_prompt_path (str): 用户提示模板文件的路径
|
||
"""
|
||
self.system_path = Path(system_prompt_path)
|
||
self.user_path = Path(user_prompt_path)
|
||
|
||
self.system_template = self._read_template(self.system_path)
|
||
self.user_template = self._read_template(self.user_path)
|
||
|
||
def _read_template(self, path: Path) -> str:
|
||
if not path.is_file():
|
||
logger.warning(f"模板文件不存在,已跳过: {path}")
|
||
return ""
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
|
||
def get_system_prompt(self, **kwargs) -> str:
|
||
"""获取格式化后的系统提示词"""
|
||
if not self.system_template:
|
||
return ""
|
||
try:
|
||
if not kwargs:
|
||
return self.system_template
|
||
return self.system_template.format(**kwargs)
|
||
except KeyError as e:
|
||
logger.error(f"渲染系统提示 '{self.system_path}' 时缺少键: {e}")
|
||
raise
|
||
|
||
def build_user_prompt(self, **kwargs) -> str:
|
||
"""构建用户提示词"""
|
||
if not self.user_template:
|
||
return ""
|
||
try:
|
||
return self.user_template.format(**kwargs)
|
||
except KeyError as e:
|
||
logger.error(f"渲染用户提示 '{self.user_path}' 时缺少键: {e}")
|
||
raise
|
||
|
||
class BasePromptBuilder(PromptTemplate):
|
||
"""
|
||
包含通用资源加载方法的构建器基类
|
||
"""
|
||
def __init__(self, config_manager: ConfigManager, system_prompt_path: str, user_prompt_path: str):
|
||
super().__init__(system_prompt_path, user_prompt_path)
|
||
self.config_manager = config_manager
|
||
self.system_config: SystemConfig = config_manager.get_config('system', SystemConfig)
|
||
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
|
||
|
||
def _get_full_path(self, path_str: str) -> Path:
|
||
"""根据基准目录解析相对或绝对路径"""
|
||
if not self.resource_config.resource_dirs:
|
||
raise ValueError("Resource directories list is empty in config.")
|
||
base_path = Path(self.resource_config.resource_dirs[0])
|
||
file_path = Path(path_str)
|
||
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
|
||
|
||
def _load_resource_as_list(self, name: str, paths: list[str]) -> str:
|
||
"""加载资源,格式化为文件名列表"""
|
||
if not paths:
|
||
return f"{name}:\n无"
|
||
|
||
content_parts = [f"{name}:"]
|
||
for path_str in paths:
|
||
try:
|
||
full_path = self._get_full_path(path_str)
|
||
if full_path.is_file():
|
||
content_parts.append(f"- {full_path.name}")
|
||
elif full_path.is_dir():
|
||
for f in sorted(p for p in full_path.iterdir() if p.is_file()):
|
||
content_parts.append(f"- {f.name}")
|
||
except Exception as e:
|
||
logger.error(f"处理路径 '{path_str}' 失败: {e}", exc_info=True)
|
||
|
||
return "\n".join(content_parts)
|
||
|
||
def _load_resource_as_content(self, name: str, paths: list[str]) -> str:
|
||
"""加载资源,格式化为包含文件内容的块"""
|
||
if not paths:
|
||
return f"{name}:\n无"
|
||
|
||
content_parts = [f"{name}:"]
|
||
for path_str in paths:
|
||
try:
|
||
full_path = self._get_full_path(path_str)
|
||
files_to_read = []
|
||
if full_path.is_file():
|
||
files_to_read.append(full_path)
|
||
elif full_path.is_dir():
|
||
files_to_read.extend(sorted(p for p in full_path.iterdir() if p.is_file()))
|
||
|
||
for f_path in files_to_read:
|
||
with f_path.open('r', encoding='utf-8') as f:
|
||
file_content = f.read()
|
||
content_parts.append(f"--- {f_path.name} ---\n{file_content}")
|
||
except Exception as e:
|
||
logger.error(f"加载资源 '{path_str}' 失败: {e}", exc_info=True)
|
||
|
||
return "\n\n".join(content_parts)
|
||
|
||
def _load_and_format_content(self, path: Path) -> str:
|
||
"""根据文件类型加载和格式化内容"""
|
||
if path.suffix == '.json':
|
||
try:
|
||
with path.open('r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
if "examples" in data and isinstance(data["examples"], list):
|
||
examples = data["examples"]
|
||
return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in examples])
|
||
else:
|
||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
logger.error(f"解析或格式化JSON文件 '{path}' 失败: {e}")
|
||
return f"加载文件 '{path.name}' 失败。"
|
||
else:
|
||
return path.read_text('utf-8')
|
||
|
||
def _load_and_format_content_with_sampling(self, path: Path, sampling_rate: float) -> str:
|
||
"""根据文件类型加载和格式化内容,并应用采样率"""
|
||
if path.suffix == '.json':
|
||
try:
|
||
with path.open('r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
if "examples" in data and isinstance(data["examples"], list):
|
||
examples = data["examples"]
|
||
# 应用采样率
|
||
sample_size = max(1, int(len(examples) * sampling_rate))
|
||
sampled_examples = random.sample(examples, sample_size)
|
||
logger.info(f"文件 '{path.name}' 中的examples采样: {sample_size}/{len(examples)} (采样率: {sampling_rate:.2f})")
|
||
return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in sampled_examples])
|
||
else:
|
||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
logger.error(f"解析或格式化JSON文件 '{path}' 失败: {e}")
|
||
return f"加载文件 '{path.name}' 失败。"
|
||
else:
|
||
return path.read_text('utf-8')
|
||
|
||
def _load_refer_content(self, name: str, refer_configs: list, step: str = "") -> str:
|
||
"""
|
||
根据带采样率的配置加载Refer内容
|
||
|
||
Args:
|
||
name: 资源名称
|
||
refer_configs: Refer配置列表
|
||
step: 当前阶段,用于过滤引用内容
|
||
"""
|
||
if not refer_configs:
|
||
return f"{name}:\n无"
|
||
|
||
content_parts = [f"{name}:"]
|
||
|
||
# 根据step过滤引用内容
|
||
filtered_configs = refer_configs
|
||
if step:
|
||
filtered_configs = [
|
||
item for item in refer_configs
|
||
if not item.step or item.step == step # 如果step为空或与当前阶段匹配
|
||
]
|
||
|
||
if not filtered_configs:
|
||
return f"{name}:\n无 (当前阶段: {step})"
|
||
|
||
logger.info(f"阶段 '{step}' 过滤后剩余 {len(filtered_configs)}/{len(refer_configs)} 个引用项")
|
||
|
||
# 为每次调用重新随机采样
|
||
random.seed()
|
||
|
||
for ref_item in filtered_configs:
|
||
try:
|
||
path_str = ref_item.path
|
||
sampling_rate = ref_item.sampling_rate
|
||
|
||
full_path = self._get_full_path(path_str)
|
||
|
||
# 简化逻辑:对于单个文件,直接应用采样率决定是否加载
|
||
if full_path.is_file():
|
||
# 对于JSON文件,对内容进行采样
|
||
if full_path.suffix == '.json':
|
||
file_content = self._load_and_format_content_with_sampling(full_path, sampling_rate)
|
||
content_parts.append(f"--- {full_path.name} ---\n{file_content}")
|
||
logger.info(f"加载JSON文件 '{path_str}' 并应用内部采样")
|
||
# 对于其他文件,根据采样率决定是否完全加载
|
||
elif random.random() < sampling_rate:
|
||
file_content = self._load_and_format_content(full_path)
|
||
content_parts.append(f"--- {full_path.name} ---\n{file_content}")
|
||
logger.info(f"文件 '{path_str}' 采样成功 (采样率: {sampling_rate:.2f})")
|
||
else:
|
||
logger.info(f"文件 '{path_str}' 采样失败 (采样率: {sampling_rate:.2f})")
|
||
# 对于目录,直接选择指定比例的文件
|
||
elif full_path.is_dir():
|
||
all_files = sorted(p for p in full_path.iterdir() if p.is_file())
|
||
if all_files:
|
||
if sampling_rate < 1.0:
|
||
sample_size = max(1, int(len(all_files) * sampling_rate))
|
||
files_to_read = random.sample(all_files, sample_size)
|
||
logger.info(f"目录 '{path_str}' 采样: {sample_size}/{len(all_files)} 个文件 (采样率: {sampling_rate:.2f})")
|
||
else:
|
||
files_to_read = all_files
|
||
logger.info(f"目录 '{path_str}' 全部加载: {len(all_files)} 个文件")
|
||
|
||
for f_path in files_to_read:
|
||
file_content = self._load_and_format_content(f_path)
|
||
content_parts.append(f"--- {f_path.name} ---\n{file_content}")
|
||
else:
|
||
logger.warning(f"目录 '{path_str}' 中没有文件")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载Refer资源 '{ref_item}' 失败: {e}", exc_info=True)
|
||
|
||
return "\n\n".join(content_parts)
|
||
|
||
def _find_and_read_file(self, filename_to_find: str, search_paths: List[str]) -> Optional[str]:
|
||
"""在指定的路径列表中查找并读取文件内容"""
|
||
if not self.resource_config.resource_dirs:
|
||
return "文件未找到"
|
||
base_path = Path(self.resource_config.resource_dirs[0])
|
||
|
||
# 获取不带后缀的文件名,用于模糊匹配
|
||
filename_base = Path(filename_to_find).stem
|
||
|
||
for p_str in search_paths:
|
||
search_path = Path(p_str)
|
||
if not search_path.is_absolute():
|
||
search_path = base_path / search_path
|
||
|
||
if search_path.is_dir():
|
||
# 精确匹配
|
||
potential_path = search_path / filename_to_find
|
||
if potential_path.is_file():
|
||
return potential_path.read_text('utf-8')
|
||
|
||
# 基于文件名(不含后缀)的模糊匹配
|
||
for file_path in search_path.iterdir():
|
||
if file_path.is_file() and file_path.stem == filename_base:
|
||
logger.info(f"通过基础名称匹配找到文件: {filename_to_find} -> {file_path.name}")
|
||
return file_path.read_text('utf-8')
|
||
|
||
elif search_path.is_file():
|
||
# 精确匹配
|
||
if search_path.name == filename_to_find:
|
||
return search_path.read_text('utf-8')
|
||
# 基于文件名(不含后缀)的模糊匹配
|
||
if search_path.stem == filename_base:
|
||
logger.info(f"通过基础名称匹配找到文件: {filename_to_find} -> {search_path.name}")
|
||
return search_path.read_text('utf-8')
|
||
|
||
logger.warning(f"在路径 {search_paths} 中未能找到文件: {filename_to_find}")
|
||
return None
|
||
|
||
class TopicPromptBuilder(BasePromptBuilder):
|
||
"""
|
||
主题生成提示词构建器
|
||
"""
|
||
def __init__(self, config_manager: ConfigManager):
|
||
self.topic_config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
|
||
super().__init__(
|
||
config_manager,
|
||
system_prompt_path=self.topic_config.topic_system_prompt,
|
||
user_prompt_path=self.topic_config.topic_user_prompt
|
||
)
|
||
|
||
# 加载并格式化所有需要的资源
|
||
self.style_content = self._load_resource_as_list("Style文件列表", self.resource_config.style.paths)
|
||
self.demand_content = self._load_resource_as_list("Demand文件列表", self.resource_config.demand.paths)
|
||
self.refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="topic")
|
||
self.object_content = self._load_resource_as_content("Object信息", self.resource_config.object.paths)
|
||
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
|
||
|
||
def build_user_prompt(self, numTopics: int, month: str) -> str:
|
||
"""构建生成主题的用户提示词"""
|
||
creative_materials = (
|
||
f"你拥有的创作资料如下:\n"
|
||
f"{self.style_content}\n\n"
|
||
f"{self.demand_content}\n\n"
|
||
f"{self.refer_content}\n\n"
|
||
f"{self.object_content}\n\n"
|
||
f"{self.product_content}"
|
||
)
|
||
return super().build_user_prompt(
|
||
creative_materials=creative_materials,
|
||
numTopics=numTopics,
|
||
month=month
|
||
)
|
||
|
||
class ContentPromptBuilder(BasePromptBuilder):
|
||
"""
|
||
内容生成提示词构建器
|
||
"""
|
||
def __init__(self, config_manager: ConfigManager):
|
||
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
|
||
super().__init__(
|
||
config_manager,
|
||
system_prompt_path=self.content_config.content_system_prompt,
|
||
user_prompt_path=self.content_config.content_user_prompt
|
||
)
|
||
|
||
# 预加载静态的Product内容
|
||
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
|
||
|
||
def build_user_prompt(self, topic: Dict[str, Any]) -> str:
|
||
"""根据topic构建完整的用户提示词"""
|
||
# 每次构建提示词时重新加载Refer内容,确保随机性
|
||
refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="content")
|
||
|
||
style_filename = topic.get("style", "")
|
||
style_content = self._find_and_read_file(style_filename, self.resource_config.style.paths) or "通用风格"
|
||
|
||
demand_filename = topic.get("targetAudience", "")
|
||
demand_content = self._find_and_read_file(demand_filename, self.resource_config.demand.paths) or "通用用户画像"
|
||
|
||
object_name = topic.get("object", "")
|
||
object_content = "无"
|
||
if object_name:
|
||
try:
|
||
object_path_str = next((p for p in self.resource_config.object.paths if object_name in p), None)
|
||
if object_path_str:
|
||
full_path = self._get_full_path(object_path_str)
|
||
if full_path.exists():
|
||
object_content = full_path.read_text('utf-8')
|
||
else:
|
||
logger.warning(f"找不到Object文件: {full_path}")
|
||
else:
|
||
logger.warning(f"在配置中找不到与 '{object_name}' 匹配的object路径")
|
||
except Exception as e:
|
||
logger.error(f"加载Object内容 '{object_name}' 失败: {e}")
|
||
|
||
return super().build_user_prompt(
|
||
style_content=f"{style_filename}\n{style_content}",
|
||
demand_content=f"{demand_filename}\n{demand_content}",
|
||
object_content=f"{object_name}\n{object_content}",
|
||
refer_content=refer_content,
|
||
product_content=self.product_content,
|
||
)
|
||
|
||
class JudgerPromptBuilder(BasePromptBuilder):
|
||
"""
|
||
内容审核提示词构建器
|
||
"""
|
||
def __init__(self, config_manager: ConfigManager):
|
||
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
|
||
super().__init__(
|
||
config_manager,
|
||
system_prompt_path=self.content_config.judger_system_prompt,
|
||
user_prompt_path=self.content_config.judger_user_prompt
|
||
)
|
||
|
||
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
|
||
|
||
def build_user_prompt(self, generated_content: str, topic: Dict[str, Any]) -> str:
|
||
"""构建审核内容的用户提示词"""
|
||
# 每次构建提示词时重新加载Refer内容,确保随机性
|
||
refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="judge")
|
||
|
||
if isinstance(generated_content, dict):
|
||
tweet_content = json.dumps(generated_content, ensure_ascii=False, indent=4)
|
||
else:
|
||
tweet_content = str(generated_content)
|
||
|
||
object_name = topic.get("object", "")
|
||
object_content = "无"
|
||
if object_name:
|
||
try:
|
||
object_path_str = next((p for p in self.resource_config.object.paths if object_name in p), None)
|
||
if object_path_str:
|
||
full_path = self._get_full_path(object_path_str)
|
||
if full_path.exists():
|
||
object_content = full_path.read_text('utf-8')
|
||
else:
|
||
logger.warning(f"找不到Object文件: {full_path}")
|
||
else:
|
||
logger.warning(f"在配置中找不到与 '{object_name}' 匹配的object路径")
|
||
except Exception as e:
|
||
logger.error(f"加载Object内容 '{object_name}' 失败: {e}")
|
||
|
||
return super().build_user_prompt(
|
||
tweet_content=tweet_content,
|
||
object_content=object_content,
|
||
product_content=self.product_content,
|
||
refer_content=refer_content
|
||
) |