417 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
提示词管理和构建模块
"""
import os
import logging
import random
from typing import Dict, Any, Optional, List
from pathlib import Path
import json
from core.config import (
ConfigManager,
GenerateTopicConfig,
GenerateContentConfig,
SystemConfig,
ResourceConfig
)
from .file_io import ResourceLoader
logger = logging.getLogger(__name__)
class PromptTemplate:
"""
提示词模板基类
"""
def __init__(self, system_prompt_path: str, user_prompt_path: str):
"""
使用明确的系统和用户提示文件路径初始化
Args:
system_prompt_path (str): 系统提示模板文件的路径
user_prompt_path (str): 用户提示模板文件的路径
"""
self.system_path = Path(system_prompt_path)
self.user_path = Path(user_prompt_path)
self.system_template = self._read_template(self.system_path)
self.user_template = self._read_template(self.user_path)
def _read_template(self, path: Path) -> str:
if not path.is_file():
logger.warning(f"模板文件不存在,已跳过: {path}")
return ""
with open(path, 'r', encoding='utf-8') as f:
return f.read()
def get_system_prompt(self, **kwargs) -> str:
"""获取格式化后的系统提示词"""
if not self.system_template:
return ""
try:
if not kwargs:
return self.system_template
return self.system_template.format(**kwargs)
except KeyError as e:
logger.error(f"渲染系统提示 '{self.system_path}' 时缺少键: {e}")
raise
def build_user_prompt(self, **kwargs) -> str:
"""构建用户提示词"""
if not self.user_template:
return ""
try:
return self.user_template.format(**kwargs)
except KeyError as e:
logger.error(f"渲染用户提示 '{self.user_path}' 时缺少键: {e}")
raise
class BasePromptBuilder(PromptTemplate):
"""
包含通用资源加载方法的构建器基类
"""
def __init__(self, config_manager: ConfigManager, system_prompt_path: str, user_prompt_path: str):
super().__init__(system_prompt_path, user_prompt_path)
self.config_manager = config_manager
self.system_config: SystemConfig = config_manager.get_config('system', SystemConfig)
self.resource_config: ResourceConfig = config_manager.get_config('resource', ResourceConfig)
def _get_full_path(self, path_str: str) -> Path:
"""根据基准目录解析相对或绝对路径"""
if not self.resource_config.resource_dirs:
raise ValueError("Resource directories list is empty in config.")
base_path = Path(self.resource_config.resource_dirs[0])
file_path = Path(path_str)
return file_path if file_path.is_absolute() else (base_path / file_path).resolve()
def _load_resource_as_list(self, name: str, paths: list[str]) -> str:
"""加载资源,格式化为文件名列表"""
if not paths:
return f"{name}:\n"
content_parts = [f"{name}:"]
for path_str in paths:
try:
full_path = self._get_full_path(path_str)
if full_path.is_file():
content_parts.append(f"- {full_path.name}")
elif full_path.is_dir():
for f in sorted(p for p in full_path.iterdir() if p.is_file()):
content_parts.append(f"- {f.name}")
except Exception as e:
logger.error(f"处理路径 '{path_str}' 失败: {e}", exc_info=True)
return "\n".join(content_parts)
def _load_resource_as_content(self, name: str, paths: list[str]) -> str:
"""加载资源,格式化为包含文件内容的块"""
if not paths:
return f"{name}:\n"
content_parts = [f"{name}:"]
for path_str in paths:
try:
full_path = self._get_full_path(path_str)
files_to_read = []
if full_path.is_file():
files_to_read.append(full_path)
elif full_path.is_dir():
files_to_read.extend(sorted(p for p in full_path.iterdir() if p.is_file()))
for f_path in files_to_read:
with f_path.open('r', encoding='utf-8') as f:
file_content = f.read()
content_parts.append(f"--- {f_path.name} ---\n{file_content}")
except Exception as e:
logger.error(f"加载资源 '{path_str}' 失败: {e}", exc_info=True)
return "\n\n".join(content_parts)
def _load_and_format_content(self, path: Path) -> str:
"""根据文件类型加载和格式化内容"""
if path.suffix == '.json':
try:
with path.open('r', encoding='utf-8') as f:
data = json.load(f)
if "examples" in data and isinstance(data["examples"], list):
examples = data["examples"]
return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in examples])
else:
return json.dumps(data, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"解析或格式化JSON文件 '{path}' 失败: {e}")
return f"加载文件 '{path.name}' 失败。"
else:
return path.read_text('utf-8')
def _load_and_format_content_with_sampling(self, path: Path, sampling_rate: float) -> str:
"""根据文件类型加载和格式化内容,并应用采样率"""
if path.suffix == '.json':
try:
with path.open('r', encoding='utf-8') as f:
data = json.load(f)
if "examples" in data and isinstance(data["examples"], list):
examples = data["examples"]
# 应用采样率
sample_size = max(1, int(len(examples) * sampling_rate))
sampled_examples = random.sample(examples, sample_size)
logger.info(f"文件 '{path.name}' 中的examples采样: {sample_size}/{len(examples)} (采样率: {sampling_rate:.2f})")
return f"参考标题列表:\n" + "\n".join([f"- {item.get('content', '')}" for item in sampled_examples])
else:
return json.dumps(data, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"解析或格式化JSON文件 '{path}' 失败: {e}")
return f"加载文件 '{path.name}' 失败。"
else:
return path.read_text('utf-8')
def _load_refer_content(self, name: str, refer_configs: list, step: str = "") -> str:
"""
根据带采样率的配置加载Refer内容
Args:
name: 资源名称
refer_configs: Refer配置列表
step: 当前阶段,用于过滤引用内容
"""
if not refer_configs:
return f"{name}:\n"
content_parts = [f"{name}:"]
# 根据step过滤引用内容
filtered_configs = refer_configs
if step:
filtered_configs = [
item for item in refer_configs
if not item.step or item.step == step # 如果step为空或与当前阶段匹配
]
if not filtered_configs:
return f"{name}:\n无 (当前阶段: {step})"
logger.info(f"阶段 '{step}' 过滤后剩余 {len(filtered_configs)}/{len(refer_configs)} 个引用项")
# 为每次调用重新随机采样
random.seed()
for ref_item in filtered_configs:
try:
path_str = ref_item.path
sampling_rate = ref_item.sampling_rate
full_path = self._get_full_path(path_str)
# 简化逻辑:对于单个文件,直接应用采样率决定是否加载
if full_path.is_file():
# 对于JSON文件对内容进行采样
if full_path.suffix == '.json':
file_content = self._load_and_format_content_with_sampling(full_path, sampling_rate)
content_parts.append(f"--- {full_path.name} ---\n{file_content}")
logger.info(f"加载JSON文件 '{path_str}' 并应用内部采样")
# 对于其他文件,根据采样率决定是否完全加载
elif random.random() < sampling_rate:
file_content = self._load_and_format_content(full_path)
content_parts.append(f"--- {full_path.name} ---\n{file_content}")
logger.info(f"文件 '{path_str}' 采样成功 (采样率: {sampling_rate:.2f})")
else:
logger.info(f"文件 '{path_str}' 采样失败 (采样率: {sampling_rate:.2f})")
# 对于目录,直接选择指定比例的文件
elif full_path.is_dir():
all_files = sorted(p for p in full_path.iterdir() if p.is_file())
if all_files:
if sampling_rate < 1.0:
sample_size = max(1, int(len(all_files) * sampling_rate))
files_to_read = random.sample(all_files, sample_size)
logger.info(f"目录 '{path_str}' 采样: {sample_size}/{len(all_files)} 个文件 (采样率: {sampling_rate:.2f})")
else:
files_to_read = all_files
logger.info(f"目录 '{path_str}' 全部加载: {len(all_files)} 个文件")
for f_path in files_to_read:
file_content = self._load_and_format_content(f_path)
content_parts.append(f"--- {f_path.name} ---\n{file_content}")
else:
logger.warning(f"目录 '{path_str}' 中没有文件")
except Exception as e:
logger.error(f"加载Refer资源 '{ref_item}' 失败: {e}", exc_info=True)
return "\n\n".join(content_parts)
def _find_and_read_file(self, filename_to_find: str, search_paths: List[str]) -> Optional[str]:
"""在指定的路径列表中查找并读取文件内容"""
if not self.resource_config.resource_dirs:
return "文件未找到"
base_path = Path(self.resource_config.resource_dirs[0])
# 获取不带后缀的文件名,用于模糊匹配
filename_base = Path(filename_to_find).stem
for p_str in search_paths:
search_path = Path(p_str)
if not search_path.is_absolute():
search_path = base_path / search_path
if search_path.is_dir():
# 精确匹配
potential_path = search_path / filename_to_find
if potential_path.is_file():
return potential_path.read_text('utf-8')
# 基于文件名(不含后缀)的模糊匹配
for file_path in search_path.iterdir():
if file_path.is_file() and file_path.stem == filename_base:
logger.info(f"通过基础名称匹配找到文件: {filename_to_find} -> {file_path.name}")
return file_path.read_text('utf-8')
elif search_path.is_file():
# 精确匹配
if search_path.name == filename_to_find:
return search_path.read_text('utf-8')
# 基于文件名(不含后缀)的模糊匹配
if search_path.stem == filename_base:
logger.info(f"通过基础名称匹配找到文件: {filename_to_find} -> {search_path.name}")
return search_path.read_text('utf-8')
logger.warning(f"在路径 {search_paths} 中未能找到文件: {filename_to_find}")
return None
class TopicPromptBuilder(BasePromptBuilder):
"""
主题生成提示词构建器
"""
def __init__(self, config_manager: ConfigManager):
self.topic_config: GenerateTopicConfig = config_manager.get_config('topic_gen', GenerateTopicConfig)
super().__init__(
config_manager,
system_prompt_path=self.topic_config.topic_system_prompt,
user_prompt_path=self.topic_config.topic_user_prompt
)
# 加载并格式化所有需要的资源
self.style_content = self._load_resource_as_list("Style文件列表", self.resource_config.style.paths)
self.demand_content = self._load_resource_as_list("Demand文件列表", self.resource_config.demand.paths)
self.refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="topic")
self.object_content = self._load_resource_as_content("Object信息", self.resource_config.object.paths)
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
def build_user_prompt(self, num_topics: int, month: str) -> str:
"""构建生成主题的用户提示词"""
creative_materials = (
f"你拥有的创作资料如下:\n"
f"{self.style_content}\n\n"
f"{self.demand_content}\n\n"
f"{self.refer_content}\n\n"
f"{self.object_content}\n\n"
f"{self.product_content}"
)
return super().build_user_prompt(
creative_materials=creative_materials,
num_topics=num_topics,
month=month
)
class ContentPromptBuilder(BasePromptBuilder):
"""
内容生成提示词构建器
"""
def __init__(self, config_manager: ConfigManager):
self.content_config: GenerateContentConfig = config_manager.get_config('content_gen', GenerateContentConfig)
super().__init__(
config_manager,
system_prompt_path=self.content_config.content_system_prompt,
user_prompt_path=self.content_config.content_user_prompt
)
# 预加载静态的Product内容
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
def build_user_prompt(self, topic: Dict[str, Any]) -> str:
"""根据topic构建完整的用户提示词"""
# 每次构建提示词时重新加载Refer内容确保随机性
refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="content")
style_filename = topic.get("style", "")
style_content = self._find_and_read_file(style_filename, self.resource_config.style.paths) or "通用风格"
demand_filename = topic.get("target_audience", "")
demand_content = self._find_and_read_file(demand_filename, self.resource_config.demand.paths) or "通用用户画像"
object_name = topic.get("object", "")
object_content = ""
if object_name:
try:
object_path_str = next((p for p in self.resource_config.object.paths if object_name in p), None)
if object_path_str:
full_path = self._get_full_path(object_path_str)
if full_path.exists():
object_content = full_path.read_text('utf-8')
else:
logger.warning(f"找不到Object文件: {full_path}")
else:
logger.warning(f"在配置中找不到与 '{object_name}' 匹配的object路径")
except Exception as e:
logger.error(f"加载Object内容 '{object_name}' 失败: {e}")
return super().build_user_prompt(
style_content=f"{style_filename}\n{style_content}",
demand_content=f"{demand_filename}\n{demand_content}",
object_content=f"{object_name}\n{object_content}",
refer_content=refer_content,
product_content=self.product_content,
)
class JudgerPromptBuilder(BasePromptBuilder):
"""
内容审核提示词构建器
"""
def __init__(self, config_manager: ConfigManager):
self.content_config = config_manager.get_config('content_gen', GenerateContentConfig)
super().__init__(
config_manager,
system_prompt_path=self.content_config.judger_system_prompt,
user_prompt_path=self.content_config.judger_user_prompt
)
self.product_content = self._load_resource_as_content("Product信息", self.resource_config.product.paths)
def build_user_prompt(self, generated_content: str, topic: Dict[str, Any]) -> str:
"""构建审核内容的用户提示词"""
# 每次构建提示词时重新加载Refer内容确保随机性
refer_content = self._load_refer_content("Refer信息", self.resource_config.refer.refer_list, step="judge")
if isinstance(generated_content, dict):
tweet_content = json.dumps(generated_content, ensure_ascii=False, indent=4)
else:
tweet_content = str(generated_content)
object_name = topic.get("object", "")
object_content = ""
if object_name:
try:
object_path_str = next((p for p in self.resource_config.object.paths if object_name in p), None)
if object_path_str:
full_path = self._get_full_path(object_path_str)
if full_path.exists():
object_content = full_path.read_text('utf-8')
else:
logger.warning(f"找不到Object文件: {full_path}")
else:
logger.warning(f"在配置中找不到与 '{object_name}' 匹配的object路径")
except Exception as e:
logger.error(f"加载Object内容 '{object_name}' 失败: {e}")
return super().build_user_prompt(
tweet_content=tweet_content,
object_content=object_content,
product_content=self.product_content,
refer_content=refer_content
)