221 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
文件输入输出工具模块
"""
import os
import random
import json
import json_repair
import logging
from pathlib import Path
from typing import Optional, List, Dict, Any
logger = logging.getLogger(__name__)
class ResourceLoader:
"""资源加载器,用于加载文件内容"""
@staticmethod
def load_text_file(file_path: str) -> Optional[str]:
"""加载文本文件内容"""
if not os.path.exists(file_path):
logger.warning(f"文件不存在: {file_path}")
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logger.error(f"加载文件 '{file_path}' 失败: {e}")
return None
@staticmethod
def load_json_file(file_path: str) -> Optional[Dict[str, Any]]:
"""加载并解析JSON文件"""
content = ResourceLoader.load_text_file(file_path)
if content is None:
return None
try:
return json_repair.loads(content)
except json.JSONDecodeError as e:
logger.error(f"解析JSON文件 '{file_path}' 失败: {e}")
return None
@staticmethod
def find_file(directory: str, file_name: str, exact_match: bool = True) -> Optional[str]:
"""在目录中查找文件,支持精确和模糊匹配"""
if not os.path.isdir(directory):
logger.warning(f"目录不存在: {directory}")
return None
# 确保文件名有.txt后缀如果需要
base, ext = os.path.splitext(file_name)
if not ext:
file_name_with_ext = f"{base}.txt"
else:
file_name_with_ext = file_name
# 精确匹配
exact_path = os.path.join(directory, file_name_with_ext)
if os.path.exists(exact_path):
return exact_path
# 模糊匹配
if not exact_match:
for f in os.listdir(directory):
if base in f:
return os.path.join(directory, f)
logger.warning(f"'{directory}' 中找不到文件 '{file_name}'")
return None
class OutputManager:
"""负责处理输出文件,如保存文章和生成汇总报告"""
def __init__(self, output_dir: str, run_id: str):
self.base_output_dir = Path(output_dir)
self.run_id = run_id
self.run_dir = self.base_output_dir / self.run_id
self.run_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"OutputManager initialized for run '{run_id}' in '{self.run_dir}'")
def get_topic_dir(self, topic_index: Any) -> Path:
"""为给定主题索引创建并返回一个唯一的目录"""
topic_dir = self.run_dir / f"topic_{topic_index}"
topic_dir.mkdir(parents=True, exist_ok=True)
return topic_dir
def get_variant_dir(self, topic_index: int, variant_index: int) -> Path:
"""获取并创建特定变体的目录"""
variant_dir = self.run_dir / f"{topic_index}_{variant_index}"
variant_dir.mkdir(exist_ok=True)
return variant_dir
def save_json(self, data: Any, filename: str, subdir: Optional[str] = None):
"""将数据保存为JSON文件"""
target_dir = self.run_dir / subdir if subdir else self.run_dir
target_dir.mkdir(exist_ok=True)
file_path = target_dir / filename
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
logger.info(f"JSON data saved to: {file_path}")
except Exception as e:
logger.error(f"Failed to save JSON to {file_path}: {e}")
def save_text(self, content: str, filename: str, subdir: Optional[str] = None):
"""将文本内容保存为文件"""
target_dir = self.run_dir / subdir if subdir else self.run_dir
target_dir.mkdir(exist_ok=True)
file_path = target_dir / filename
try:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
logger.info(f"Text data saved to: {file_path}")
except Exception as e:
logger.error(f"Failed to save text to {file_path}: {e}")
def save_image(self, image_data, filename: str, subdir: Optional[str] = None):
"""保存图像文件 (需要Pillow库)"""
target_dir = self.run_dir / subdir if subdir else self.run_dir
target_dir.mkdir(exist_ok=True)
file_path = target_dir / filename
try:
image_data.save(file_path)
logger.info(f"Image saved to: {file_path}")
except Exception as e:
logger.error(f"Failed to save image to {file_path}: {e}")
def finalize(self):
"""完成运行的最终操作"""
logger.info(f"Finalizing run: {self.run_id}")
# 目前没有特殊操作,但可以用于未来的扩展,如创建清单文件
pass
def process_llm_json_text(text: Any) -> Optional[Dict[str, Any]]:
"""
处理LLM返回的JSON字符串支持多种格式提取
1. 提取</think>后的内容
2. 提取```json和```之间的内容
3. 尝试直接解析整个文本
4. 使用json_repair修复格式问题
Args:
text: LLM返回的原始文本或已解析的对象
Returns:
解析后的JSON对象解析失败则返回None
"""
# 如果输入已经是字典类型,直接返回
if isinstance(text, dict):
return text
# 如果输入是列表类型且要求返回字典则返回None
if isinstance(text, list):
logger.warning("输入是列表类型,但期望返回字典类型")
return None
# 确保输入是字符串类型
if not isinstance(text, str):
try:
text = str(text)
except Exception as e:
logger.error(f"无法将输入转换为字符串: {e}")
return None
if not text or not text.strip():
logger.warning("收到空的LLM响应")
return None
# 存储可能的JSON文本
json_candidates = []
# 1. 尝试提取</think>后的内容
if "</think>" in text:
think_parts = text.split("</think>", 1)
if len(think_parts) > 1:
json_candidates.append(think_parts[1].strip())
# 2. 尝试提取```json和```之间的内容
json_code_blocks = []
# 匹配```json和```之间的内容
import re
json_blocks = re.findall(r"```(?:json)?\s*([\s\S]*?)```", text)
if json_blocks:
json_candidates.extend([block.strip() for block in json_blocks])
# 3. 直接使用json_repair解析
try:
return json_repair.loads(text)
except Exception:
pass
# 4. 添加原始文本作为候选
json_candidates.append(text.strip())
# 尝试解析每个候选文本
for candidate in json_candidates:
# 直接尝试解析
try:
import json
return json.loads(candidate)
except json.JSONDecodeError:
pass
# 使用json_repair尝试修复
try:
import json_repair
return json_repair.loads(candidate)
except Exception:
continue
# 所有尝试都失败记录错误并返回None
logger.error(f"无法解析LLM返回的JSON尝试了{len(json_candidates)}种提取方式")
logger.debug(f"原始响应: {text[:200]}...") # 只记录前200个字符避免日志过大
return None