TravelContentCreator/utils/prompt_manager.py

585 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Manages the construction of prompts for different AI generation tasks.
"""
import os
import traceback
import logging # Add logging
from .resource_loader import ResourceLoader # Use relative import within the same package
class PromptManager:
"""Handles the loading and construction of prompts."""
def __init__(self,
topic_system_prompt_path: str,
topic_user_prompt_path: str,
content_system_prompt_path: str,
prompts_dir: str = None, # 兼容旧配置
prompts_config: list = None, # 新的配置方式
resource_dir_config: list = None,
topic_gen_num: int = 1, # Default values if needed
topic_gen_date: str = ""
):
self.topic_system_prompt_path = topic_system_prompt_path
self.topic_user_prompt_path = topic_user_prompt_path
self.content_system_prompt_path = content_system_prompt_path
self.prompts_dir = prompts_dir # 保留兼容旧配置
self.prompts_config = prompts_config or [] # 新的配置方式
self.resource_dir_config = resource_dir_config or []
self.topic_gen_num = topic_gen_num
self.topic_gen_date = topic_gen_date
# 缓存加载的文件内容
self._style_cache = {}
self._demand_cache = {}
self._refer_cache = {}
self._system_prompt_cache = {} # 新增:系统提示词缓存
self._user_prompt_cache = {} # 新增:用户提示词缓存
self._dateline_cache = None # 新增:日期线缓存
# 初始化时预加载配置的文件
self._preload_prompt_files()
def _preload_prompt_files(self):
"""预加载配置中的提示文件到缓存"""
# 预加载系统提示词和用户提示词文件
if self.topic_system_prompt_path and os.path.exists(self.topic_system_prompt_path):
content = ResourceLoader.load_file_content(self.topic_system_prompt_path)
if content:
self._system_prompt_cache["topic"] = content
logging.info(f"预加载系统提示词: {self.topic_system_prompt_path}")
if self.topic_user_prompt_path and os.path.exists(self.topic_user_prompt_path):
content = ResourceLoader.load_file_content(self.topic_user_prompt_path)
if content:
self._user_prompt_cache["topic"] = content
logging.info(f"预加载用户提示词: {self.topic_user_prompt_path}")
if self.content_system_prompt_path and os.path.exists(self.content_system_prompt_path):
content = ResourceLoader.load_file_content(self.content_system_prompt_path)
if content:
self._system_prompt_cache["content"] = content
logging.info(f"预加载内容系统提示词: {self.content_system_prompt_path}")
# 预加载日期线文件
if self.topic_user_prompt_path:
user_prompt_dir = os.path.dirname(self.topic_user_prompt_path)
dateline_path = os.path.join(user_prompt_dir, "2025各月节日宣传节点时间表.md")
if os.path.exists(dateline_path):
self._dateline_cache = ResourceLoader.load_file_content(dateline_path)
logging.info(f"预加载日期线文件: {dateline_path}")
# 加载prompts_config配置的文件
if not self.prompts_config:
return
for config_item in self.prompts_config:
prompt_type = config_item.get("type", "").lower()
file_paths = config_item.get("file_path", [])
if prompt_type == "style":
for path in file_paths:
if os.path.exists(path):
filename = os.path.basename(path)
content = ResourceLoader.load_file_content(path)
if content:
self._style_cache[filename] = content
name_without_ext = os.path.splitext(filename)[0]
self._style_cache[name_without_ext] = content # 同时缓存不带扩展名的版本
elif prompt_type == "demand":
for path in file_paths:
if os.path.exists(path):
filename = os.path.basename(path)
content = ResourceLoader.load_file_content(path)
if content:
self._demand_cache[filename] = content
name_without_ext = os.path.splitext(filename)[0]
self._demand_cache[name_without_ext] = content # 同时缓存不带扩展名的版本
elif prompt_type == "refer":
for path in file_paths:
if os.path.exists(path):
filename = os.path.basename(path)
content = ResourceLoader.load_file_content(path)
if content:
self._refer_cache[filename] = content
def _get_style_content(self, style_name):
"""获取Style文件内容优先从缓存获取如果不存在则尝试从目录加载"""
# 首先检查缓存
if style_name in self._style_cache:
return self._style_cache[style_name]
# 确保有扩展名
if not style_name.lower().endswith('.txt'):
style_file = f"{style_name}.txt"
else:
style_file = style_name
style_name = os.path.splitext(style_name)[0] # 移除扩展名
# 尝试模糊匹配缓存中的文件名
for cache_key in self._style_cache.keys():
cache_key_lower = cache_key.lower()
style_name_lower = style_name.lower()
# 完全匹配
if cache_key_lower == style_name_lower:
return self._style_cache[cache_key]
# 部分匹配
# 攻略风格
if ("攻略" in style_name_lower or "干货" in style_name_lower) and "攻略" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Style文件: '{cache_key}' 匹配 '{style_name}'")
return self._style_cache[cache_key]
# 轻奢风格
if "轻奢" in style_name_lower and "轻奢" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Style文件: '{cache_key}' 匹配 '{style_name}'")
return self._style_cache[cache_key]
# 推荐风格
if ("推荐" in style_name_lower or "种草" in style_name_lower) and "推荐" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Style文件: '{cache_key}' 匹配 '{style_name}'")
return self._style_cache[cache_key]
# 美食风格
if "美食" in style_name_lower and "美食" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Style文件: '{cache_key}' 匹配 '{style_name}'")
return self._style_cache[cache_key]
# 如果没有在缓存中找到模糊匹配尝试从prompts_dir加载
if self.prompts_dir:
style_path = os.path.join(self.prompts_dir, "Style", style_file)
if os.path.exists(style_path):
content = ResourceLoader.load_file_content(style_path)
if content:
# 保存到缓存
self._style_cache[style_name] = content
self._style_cache[style_file] = content
return content
# 如果直接加载失败,尝试列出目录中的所有文件并进行模糊匹配
style_dir = os.path.join(self.prompts_dir, "Style")
if os.path.isdir(style_dir):
try:
files = os.listdir(style_dir)
style_name_lower = style_name.lower()
for file in files:
file_lower = file.lower()
# 检查关键词匹配
matched = False
if ("攻略" in style_name_lower or "干货" in style_name_lower) and "攻略" in file_lower:
matched = True
elif "轻奢" in style_name_lower and "轻奢" in file_lower:
matched = True
elif ("推荐" in style_name_lower or "种草" in style_name_lower) and "推荐" in file_lower:
matched = True
elif "美食" in style_name_lower and "美食" in file_lower:
matched = True
if matched:
matched_path = os.path.join(style_dir, file)
logging.info(f"模糊匹配 - 在目录中找到部分匹配的Style文件: '{file}' 匹配 '{style_name}'")
content = ResourceLoader.load_file_content(matched_path)
if content:
# 保存到缓存
self._style_cache[style_name] = content
self._style_cache[file] = content
return content
except Exception as e:
logging.warning(f"尝试列出Style目录内容时出错: {e}")
return None
def _get_demand_content(self, demand_name):
"""获取Demand文件内容优先从缓存获取如果不存在则尝试从目录加载"""
# 首先检查缓存
if demand_name in self._demand_cache:
return self._demand_cache[demand_name]
# 确保有扩展名
if not demand_name.lower().endswith('.txt'):
demand_file = f"{demand_name}.txt"
else:
demand_file = demand_name
demand_name = os.path.splitext(demand_name)[0] # 移除扩展名
# 尝试模糊匹配缓存中的文件名
for cache_key in self._demand_cache.keys():
cache_key_lower = cache_key.lower()
demand_name_lower = demand_name.lower()
# 完全匹配
if cache_key_lower == demand_name_lower:
return self._demand_cache[cache_key]
# 部分匹配:检查需求名称是否是缓存键的一部分,或者缓存键是否是需求名称的一部分
# 例如"亲子家庭文旅需求"能匹配到"亲子向文旅需求"
if "亲子" in demand_name_lower and "亲子" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "情侣" in demand_name_lower and "情侣" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "职场" in demand_name_lower and "职场" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "学生" in demand_name_lower and "学生" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "银发" in demand_name_lower and "夕阳红" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "夕阳红" in demand_name_lower and "银发" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
if "周边" in demand_name_lower and "周边" in cache_key_lower:
logging.info(f"模糊匹配 - 找到部分匹配的Demand文件: '{cache_key}' 匹配 '{demand_name}'")
return self._demand_cache[cache_key]
# 如果没有在缓存中找到模糊匹配尝试从prompts_dir加载向后兼容
if self.prompts_dir:
demand_path = os.path.join(self.prompts_dir, "Demand", demand_file)
if os.path.exists(demand_path):
content = ResourceLoader.load_file_content(demand_path)
if content:
# 保存到缓存
self._demand_cache[demand_name] = content
self._demand_cache[demand_file] = content
return content
# 如果直接加载失败,尝试列出目录中的所有文件并进行模糊匹配
demand_dir = os.path.join(self.prompts_dir, "Demand")
if os.path.isdir(demand_dir):
try:
files = os.listdir(demand_dir)
demand_name_lower = demand_name.lower()
for file in files:
file_lower = file.lower()
# 检查关键词匹配
matched = False
if "亲子" in demand_name_lower and "亲子" in file_lower:
matched = True
elif "情侣" in demand_name_lower and "情侣" in file_lower:
matched = True
elif "职场" in demand_name_lower and "职场" in file_lower:
matched = True
elif "学生" in demand_name_lower and "学生" in file_lower:
matched = True
elif ("银发" in demand_name_lower or "夕阳红" in demand_name_lower) and ("银发" in file_lower or "夕阳红" in file_lower):
matched = True
elif "周边" in demand_name_lower and "周边" in file_lower:
matched = True
if matched:
matched_path = os.path.join(demand_dir, file)
logging.info(f"模糊匹配 - 在目录中找到部分匹配的Demand文件: '{file}' 匹配 '{demand_name}'")
content = ResourceLoader.load_file_content(matched_path)
if content:
# 保存到缓存
self._demand_cache[demand_name] = content
self._demand_cache[file] = content
return content
except Exception as e:
logging.warning(f"尝试列出Demand目录内容时出错: {e}")
# 如果所有尝试都失败
logging.warning(f"未能找到Demand文件: '{demand_name}',尝试过以下位置: 缓存, {self.prompts_dir}/Demand/")
return None
def _get_all_refer_contents(self):
"""获取所有Refer文件内容"""
# 如果缓存中有内容,先使用缓存
if self._refer_cache:
refer_content_all = ""
for filename, content in self._refer_cache.items():
refer_content_all += f"--- Refer File: {filename} ---\n{content}\n\n"
return refer_content_all
# 如果缓存为空尝试从prompts_dir加载向后兼容
refer_content_all = ""
if self.prompts_dir:
refer_dir = os.path.join(self.prompts_dir, "Refer")
if os.path.isdir(refer_dir):
refer_files = [f for f in os.listdir(refer_dir) if os.path.isfile(os.path.join(refer_dir, f))]
for refer_file in refer_files:
refer_path = os.path.join(refer_dir, refer_file)
content = ResourceLoader.load_file_content(refer_path)
if content:
refer_content_all += f"--- Refer File: {refer_file} ---\n{content}\n\n"
# 保存到缓存
self._refer_cache[refer_file] = content
return refer_content_all
def get_topic_prompts(self):
"""Constructs the system and user prompts for topic generation."""
logging.info("Constructing prompts for topic generation...")
try:
# --- System Prompt ---
system_prompt = self._system_prompt_cache.get("topic")
if not system_prompt:
if not self.topic_system_prompt_path:
logging.error("Topic system prompt path not provided during PromptManager initialization.")
return None, None
system_prompt = ResourceLoader.load_file_content(self.topic_system_prompt_path)
if system_prompt:
self._system_prompt_cache["topic"] = system_prompt
else:
logging.error(f"Failed to load topic system prompt from '{self.topic_system_prompt_path}'.")
return None, None
# --- User Prompt ---
base_user_prompt = self._user_prompt_cache.get("topic")
if not base_user_prompt:
if not self.topic_user_prompt_path:
logging.error("Topic user prompt path not provided during PromptManager initialization.")
return None, None
base_user_prompt = ResourceLoader.load_file_content(self.topic_user_prompt_path)
if base_user_prompt:
self._user_prompt_cache["topic"] = base_user_prompt
else:
logging.error(f"Failed to load base topic user prompt from '{self.topic_user_prompt_path}'.")
return None, None
# --- Build the dynamic part of the user prompt ---
user_prompt_dynamic = "你拥有的创作资料如下:\n"
# 添加prompts_config配置的文件信息
if self.prompts_config:
for config_item in self.prompts_config:
prompt_type = config_item.get("type", "").lower()
file_paths = config_item.get("file_path", [])
if file_paths:
user_prompt_dynamic += f"{prompt_type.capitalize()}文件列表:\n"
for path in file_paths:
filename = os.path.basename(path)
user_prompt_dynamic += f"- {filename}\n"
user_prompt_dynamic += "\n"
# 兼容旧配置Add genPrompts directory structure
elif self.prompts_dir and os.path.isdir(self.prompts_dir):
try:
gen_prompts_list = os.listdir(self.prompts_dir)
for gen_prompt_folder in gen_prompts_list:
folder_path = os.path.join(self.prompts_dir, gen_prompt_folder)
if os.path.isdir(folder_path):
try:
# List files, filter out subdirs if needed
gen_prompts_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
user_prompt_dynamic += f"{gen_prompt_folder}\n{gen_prompts_files}\n"
except OSError as e:
logging.warning(f"Could not list directory {folder_path}: {e}")
except OSError as e:
logging.warning(f"Could not list base prompts directory {self.prompts_dir}: {e}")
else:
logging.warning(f"Neither prompts_config nor prompts_dir provided or valid.")
# Add resource directory contents
for dir_info in self.resource_dir_config:
source_type = dir_info.get("type", "UnknownType")
source_file_paths = dir_info.get("file_path", [])
for file_path in source_file_paths:
# Use ResourceLoader's static method
file_content = ResourceLoader.load_file_content(file_path)
if file_content:
user_prompt_dynamic += f"{source_type}信息:\n{os.path.basename(file_path)}\n{file_content}\n\n"
else:
logging.warning(f"Could not load resource file {file_path}")
# Add dateline information (optional)
if self._dateline_cache:
user_prompt_dynamic += f"\n{self._dateline_cache}"
else:
user_prompt_dir = os.path.dirname(self.topic_user_prompt_path)
dateline_path = os.path.join(user_prompt_dir, "2025各月节日宣传节点时间表.md") # Consider making this configurable
if os.path.exists(dateline_path):
dateline_content = ResourceLoader.load_file_content(dateline_path)
if dateline_content:
self._dateline_cache = dateline_content
user_prompt_dynamic += f"\n{dateline_content}"
# Combine dynamic part, base template, and final parameters
user_prompt = user_prompt_dynamic + base_user_prompt
user_prompt += f"\n选题数量:{self.topic_gen_num}\n选题日期:{self.topic_gen_date}\n"
logging.info(f"Topic prompts constructed. System: {len(system_prompt)} chars, User: {len(user_prompt)} chars.")
return system_prompt, user_prompt
except Exception as e:
logging.exception("Error constructing topic prompts:")
return None, None
def get_content_prompts(self, topic_item):
"""Constructs the system and user prompts for content generation based on a topic item."""
logging.info(f"Constructing content prompts for topic: {topic_item.get('object', 'N/A')}...")
try:
# --- System Prompt ---
system_prompt = self._system_prompt_cache.get("content")
if not system_prompt:
if not self.content_system_prompt_path:
logging.error("Content system prompt path not provided during PromptManager initialization.")
return None, None
system_prompt = ResourceLoader.load_file_content(self.content_system_prompt_path)
if system_prompt:
self._system_prompt_cache["content"] = system_prompt
else:
logging.error(f"Failed to load content system prompt from '{self.content_system_prompt_path}'.")
return None, None
# --- User Prompt ---
user_prompt = ""
# 1. 添加Demand部分 (直接使用 topic_item['logic'] 的描述性文本)
try:
demand_description = topic_item.get('logic')
if demand_description:
user_prompt += f"Demand Logic:\n{demand_description}\n"
else:
logging.warning("Warning: 'logic' key missing or empty in topic_item for Demand prompt.")
except Exception as e:
logging.exception("Error processing Demand description:")
# 2. Object Info - 先列出所有可用文件,再注入匹配文件的内容
try:
object_name_from_topic = topic_item.get('object') # e.g., "尚书第建筑群"
object_file_basenames = []
matched_object_file_path = None
matched_object_basename = None
# 遍历查找 Object 文件
for dir_info in self.resource_dir_config:
if dir_info.get("type") == "Object":
for file_path in dir_info.get("file_path", []):
basename = os.path.basename(file_path)
object_file_basenames.append(basename)
# 尝试匹配当前 topic 的 object (仅当尚未找到匹配时)
if object_name_from_topic and not matched_object_file_path:
cleaned_resource_name = basename
if cleaned_resource_name.startswith("景点信息-"):
cleaned_resource_name = cleaned_resource_name[len("景点信息-"):]
if cleaned_resource_name.endswith(".txt"):
cleaned_resource_name = cleaned_resource_name[:-len(".txt")]
if cleaned_resource_name and cleaned_resource_name in object_name_from_topic:
matched_object_file_path = file_path
matched_object_basename = basename
# 注意:这里不 break继续收集所有文件名
# 构建提示词 - Part 1: 文件列表
if object_file_basenames:
user_prompt += "Object信息:\n"
# user_prompt += f"{object_file_basenames}\n\n" # 直接打印列表可能不够清晰
for fname in object_file_basenames:
user_prompt += f"- {fname}\n"
user_prompt += "\n" # 加一个空行
logging.info(f"Listed {len(object_file_basenames)} available object files.")
else:
logging.warning("No resource directory entry found with type 'Object', or it has no file paths.")
# 构建提示词 - Part 2: 注入匹配文件内容
if matched_object_file_path:
logging.info(f"Attempting to load content for matched object file: {matched_object_basename}")
matched_object_content = ResourceLoader.load_file_content(matched_object_file_path)
if matched_object_content:
user_prompt += f"{matched_object_basename}\n{matched_object_content}\n\n"
logging.info(f"Successfully loaded and injected content for: {matched_object_basename}")
else:
logging.warning(f"Object file matched ({matched_object_basename}) but could not be loaded or is empty.")
elif object_name_from_topic: # 只有当 topic 中指定了 object 但没找到匹配文件时才警告
logging.warning(f"Could not find a matching Object resource file to inject content for '{object_name_from_topic}'. Only the list of files was provided.")
except KeyError:
logging.warning("Warning: 'object' key potentially missing in topic_item.")
except Exception as e:
logging.exception("Error processing Object prompt section:")
# 3. 添加Product信息 (if applicable)
try:
product_name = topic_item.get('product')
product_logic_description = topic_item.get('product_logic') # Directly use this description
if product_name:
# Add Product Logic description first (if available)
if product_logic_description:
user_prompt += f"Product Logic:\n{product_logic_description}\n"
else:
logging.warning(f"Warning: 'product_logic' key missing or empty for product '{product_name}'.")
# Then, load Product Info file
product_file_path = None
for dir_info in self.resource_dir_config:
if dir_info.get("type") == "Product":
for file_path in dir_info.get("file_path", []):
if product_name in os.path.basename(file_path):
product_file_path = file_path
break
if product_file_path: break
if product_file_path:
product_content = ResourceLoader.load_file_content(product_file_path)
if product_content:
user_prompt += f"Product Info:\n{product_content}\n"
else:
logging.warning(f"Product file could not be loaded: {product_file_path}")
else:
logging.warning(f"Product file path not found in config for: {product_name}")
except KeyError:
logging.warning("Warning: Missing 'product' key in topic_item for Product prompt.")
except Exception as e:
logging.exception("Error processing Product prompt:")
# 4. 添加Style信息 (加载文件 based on topic_item['style'])
try:
style_name = topic_item.get('style')
if style_name:
style_content = self._get_style_content(style_name)
if style_content:
user_prompt += f"Style Info:\n{style_content}\n"
else:
logging.warning(f"Style file not found or empty for: {style_name}")
else:
logging.warning("Warning: 'style' key missing or empty in topic_item.")
except Exception as e:
logging.exception("Error processing Style prompt:")
# 5. 添加Target Audience信息 (加载文件 based on topic_item['target_audience'])
try:
target_audience_name = topic_item.get('target_audience')
if target_audience_name:
target_audience_content = self._get_demand_content(target_audience_name)
if target_audience_content:
user_prompt += f"Target Audience Info:\n{target_audience_content}\n"
else:
logging.warning(f"Target Audience file not found or empty for: {target_audience_name}")
else:
logging.warning("Warning: 'target_audience' key missing or empty in topic_item.")
except Exception as e:
logging.exception("Error processing Target Audience prompt:")
# 6. 添加Refer信息 (加载所有Refer文件的内容)
try:
refer_content_all = self._get_all_refer_contents()
if refer_content_all:
user_prompt += f"Refer Info:\n{refer_content_all}"
else:
logging.warning("No content loaded from Refer files.")
except Exception as e:
logging.exception("Error processing Refer files:")
# --- End of prompt construction logic ---
logging.info(f"Content prompts constructed. System: {len(system_prompt)} chars, User: {len(user_prompt)} chars.")
return system_prompt, user_prompt
except KeyError as e:
# Catch potential KeyErrors from accessing topic_item if a required key is missing early on
logging.error(f"Error constructing content prompts: Missing essential key '{e}' in topic_item: {topic_item}")
return None, None
except Exception as e:
logging.exception("Error constructing content prompts:")
return None, None