TravelContentCreator/scripts/distribution/extract_and_render.py

565 lines
21 KiB
Python
Raw Normal View History

2025-05-13 09:34:59 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import shutil
import csv
import traceback
import re
import argparse
from datetime import datetime
2025-05-13 19:01:27 +08:00
import sqlite3
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# 内置数据库记录功能
def init_database(db_path):
"""初始化数据库,创建表结构"""
try:
conn = sqlite3.connect(db_path)
conn.execute("PRAGMA foreign_keys = OFF") # 禁用外键约束
2025-05-13 19:01:27 +08:00
cursor = conn.cursor()
# 创建内容表
cursor.execute("""
CREATE TABLE IF NOT EXISTS contents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
entry_id TEXT NOT NULL UNIQUE,
output_txt_path TEXT,
poster_path TEXT,
article_json_path TEXT,
product TEXT,
object TEXT,
date TEXT,
logic TEXT,
judge_status INTEGER,
is_distributed INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 创建索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_contents_entry_id ON contents(entry_id)")
conn.commit()
logger.info("数据库初始化成功")
return conn
except sqlite3.Error as e:
logger.error(f"初始化数据库失败: {e}")
return None
def record_to_database(
db_path,
entry_id,
output_txt_path=None,
poster_path=None,
article_json_path=None,
product=None,
object=None,
date=None,
logic=None,
judge_status=None,
is_distributed=0
):
"""将内容记录到数据库"""
try:
# 检查数据库是否存在,如果不存在则初始化
if not os.path.exists(db_path):
logger.info(f"数据库文件不存在: {db_path},将自动创建")
conn = init_database(db_path)
if not conn:
return False
else:
try:
conn = sqlite3.connect(db_path)
conn.execute("PRAGMA foreign_keys = OFF") # 禁用外键约束
2025-05-13 19:01:27 +08:00
except sqlite3.Error as e:
logger.error(f"连接数据库失败: {e}")
return False
try:
cursor = conn.cursor()
# 准备数据
data = (
entry_id,
output_txt_path or '',
poster_path or '',
article_json_path or '',
product or '',
object or '',
date or '',
logic or '',
judge_status if judge_status is not None else None,
is_distributed
)
# 插入或更新内容
cursor.execute("""
INSERT OR REPLACE INTO contents
(entry_id, output_txt_path, poster_path, article_json_path,
product, object, date, logic, judge_status, is_distributed)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", data)
conn.commit()
logger.info(f"已将内容 {entry_id} 记录到数据库")
return True
except Exception as e:
logger.error(f"记录内容到数据库失败: {e}")
try:
conn.rollback()
except:
pass
return False
finally:
try:
conn.close()
except:
pass
except Exception as e:
logger.error(f"记录提取内容时发生错误: {e}")
return False
2025-05-13 09:34:59 +08:00
def convert_json_to_txt_content(json_path, prefer_original=False):
"""
读取 JSON 文件提取标题内容和标签移除 Markdown 格式
并返回格式化文本
根据judge_success字段决定使用原始内容还是审核后内容
- judge_success为True时使用title/content除非prefer_original=True
- judge_success为False时使用original_title/original_content
Args:
json_path: JSON文件路径
prefer_original: 是否优先使用原始内容无视judge_success结果
"""
print(f" - 正在读取 JSON: {json_path}")
if not os.path.exists(json_path):
print(f" - 警告: JSON 文件不存在: {json_path}")
return None, f"文件未找到: {json_path}"
try:
with open(json_path, 'r', encoding='utf-8') as f_json:
data = json.load(f_json)
# 根据judge_success选择标题和内容
judge_success = data.get('judge_success', None)
if prefer_original and 'original_title' in data and 'original_content' in data:
# 优先使用原始内容
title = data.get('original_title', '未找到原始标题')
content = data.get('original_content', '未找到原始内容')
# 优先使用原始标签
tags = data.get('original_tags', data.get('tags', '未找到标签'))
print(f" - 优先使用原始内容 (prefer_original=True)")
elif judge_success is True and not prefer_original:
# 使用审核后的内容
title = data.get('title', '未找到标题')
content = data.get('content', '未找到内容')
tags = data.get('tags', '未找到标签')
print(f" - 使用审核后内容 (judge_success=True)")
elif 'original_title' in data and 'original_content' in data:
# 使用原始内容
title = data.get('original_title', '未找到原始标题')
content = data.get('original_content', '未找到原始内容')
# 优先使用原始标签
tags = data.get('original_tags', data.get('tags', '未找到标签'))
print(f" - 使用原始内容 (judge_success={judge_success})")
else:
# 若无original字段使用常规字段
title = data.get('title', '未找到标题')
content = data.get('content', '未找到内容')
tags = data.get('tags', '未找到标签')
print(f" - 使用常规内容 (无judge结果)")
# 解决tag/tags字段重复问题按照修正后的处理逻辑只使用tags字段
if not tags and 'tag' in data:
tags = data.get('tag', '未找到标签')
print(f" - 使用tag字段作为标签 (该字段将在后续版本中统一为tags)")
# 移除Markdown格式
content_no_format = re.sub(r'\*\*(.*?)\*\*', r'\1', content)
# 组合输出文本
return f"{title}\n\n{content_no_format}\n\n{tags}", None
except json.JSONDecodeError:
print(f" - 错误: JSON 格式无效: {json_path}")
return None, f"无效的 JSON 格式: {json_path}"
except Exception as e:
print(f" - 错误: 处理 JSON 时出错: {e}")
return None, f"处理 JSON 时出错: {e}"
def load_topic_data(source_dir, run_id):
"""
加载选题数据
Args:
source_dir: 源目录路径
run_id: 运行ID
Returns:
dict: 以topic_index为键的选题数据字典
"""
topic_file_path = os.path.join(source_dir, f"tweet_topic_{run_id}.json")
topic_data = {}
if os.path.exists(topic_file_path):
try:
with open(topic_file_path, 'r', encoding='utf-8') as f:
topics = json.load(f)
# 将选题数据转换为以index为键的字典
for topic in topics:
index = topic.get("index")
if index:
topic_data[index] = topic
print(f"成功加载选题数据,共{len(topic_data)}")
except Exception as e:
print(f"加载选题数据时出错: {e}")
else:
print(f"警告: 未找到选题文件: {topic_file_path}")
return topic_data
2025-05-13 19:01:27 +08:00
def process_result_directory(source_dir, output_dir, run_id=None, prefer_original=False, db_path=None):
2025-05-13 09:34:59 +08:00
"""
处理指定的结果目录提取内容并渲染到输出目录
Args:
source_dir: 源目录路径包含i_j子目录
output_dir: 输出目录路径
run_id: 可选的运行ID如果不提供则使用源目录名
prefer_original: 是否优先使用原始内容无视judge_success结果
2025-05-13 19:01:27 +08:00
db_path: 数据库路径若不提供则使用默认路径
2025-05-13 09:34:59 +08:00
"""
if not os.path.isdir(source_dir):
print(f"错误: 源目录不存在: {source_dir}")
return
2025-05-13 19:01:27 +08:00
# 设置默认数据库路径
if db_path is None:
db_path = '/root/autodl-tmp/TravelContentCreator/distribution.db'
# 数据库是否启用
db_enabled = True
2025-05-13 09:34:59 +08:00
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
print(f"确保输出目录存在: {output_dir}")
# 提取run_id
if not run_id:
run_id = os.path.basename(source_dir)
# 加载选题数据
topic_data = load_topic_data(source_dir, run_id)
# 创建CSV清单添加选题相关字段
csv_path = os.path.join(output_dir, f"manifest_{run_id}.csv")
csv_data = [
[
"EntryID",
"TopicIndex",
"VariantIndex",
"Date",
"Logic",
"Object",
"Product",
"ProductLogic",
"Style",
"StyleLogic",
"TargetAudience",
"TargetAudienceLogic",
"SourcePath",
"ArticleJsonPath",
"OutputTxtPath",
"PosterPath",
"AdditionalImagesCount",
"Status",
"Details",
"JudgeStatus",
2025-05-13 19:01:27 +08:00
"ContentSource",
"RecordedInDB",
"IsDistributed"
2025-05-13 09:34:59 +08:00
]
]
# 查找所有i_j目录
entry_pattern = re.compile(r"^(\d+)_(\d+)$")
entries = []
for item in os.listdir(source_dir):
item_path = os.path.join(source_dir, item)
match = entry_pattern.match(item)
if os.path.isdir(item_path) and match:
entries.append(item)
if not entries:
print(f"警告: 在源目录中未找到任何i_j格式的子目录")
return
print(f"找到 {len(entries)} 个条目目录")
# 处理每个条目
for entry in sorted(entries):
entry_path = os.path.join(source_dir, entry)
output_entry_path = os.path.join(output_dir, entry)
print(f"\n处理条目: {entry}")
# 解析topic_index和variant_index
match = entry_pattern.match(entry)
topic_index = match.group(1)
variant_index = match.group(2)
# 获取该话题的选题信息
topic_info = topic_data.get(topic_index, {})
# 创建记录
record = {
"EntryID": entry,
"TopicIndex": topic_index,
"VariantIndex": variant_index,
"Date": topic_info.get("date", ""),
"Logic": topic_info.get("logic", ""),
"Object": topic_info.get("object", ""),
"Product": topic_info.get("product", ""),
"ProductLogic": topic_info.get("product_logic", ""),
"Style": topic_info.get("style", ""),
"StyleLogic": topic_info.get("style_logic", ""),
"TargetAudience": topic_info.get("target_audience", ""),
"TargetAudienceLogic": topic_info.get("target_audience_logic", ""),
"SourcePath": entry_path,
"ArticleJsonPath": "",
"OutputTxtPath": "",
"PosterPath": "",
"AdditionalImagesCount": 0,
"Status": "Processing",
"Details": "",
"JudgeStatus": "",
2025-05-13 19:01:27 +08:00
"ContentSource": "unknown",
"RecordedInDB": "No",
"IsDistributed": "No"
2025-05-13 09:34:59 +08:00
}
# 创建输出条目目录
try:
os.makedirs(output_entry_path, exist_ok=True)
except Exception as e:
record["Status"] = "Failed"
record["Details"] = f"创建输出目录失败: {e}"
csv_data.append([record[col] for col in csv_data[0]])
print(f" - 错误: {record['Details']}")
continue
# 1. 处理article.json -> txt
json_path = os.path.join(entry_path, "article.json")
txt_path = os.path.join(output_entry_path, "article.txt")
record["ArticleJsonPath"] = json_path
record["OutputTxtPath"] = txt_path
2025-05-13 19:01:27 +08:00
# 读取article.json
article_data = {}
2025-05-13 09:34:59 +08:00
if os.path.exists(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as f_json:
article_data = json.load(f_json)
# 提取judge_success状态
if "judge_success" in article_data:
record["JudgeStatus"] = str(article_data["judge_success"])
elif "judged" in article_data:
record["JudgeStatus"] = "已审核" if article_data["judged"] else "未审核"
except Exception as e:
print(f" - 错误: 读取article.json失败: {e}")
txt_content, error = convert_json_to_txt_content(json_path, prefer_original)
if error:
record["Status"] = "Partial"
record["Details"] += f"文章处理失败: {error}; "
print(f" - 错误: {record['Details']}")
else:
try:
with open(txt_path, 'w', encoding='utf-8') as f_txt:
f_txt.write(txt_content)
print(f" - 成功写入文本文件: {txt_path}")
# 记录内容来源
if prefer_original:
record["ContentSource"] = "original_preferred"
elif article_data.get("judge_success") is True:
record["ContentSource"] = "judged"
elif "original_title" in article_data:
record["ContentSource"] = "original"
else:
record["ContentSource"] = "default"
except Exception as e:
record["Status"] = "Partial"
record["Details"] += f"写入文本文件失败: {e}; "
print(f" - 错误: {record['Details']}")
else:
record["Status"] = "Partial"
record["Details"] += "文章JSON文件不存在; "
print(f" - 警告: {record['Details']}")
# 2. 处理海报图片
poster_dir = os.path.join(entry_path, "poster")
poster_jpg_path = os.path.join(poster_dir, "poster.jpg")
output_poster_path = os.path.join(output_entry_path, "poster.jpg")
record["PosterPath"] = output_poster_path
if os.path.exists(poster_jpg_path):
try:
shutil.copy2(poster_jpg_path, output_poster_path)
print(f" - 成功复制海报图片: {output_poster_path}")
except Exception as e:
record["Status"] = "Partial"
record["Details"] += f"复制海报图片失败: {e}; "
print(f" - 错误: {record['Details']}")
else:
record["Status"] = "Partial"
record["Details"] += "海报图片不存在; "
print(f" - 警告: {record['Details']}")
# 3. 处理额外图片
image_dir = os.path.join(entry_path, "image")
output_image_dir = os.path.join(output_entry_path, "additional_images")
if os.path.exists(image_dir) and os.path.isdir(image_dir):
try:
os.makedirs(output_image_dir, exist_ok=True)
image_count = 0
for filename in os.listdir(image_dir):
if filename.startswith("additional_") and filename.endswith(".jpg"):
source_file = os.path.join(image_dir, filename)
dest_file = os.path.join(output_image_dir, filename)
# 复制图片
shutil.copy2(source_file, dest_file)
image_count += 1
record["AdditionalImagesCount"] = image_count
print(f" - 复制了 {image_count} 张额外图片到: {output_image_dir}")
except Exception as e:
record["Status"] = "Partial"
record["Details"] += f"处理额外图片时出错: {e}; "
print(f" - 错误: {record['Details']}")
else:
record["AdditionalImagesCount"] = 0
print(f" - 没有找到额外图片目录")
# 更新状态
if record["Status"] == "Processing":
record["Status"] = "Success"
record["Details"] = "处理成功完成"
2025-05-13 19:01:27 +08:00
# 4. 将内容记录到数据库
if db_enabled:
try:
# 准备judge_status值
if record["JudgeStatus"] == "True":
judge_status = 1
elif record["JudgeStatus"] == "False":
judge_status = 0
else:
judge_status = None
# 调用数据库记录函数
success = record_to_database(
db_path,
entry_id=record["EntryID"],
output_txt_path=record["OutputTxtPath"],
poster_path=record["PosterPath"],
article_json_path=record["ArticleJsonPath"],
product=record["Product"],
object=record["Object"],
date=record["Date"],
logic=record["Logic"],
judge_status=judge_status,
is_distributed=0 # 默认为未分发
)
if success:
record["RecordedInDB"] = "Yes"
print(f" - 成功将内容记录到数据库")
else:
record["RecordedInDB"] = "Failed"
print(f" - 警告: 内容记录到数据库失败")
except Exception as e:
record["RecordedInDB"] = "Error"
print(f" - 错误: 记录到数据库时发生异常: {e}")
traceback.print_exc() # 打印详细的异常堆栈
else:
record["RecordedInDB"] = "Disabled"
print(f" - 信息: 数据库记录功能已禁用")
2025-05-13 09:34:59 +08:00
# 添加记录到CSV数据
csv_data.append([record[col] for col in csv_data[0]])
# 写入CSV清单
try:
print(f"\n正在写入清单CSV: {csv_path}")
with open(csv_path, 'w', newline='', encoding='utf-8-sig') as f_csv:
writer = csv.writer(f_csv)
writer.writerows(csv_data)
print(f"清单CSV生成成功")
except Exception as e:
print(f"写入CSV文件时出错: {e}")
traceback.print_exc()
print(f"\n处理完成. 共处理 {len(entries)} 个条目.")
print(f"结果保存在: {output_dir}")
def main():
parser = argparse.ArgumentParser(description="从TravelContentCreator结果目录提取内容并渲染到指定目录")
parser.add_argument("--source", type=str, help="源目录路径")
parser.add_argument("--output", type=str, help="输出目录路径")
parser.add_argument("--run-id", type=str, help="自定义运行ID")
parser.add_argument("--prefer-original", action="store_true", help="优先使用原始内容,忽略审核结果")
2025-05-13 19:01:27 +08:00
parser.add_argument("--db-path", type=str, help="数据库路径,若不提供则使用默认路径")
parser.add_argument("--disable-db", action="store_true", help="禁用数据库记录功能")
2025-05-13 09:34:59 +08:00
args = parser.parse_args()
# 默认值设置
source = args.source if args.source else "/root/autodl-tmp/TravelContentCreator/result/2025-05-14_22-10-37"
output = args.output if args.output else "/root/autodl-tmp/TravelContentCreator/output/2025-05-14_22-10-37"
2025-05-13 09:34:59 +08:00
run_id = args.run_id if args.run_id else os.path.basename(source)
prefer_original = args.prefer_original
2025-05-13 19:01:27 +08:00
db_path = args.db_path if args.db_path else '/root/autodl-tmp/TravelContentCreator/distribution.db'
2025-05-13 09:34:59 +08:00
print("-" * 60)
print(f"开始提取和渲染流程")
print(f"源目录: {source}")
print(f"输出目录: {output}")
print(f"运行ID: {run_id}")
if prefer_original:
print("内容模式: 优先使用原始内容")
else:
print("内容模式: 根据审核结果选择内容")
2025-05-13 19:01:27 +08:00
if args.disable_db:
print("数据库记录: 已禁用")
else:
print(f"数据库记录: 已启用 (路径: {db_path})")
2025-05-13 09:34:59 +08:00
print("-" * 60)
2025-05-13 19:01:27 +08:00
process_result_directory(source, output, run_id, prefer_original, db_path)
2025-05-13 09:34:59 +08:00
print("\n脚本执行完毕.")
if __name__ == "__main__":
main()