diff --git a/code/README_CnOCR.md b/code/README_CnOCR.md new file mode 100644 index 0000000..5cf922d --- /dev/null +++ b/code/README_CnOCR.md @@ -0,0 +1,164 @@ +# 视频字幕OCR提取器 - CnOCR集成 + +## 概述 + +字幕提取器现在支持三种OCR引擎: +- **PaddleOCR**: 百度开源OCR引擎 +- **EasyOCR**: 轻量级OCR引擎 +- **CnOCR**: 中文OCR专用引擎(新增) + +## CnOCR安装和配置 + +### 1. 自动安装(推荐) + +```bash +cd code +python install_cnocr.py +``` + +### 2. 手动安装 + +```bash +# 安装CnOCR +pip install cnocr[ort-cpu] -i https://pypi.tuna.tsinghua.edu.cn/simple + +# 创建模型目录 +mkdir -p /root/autodl-tmp/llm/cnocr + +# 设置环境变量 +export CNOCR_HOME=/root/autodl-tmp/llm/cnocr +``` + +## 使用方法 + +### 1. 单独使用CnOCR + +```bash +python ocr_subtitle_extractor.py your_video.mp4 -e cnocr +``` + +### 2. 使用所有OCR引擎 + +```bash +python ocr_subtitle_extractor.py your_video.mp4 -e all +``` + +### 3. 完整参数示例 + +```bash +python ocr_subtitle_extractor.py your_video.mp4 \ + -e cnocr \ + -l ch \ + -i 30 \ + -c 0.5 \ + -o results \ + -f json \ + --position bottom +``` + +## 参数说明 + +- `-e, --engine`: OCR引擎选择 + - `paddleocr`: 仅使用PaddleOCR + - `easyocr`: 仅使用EasyOCR + - `cnocr`: 仅使用CnOCR(新增) + - `all`: 使用所有三种引擎 + +- `-l, --language`: 语言设置 + - `ch`: 中文 + - `en`: 英文 + - `ch_en`: 中英文混合 + +- `-i, --interval`: 帧采样间隔(默认30帧) +- `-c, --confidence`: 置信度阈值(默认0.5) +- `-o, --output`: 输出目录 +- `-f, --format`: 输出格式(json/txt/srt) +- `--position`: 字幕区域位置(full/center/bottom) + +## CnOCR特点 + +1. **专为中文优化**: 对中文识别效果更好 +2. **轻量级**: 模型体积较小,运行速度快 +3. **易于部署**: 安装简单,依赖少 +4. **多种模型**: 支持多种检测和识别模型 + +## 测试CnOCR集成 + +```bash +python test_cnocr.py +``` + +这个脚本会: +1. 测试CnOCR安装 +2. 测试模型下载 +3. 测试字幕提取器集成 +4. 显示测试结果 + +## 模型存储位置 + +所有CnOCR模型文件都会下载到: +``` +/root/autodl-tmp/llm/cnocr/ +``` + +首次使用时会自动下载所需模型,请耐心等待。 + +## 输出格式 + +使用CnOCR时,识别结果中的`engine`字段会标记为`"CnOCR"`,便于区分不同引擎的结果。 + +## 性能对比 + +| 引擎 | 中文识别 | 英文识别 | 速度 | 模型大小 | +|------|----------|----------|------|----------| +| PaddleOCR | 优秀 | 优秀 | 中等 | 大 | +| EasyOCR | 良好 | 优秀 | 较慢 | 大 | +| CnOCR | 优秀 | 良好 | 较快 | 中等 | + +## 故障排除 + +### 1. 安装失败 +```bash +# 更新pip +pip install --upgrade pip + +# 使用国内源 +pip install cnocr[ort-cpu] -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +### 2. 模型下载失败 +```bash +# 检查网络连接 +# 确保有足够的磁盘空间 +# 重新运行安装脚本 +python install_cnocr.py +``` + +### 3. 环境变量问题 +```bash +# 在脚本开头添加 +export CNOCR_HOME=/root/autodl-tmp/llm/cnocr +``` + +## 示例输出 + +```json +{ + "video_path": "test_video.mp4", + "subtitles": [ + { + "timestamp": 1.5, + "text": "这是一个测试字幕", + "confidence": 0.95, + "bbox": [[10, 20], [200, 20], [200, 50], [10, 50]], + "engine": "CnOCR" + } + ], + "stats": { + "total_detections": 150, + "filtered_detections": 120, + "unique_texts": 50, + "average_confidence": 0.87 + } +} +``` \ No newline at end of file diff --git a/code/api_video_with_monitor.py b/code/api_video_with_monitor.py new file mode 100644 index 0000000..78d2381 --- /dev/null +++ b/code/api_video_with_monitor.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 +from openai import OpenAI +import os +import base64 +import time +import psutil +import subprocess +from datetime import datetime + + +class MemoryMonitor: + def __init__(self): + self.checkpoints = [] + self.initial_memory = self.get_memory_info() + + def get_memory_info(self): + """获取当前内存使用情况""" + memory = psutil.virtual_memory() + gpu_info = self.get_gpu_memory() + process = psutil.Process() + memory_info = process.memory_info() + + return { + "timestamp": datetime.now().isoformat(), + "system_memory_gb": memory.used / 1024**3, + "system_memory_percent": memory.percent, + "gpu_memory": gpu_info, + "process_memory_mb": memory_info.rss / 1024 / 1024 + } + + def get_gpu_memory(self): + """获取GPU内存使用情况""" + try: + result = subprocess.run(['nvidia-smi', '--query-gpu=memory.total,memory.used,memory.free', + '--format=csv,noheader,nounits'], + capture_output=True, text=True, check=True) + lines = result.stdout.strip().split('\n') + gpu_info = [] + for i, line in enumerate(lines): + parts = line.split(', ') + if len(parts) == 3: + total, used, free = map(int, parts) + gpu_info.append({ + "gpu_id": i, + "total_mb": total, + "used_mb": used, + "free_mb": free, + "usage_percent": round(used / total * 100, 2) + }) + return gpu_info + except: + return [] + + def checkpoint(self, name=""): + """创建内存检查点""" + current_memory = self.get_memory_info() + + if self.checkpoints: + last_memory = self.checkpoints[-1]["memory"] + memory_diff = { + "system_memory_gb": current_memory["system_memory_gb"] - last_memory["system_memory_gb"], + "process_memory_mb": current_memory["process_memory_mb"] - last_memory["process_memory_mb"], + } + + # GPU内存差异 + gpu_diff = [] + if current_memory["gpu_memory"] and last_memory["gpu_memory"]: + for i in range(min(len(current_memory["gpu_memory"]), len(last_memory["gpu_memory"]))): + current_gpu = current_memory["gpu_memory"][i]["used_mb"] + last_gpu = last_memory["gpu_memory"][i]["used_mb"] + gpu_diff.append({ + "gpu_id": i, + "used_mb_diff": current_gpu - last_gpu + }) + memory_diff["gpu_memory"] = gpu_diff + else: + memory_diff = None + + checkpoint = { + "name": name, + "memory": current_memory, + "memory_diff": memory_diff + } + + self.checkpoints.append(checkpoint) + return checkpoint + + def check_memory_risk(self): + """检查内存风险等级""" + current = self.get_memory_info() + + # 系统内存风险 + sys_risk = "低" + if current["system_memory_percent"] > 90: + sys_risk = "高" + elif current["system_memory_percent"] > 80: + sys_risk = "中" + + # GPU内存风险 + gpu_risk = "低" + if current["gpu_memory"]: + max_gpu_usage = max(gpu["usage_percent"] for gpu in current["gpu_memory"]) + if max_gpu_usage > 95: + gpu_risk = "高" + elif max_gpu_usage > 85: + gpu_risk = "中" + + return { + "system_risk": sys_risk, + "gpu_risk": gpu_risk, + "current_memory": current + } + + def print_memory_status(self, title=""): + """打印当前内存状态""" + current = self.get_memory_info() + risk = self.check_memory_risk() + + print(f"\n{'='*50}") + print(f"🔍 {title if title else '内存状态检查'}") + print(f"{'='*50}") + + # 系统内存 + risk_icon = {"低": "✅", "中": "⚠️", "高": "🚨"}[risk["system_risk"]] + print(f"💾 系统内存: {current['system_memory_gb']:.1f} GB ({current['system_memory_percent']:.1f}%) {risk_icon}") + + # GPU内存 + if current["gpu_memory"]: + risk_icon = {"低": "✅", "中": "⚠️", "高": "🚨"}[risk["gpu_risk"]] + for gpu in current["gpu_memory"]: + print(f"🎮 GPU {gpu['gpu_id']}: {gpu['used_mb']:.0f}/{gpu['total_mb']:.0f} MB ({gpu['usage_percent']:.1f}%) {risk_icon}") + + # 进程内存 + print(f"🔧 当前进程: {current['process_memory_mb']:.1f} MB") + + return risk + +def analyze_file_sizes(video_path, audio_path=None, txt_content=""): + """分析文件大小和预估内存占用""" + print(f"\n{'='*50}") + print("📊 文件大小分析") + print(f"{'='*50}") + + total_estimated_mb = 0 + warnings = [] + + # 视频文件分析 + if os.path.exists(video_path): + video_size = os.path.getsize(video_path) + video_size_mb = video_size / 1024 / 1024 + base64_size_mb = video_size_mb * 1.33 # Base64编码增加约33% + memory_estimate_mb = base64_size_mb * 2 # 编码过程需要双倍内存 + + print(f"🎥 视频文件: {os.path.basename(video_path)}") + print(f" 原始大小: {video_size_mb:.2f} MB") + print(f" Base64后: {base64_size_mb:.2f} MB") + print(f" 内存估算: {memory_estimate_mb:.2f} MB") + + total_estimated_mb += memory_estimate_mb + + if base64_size_mb > 100: + warnings.append("视频文件过大(>100MB Base64)") + elif base64_size_mb > 50: + warnings.append("视频文件较大(>50MB Base64)") + + # 音频文件分析 + if audio_path and os.path.exists(audio_path): + audio_size = os.path.getsize(audio_path) + audio_size_mb = audio_size / 1024 / 1024 + base64_size_mb = audio_size_mb * 1.33 + memory_estimate_mb = base64_size_mb * 2 + + print(f"\n🎵 音频文件: {os.path.basename(audio_path)}") + print(f" 原始大小: {audio_size_mb:.2f} MB") + print(f" Base64后: {base64_size_mb:.2f} MB") + print(f" 内存估算: {memory_estimate_mb:.2f} MB") + + total_estimated_mb += memory_estimate_mb + + if base64_size_mb > 50: + warnings.append("音频文件过大(>50MB Base64)") + + # 文本内容分析 + if txt_content: + text_size_mb = len(txt_content.encode('utf-8')) / 1024 / 1024 + print(f"\n📝 文本内容: {len(txt_content)} 字符 ({text_size_mb:.3f} MB)") + total_estimated_mb += text_size_mb + + if len(txt_content) > 50000: + warnings.append("文本内容过长(>50k字符)") + + print(f"\n📋 总估算内存: {total_estimated_mb:.2f} MB") + + # 风险评估 + if total_estimated_mb > 500: + print("🚨 高风险: 内容过大,强烈建议压缩或分段处理") + warnings.append("总内存占用过高(>500MB)") + elif total_estimated_mb > 200: + print("⚠️ 中风险: 建议监控内存使用") + warnings.append("总内存占用较高(>200MB)") + else: + print("✅ 低风险: 内存占用在可接受范围内") + + return total_estimated_mb, warnings + +# Base64 编码格式 +def encode_video(video_path): + with open(video_path, "rb") as video_file: + return base64.b64encode(video_file.read()).decode("utf-8") + +def encode_audio(audio_path): + with open(audio_path, "rb") as audio_file: + return base64.b64encode(audio_file.read()).decode("utf-8") + +def read_txt_file(txt_path): + """读取txt文件内容""" + try: + with open(txt_path, 'r', encoding='utf-8') as file: + content = file.read() + print(f"成功读取txt文件: {txt_path}") + print(f"文件内容长度: {len(content)} 字符") + return content + except FileNotFoundError: + print(f"错误: 找不到文件 {txt_path}") + return "" + except Exception as e: + print(f"读取文件时出错: {e}") + return "" + +def save_result_to_txt(response_text, video_path, save_dir="results"): + """将分析结果保存为TXT文件""" + os.makedirs(save_dir, exist_ok=True) + + video_name = os.path.splitext(os.path.basename(video_path))[0] + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + txt_filename = f"{video_name}_analysis_{timestamp}.txt" + txt_path = os.path.join(save_dir, txt_filename) + + content = f"""视频分析结果 +===================================== +视频文件: {video_path} +分析时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +===================================== + +{response_text} +""" + + try: + with open(txt_path, 'w', encoding='utf-8') as f: + f.write(content) + print(f"\n✅ 分析结果已保存到: {txt_path}") + return txt_path + except Exception as e: + print(f"\n❌ 保存TXT文件失败: {e}") + return None + +# 初始化内存监控器 +monitor = MemoryMonitor() + +STREAM_MODE = True + +# 文件路径配置 +video_path = "/root/autodl-tmp/video2audio/sample_demo_6.mp4" +audio_path = "/root/autodl-tmp/video2audio/sample_demo_6.wav" +#txt_path = "/root/autodl-tmp/hot_video_analyse/source/example_reference.txt" + +# 初始内存检查 +monitor.checkpoint("程序启动") +monitor.print_memory_status("程序启动时内存状态") + +# 分析文件大小和预估内存占用 +txt_content = "" +estimated_memory, warnings = analyze_file_sizes(video_path, txt_content=txt_content) + +# 如果有警告,询问是否继续 +if warnings: + print(f"\n⚠️ 发现以下潜在问题:") + for warning in warnings: + print(f" - {warning}") + print(f"\n建议:") + print(f" - 使用更小的测试文件") + print(f" - 监控内存使用情况") + print(f" - 如遇到错误,尝试压缩文件") + +# 编码前内存检查 +monitor.checkpoint("开始编码前") +risk_before = monitor.check_memory_risk() + +if risk_before["system_risk"] == "高" or risk_before["gpu_risk"] == "高": + print(f"\n🚨 警告: 当前内存使用率已经很高,继续可能导致内存溢出!") + print(f" 系统内存风险: {risk_before['system_risk']}") + print(f" GPU内存风险: {risk_before['gpu_risk']}") + +print("\n开始编码文件...") +encode_start_time = time.time() + +try: + base64_video = encode_video(video_path) + print(f"✅ 视频编码完成") +except Exception as e: + print(f"❌ 视频编码失败: {e}") + monitor.print_memory_status("编码失败时内存状态") + exit(1) +base64_audio = encode_audio(audio_path) +# 编码后内存检查 +monitor.checkpoint("编码完成") +encode_end_time = time.time() +encode_duration = encode_end_time - encode_start_time + +print(f"📁 文件编码完成,耗时: {encode_duration:.2f} 秒") + +# 检查编码后内存变化 +last_checkpoint = monitor.checkpoints[-1] +if last_checkpoint["memory_diff"]: + diff = last_checkpoint["memory_diff"] + print(f"📊 编码过程内存变化:") + print(f" 进程内存增加: {diff['process_memory_mb']:+.1f} MB") + if diff["gpu_memory"]: + for gpu_diff in diff["gpu_memory"]: + print(f" GPU {gpu_diff['gpu_id']} 内存变化: {gpu_diff['used_mb_diff']:+.0f} MB") + +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1", +) + +# 构建content列表 +content_list = [ + { + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{base64_video}"}, + "type": "audio_url", + "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}, + } +] + +# 如果txt文件有内容,添加到content中 +if txt_content.strip(): + content_list.append({ + "type": "text", + "text": f"参考文档内容:\n{txt_content}\n\n" + }) + +# 添加主要提示文本(简化版以减少内存使用) +content_list.append({ + "type": "text", + "text": """请分析这个抖音短视频的内容: + +1. **口播内容**:转录视频中的语音内容 +2. **字幕文字**:识别画面中的文字和字幕 +3. **勾子分析**:分析视频的开头勾子策略 + +请用JSON格式输出结果: +{ + "口播分析": {"是否有口播": "", "口播内容": "", "讲话时长": ""}, + "字幕分析": {"是否有字幕": "", "字幕内容": "", "字幕位置": ""}, + "勾子分析": {"勾子类型": "", "勾子公式": "", "勾子内容": ""} +}""" +}) + +# API请求前内存检查 +monitor.checkpoint("API请求前") +monitor.print_memory_status("API请求前内存状态") + +print(f"\n🚀 开始请求API...") +print(f"📅 请求时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") +print(f"🔄 Stream模式: {STREAM_MODE}") +print(f"📋 Content项目数量: {len(content_list)}") + +# 计算请求大小 +total_request_size = sum(len(str(content)) for content in content_list) +print(f"📏 请求总大小: {total_request_size/1024/1024:.2f} MB") + +api_start_time = time.time() + +try: + completion = client.chat.completions.create( + model="/root/autodl-tmp/llm/Qwen-omni", + messages=[ + { + "role": "system", + "content": [{"type":"text","text": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": content_list + } + ], + stream=STREAM_MODE, + stream_options={"include_usage": True} if STREAM_MODE else None, + max_tokens=1024, # 限制输出长度以节省内存 + ) + + if STREAM_MODE: + full_response = "" + usage_info = None + first_token_time = None + token_count = 0 + + print("✨ 正在生成回复...") + for chunk in completion: + if chunk.choices: + delta = chunk.choices[0].delta + if delta.content: + if first_token_time is None: + first_token_time = time.time() + first_token_delay = first_token_time - api_start_time + print(f"🚀 首个token延迟: {first_token_delay:.2f} 秒") + + full_response += delta.content + token_count += 1 + else: + usage_info = chunk.usage + + api_end_time = time.time() + total_duration = api_end_time - api_start_time + + print("\n" + "="*50) + print("📝 完整回复:") + print("="*50) + print(full_response) + + # 保存结果为TXT文件 + txt_file_path = save_result_to_txt(full_response, video_path) + + # API完成后内存检查 + monitor.checkpoint("API完成") + + # 输出时间统计信息 + print("\n" + "="*50) + print("⏱️ 时间统计:") + print("="*50) + print(f"📁 文件编码时间: {encode_duration:.2f} 秒") + if first_token_time: + print(f"🚀 首个token延迟: {first_token_delay:.2f} 秒") + generation_time = api_end_time - first_token_time + print(f"⚡ 内容生成时间: {generation_time:.2f} 秒") + print(f"🕐 API总响应时间: {total_duration:.2f} 秒") + print(f"📊 生成token数量: {token_count}") + if first_token_time and token_count > 0: + tokens_per_second = token_count / generation_time + print(f"🔥 生成速度: {tokens_per_second:.2f} tokens/秒") + print(f"⏰ 完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + if usage_info: + print(f"\n📈 使用情况: {usage_info}") + +except Exception as e: + print(f"\n❌ API请求失败!") + print(f"错误类型: {type(e)}") + print(f"错误信息: {e}") + + # 错误时进行内存检查 + monitor.checkpoint("API错误") + monitor.print_memory_status("API错误时内存状态") + + # 分析可能的原因 + if "Internal Server Error" in str(e) or "OutOfMemoryError" in str(e): + print(f"\n💡 可能的内存溢出原因:") + print(f" - 视频文件过大 ({estimated_memory:.1f} MB)") + print(f" - GPU内存不足") + print(f" - 系统内存不足") + print(f"\n建议解决方案:") + print(f" - 使用更小的视频文件") + print(f" - 重启vLLM服务释放GPU内存") + print(f" - 降低max_tokens限制") + +# 最终内存状态报告 +print(f"\n{'='*60}") +print("📊 最终内存使用报告") +print(f"{'='*60}") + +for i, checkpoint in enumerate(monitor.checkpoints): + print(f"{i+1}. {checkpoint['name']}") + if checkpoint['memory_diff']: + diff = checkpoint['memory_diff'] + if abs(diff['process_memory_mb']) > 10: # 只显示显著变化 + print(f" 进程内存变化: {diff['process_memory_mb']:+.1f} MB") + if diff['gpu_memory']: + for gpu_diff in diff['gpu_memory']: + if abs(gpu_diff['used_mb_diff']) > 50: # 只显示显著变化 + print(f" GPU {gpu_diff['gpu_id']} 变化: {gpu_diff['used_mb_diff']:+.0f} MB") + +monitor.print_memory_status("程序结束时内存状态") \ No newline at end of file diff --git a/code/batch_subtitle_extractor.py b/code/batch_subtitle_extractor.py new file mode 100644 index 0000000..4183c1a --- /dev/null +++ b/code/batch_subtitle_extractor.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +批量视频字幕提取器 +支持批量处理多个视频文件,提取字幕 +支持PaddleOCR、EasyOCR和CnOCR三种引擎 +""" + +import os +import sys +import time +import json +import argparse +from pathlib import Path +from datetime import datetime +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed + +# 添加当前目录到路径以导入OCR模块 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from ocr_subtitle_extractor import VideoSubtitleExtractor + +# 设置OCR模型路径环境变量 +os.environ['EASYOCR_MODULE_PATH'] = '/root/autodl-tmp/llm/easyocr' +os.environ['CNOCR_HOME'] = '/root/autodl-tmp/llm/cnocr' + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class BatchSubtitleExtractor: + """批量视频字幕提取器""" + + def __init__(self, ocr_engine="paddleocr", language="ch", max_workers=2): + """ + 初始化批量提取器 + + Args: + ocr_engine: OCR引擎 ("paddleocr", "easyocr", "cnocr", "all") + language: 语言设置 ("ch", "en", "ch_en") + max_workers: 最大并行工作数 + """ + self.ocr_engine = ocr_engine + self.language = language + self.max_workers = max_workers + self.extractor = VideoSubtitleExtractor(ocr_engine=ocr_engine, language=language) + + def find_video_files(self, input_dir): + """查找目录中的所有视频文件""" + video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'] + video_files = [] + + input_path = Path(input_dir) + + if input_path.is_file(): + # 单个文件 + if input_path.suffix.lower() in video_extensions: + video_files.append(input_path) + elif input_path.is_dir(): + # 目录中的所有视频文件 + for ext in video_extensions: + video_files.extend(input_path.glob(f"*{ext}")) + video_files.extend(input_path.glob(f"*{ext.upper()}")) + + return sorted(video_files) + + def extract_single_video(self, video_path, output_dir, **kwargs): + """ + 处理单个视频文件 + + Args: + video_path: 视频文件路径 + output_dir: 输出目录 + **kwargs: 其他参数 + + Returns: + dict: 处理结果 + """ + video_path = Path(video_path) + video_name = video_path.stem + + logger.info(f"开始处理视频: {video_path}") + start_time = time.time() + + try: + # 提取字幕 + results = self.extractor.extract_subtitles_from_video( + str(video_path), + sample_interval=kwargs.get('interval', 30), + confidence_threshold=kwargs.get('confidence', 0.5), + subtitle_position=kwargs.get('position', 'bottom') + ) + + # 保存结果 + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + for format_type in kwargs.get('formats', ['json']): + output_file = output_path / f"{video_name}_subtitles.{format_type}" + self.extractor.save_results(results, output_file, format_type) + + process_time = time.time() - start_time + results['process_time'] = process_time + results['video_path'] = str(video_path) + results['success'] = True + + # 统计位置信息 + subtitles_with_bbox = [s for s in results['subtitles'] if s.get('bbox')] + bbox_coverage = len(subtitles_with_bbox) / len(results['subtitles']) * 100 if results['subtitles'] else 0 + + logger.info(f"完成处理视频: {video_path} (耗时: {process_time:.2f}秒)") + logger.info(f" 字幕总数: {len(results['subtitles'])}") + logger.info(f" 有位置信息: {len(subtitles_with_bbox)}") + logger.info(f" 位置信息覆盖率: {bbox_coverage:.1f}%") + + return { + 'video_path': str(video_path), + 'success': True, + 'process_time': process_time, + 'subtitle_count': results['stats']['filtered_detections'], + 'text_length': results['stats']['text_length'], + 'total_subtitles': len(results['subtitles']), + 'subtitles_with_bbox': len(subtitles_with_bbox), + 'bbox_coverage': bbox_coverage, + 'output_files': [str(output_path / f"{video_name}_subtitles.{fmt}") for fmt in kwargs.get('formats', ['json'])] + } + + except Exception as e: + error_msg = f"处理视频 {video_path} 时出错: {str(e)}" + logger.error(error_msg) + + return { + 'video_path': str(video_path), + 'success': False, + 'error': error_msg, + 'process_time': time.time() - start_time + } + + def extract_batch(self, input_dir, output_dir, parallel=True, **kwargs): + """ + 批量提取字幕 + + Args: + input_dir: 输入目录或文件 + output_dir: 输出目录 + parallel: 是否并行处理 + **kwargs: 其他参数 + + Returns: + dict: 批量处理结果 + """ + logger.info(f"开始批量字幕提取") + logger.info(f"输入: {input_dir}") + logger.info(f"输出目录: {output_dir}") + logger.info(f"OCR引擎: {self.ocr_engine}") + logger.info(f"字幕位置: {kwargs.get('position', 'bottom')}") + logger.info(f"并行处理: {parallel}") + + start_time = time.time() + + # 查找视频文件 + video_files = self.find_video_files(input_dir) + + if not video_files: + logger.warning(f"在 {input_dir} 中未找到视频文件") + return { + 'success': False, + 'message': '未找到视频文件', + 'total_files': 0, + 'results': [] + } + + logger.info(f"找到 {len(video_files)} 个视频文件") + + results = [] + + if parallel and len(video_files) > 1: + # 并行处理 + logger.info(f"使用 {self.max_workers} 个并行工作进程") + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + # 提交任务 + future_to_video = { + executor.submit(self.extract_single_video, video_file, output_dir, **kwargs): video_file + for video_file in video_files + } + + # 收集结果 + for future in as_completed(future_to_video): + video_file = future_to_video[future] + try: + result = future.result() + results.append(result) + + # 显示进度 + progress = len(results) / len(video_files) * 100 + logger.info(f"批量处理进度: {progress:.1f}% ({len(results)}/{len(video_files)})") + + except Exception as e: + logger.error(f"处理视频 {video_file} 时发生异常: {str(e)}") + results.append({ + 'video_path': str(video_file), + 'success': False, + 'error': str(e) + }) + else: + # 串行处理 + for i, video_file in enumerate(video_files, 1): + logger.info(f"处理第 {i}/{len(video_files)} 个视频") + result = self.extract_single_video(video_file, output_dir, **kwargs) + results.append(result) + + # 显示进度 + progress = i / len(video_files) * 100 + logger.info(f"批量处理进度: {progress:.1f}%") + + total_time = time.time() - start_time + + # 统计结果 + success_count = sum(1 for r in results if r['success']) + failed_count = len(results) - success_count + + total_subtitles = sum(r.get('subtitle_count', 0) for r in results if r['success']) + total_text_length = sum(r.get('text_length', 0) for r in results if r['success']) + + # 统计位置信息 + total_subtitles_raw = sum(r.get('total_subtitles', 0) for r in results if r['success']) + total_subtitles_with_bbox = sum(r.get('subtitles_with_bbox', 0) for r in results if r['success']) + overall_bbox_coverage = total_subtitles_with_bbox / total_subtitles_raw * 100 if total_subtitles_raw > 0 else 0 + + batch_result = { + 'success': True, + 'total_time': total_time, + 'total_files': len(video_files), + 'success_count': success_count, + 'failed_count': failed_count, + 'total_subtitles': total_subtitles, + 'total_text_length': total_text_length, + 'total_subtitles_raw': total_subtitles_raw, + 'total_subtitles_with_bbox': total_subtitles_with_bbox, + 'overall_bbox_coverage': overall_bbox_coverage, + 'output_directory': output_dir, + 'ocr_engine': self.ocr_engine, + 'timestamp': datetime.now().isoformat(), + 'results': results + } + + # 保存批量处理报告 + report_file = Path(output_dir) / "batch_report.json" + with open(report_file, 'w', encoding='utf-8') as f: + json.dump(batch_result, f, ensure_ascii=False, indent=2) + + logger.info(f"批量处理完成!") + logger.info(f"总文件数: {len(video_files)}") + logger.info(f"成功: {success_count}, 失败: {failed_count}") + logger.info(f"总耗时: {total_time:.2f} 秒") + logger.info(f"提取字幕: {total_subtitles} 个") + logger.info(f"文本长度: {total_text_length} 字符") + logger.info(f"位置信息统计:") + logger.info(f" 总字幕数: {total_subtitles_raw}") + logger.info(f" 有位置信息: {total_subtitles_with_bbox}") + logger.info(f" 位置信息覆盖率: {overall_bbox_coverage:.1f}%") + logger.info(f"处理报告: {report_file}") + + return batch_result + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="批量视频字幕提取器") + parser.add_argument("input", help="输入视频文件或目录") + parser.add_argument("-e", "--engine", default="cnocr", + choices=["paddleocr", "easyocr", "cnocr", "all"], + help="OCR引擎 (默认: paddleocr)") + parser.add_argument("-l", "--language", default="ch", + choices=["ch", "en", "ch_en"], + help="语言设置 (默认: ch)") + parser.add_argument("-i", "--interval", type=int, default=30, + help="帧采样间隔 (默认: 30)") + parser.add_argument("-c", "--confidence", type=float, default=0.5, + help="置信度阈值 (默认: 0.5)") + parser.add_argument("-o", "--output", default="batch_subtitles", + help="输出目录 (默认: batch_subtitles)") + parser.add_argument("-f", "--formats", nargs='+', default=["json"], + choices=["json", "txt", "srt"], + help="输出格式 (默认: json)") + parser.add_argument("--position", default="full", + choices=["full", "center", "bottom"], + help="字幕区域位置 (full=全屏, center=居中0.5-0.8, bottom=居下0.7-1.0)") + parser.add_argument("--workers", type=int, default=2, + help="并行工作进程数 (默认: 2)") + parser.add_argument("--no-parallel", action="store_true", + help="禁用并行处理") + + args = parser.parse_args() + + # 创建批量提取器 + batch_extractor = BatchSubtitleExtractor( + ocr_engine=args.engine, + language=args.language, + max_workers=args.workers + ) + + try: + # 执行批量提取 + result = batch_extractor.extract_batch( + input_dir=args.input, + output_dir=args.output, + parallel=not args.no_parallel, + interval=args.interval, + confidence=args.confidence, + formats=args.formats, + position=args.position + ) + + if result['success']: + print(f"\n✅ 批量字幕提取完成!") + print(f"📁 输出目录: {args.output}") + print(f"📊 成功处理: {result['success_count']}/{result['total_files']} 个视频") + if result['failed_count'] > 0: + print(f"❌ 失败: {result['failed_count']} 个") + print(f"⏱️ 总耗时: {result['total_time']:.2f} 秒") + print(f"📝 字幕片段: {result['total_subtitles']} 个") + print(f"📏 文本长度: {result['total_text_length']} 字符") + print(f"📍 位置信息统计:") + print(f" 总字幕数: {result['total_subtitles_raw']}") + print(f" 有位置信息: {result['total_subtitles_with_bbox']}") + print(f" 位置信息覆盖率: {result['overall_bbox_coverage']:.1f}%") + else: + print(f"\n❌ 批量处理失败: {result.get('message', '未知错误')}") + + except Exception as e: + logger.error(f"批量处理出错: {str(e)}") + print(f"\n❌ 批量处理出错: {str(e)}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/copy_video.py b/code/copy_video.py new file mode 100644 index 0000000..9415a7f --- /dev/null +++ b/code/copy_video.py @@ -0,0 +1,545 @@ +from openai import OpenAI +import os +import base64 +import time +from datetime import datetime + + +# Base64 编码格式 +def encode_video(video_path): + with open(video_path, "rb") as video_file: + return base64.b64encode(video_file.read()).decode("utf-8") + +def encode_audio(audio_path): + with open(audio_path, "rb") as audio_file: + return base64.b64encode(audio_file.read()).decode("utf-8") + +def read_txt_file(txt_path): + """读取txt文件内容""" + try: + with open(txt_path, 'r', encoding='utf-8') as file: + content = file.read() + print(f"成功读取txt文件: {txt_path}") + print(f"文件内容长度: {len(content)} 字符") + return content + except FileNotFoundError: + print(f"错误: 找不到文件 {txt_path}") + return "" + except Exception as e: + print(f"读取文件时出错: {e}") + return "" + +def read_json_file(json_path): + """读取JSON文件内容""" + try: + import json + with open(json_path, 'r', encoding='utf-8') as file: + data = json.load(file) + print(f"成功读取JSON文件: {json_path}") + return data + except FileNotFoundError: + print(f"错误: 找不到文件 {json_path}") + return None + except json.JSONDecodeError as e: + print(f"JSON解析错误: {e}") + return None + except Exception as e: + print(f"读取JSON文件时出错: {e}") + return None + +def format_speech_json(speech_data): + """格式化口播转文字JSON数据(支持SenseVoice格式)""" + if not speech_data: + return "" + + formatted_text = "【口播转文字内容】\n" + + if isinstance(speech_data, dict): + # 新SenseVoice格式 - 处理raw_result + if 'raw_result' in speech_data: + raw_result = speech_data['raw_result'] + if isinstance(raw_result, list) and len(raw_result) > 0: + # 提取所有文本内容 + all_texts = [] + for item in raw_result: + if isinstance(item, dict) and 'text' in item: + text = item['text'] + # 清理SenseVoice的特殊标签 + import re + clean_text = re.sub(r'<\|[^|]+\|>', '', text) + clean_text = ' '.join(clean_text.split()) + if clean_text.strip(): + all_texts.append(clean_text.strip()) + + if all_texts: + formatted_text += f"完整转录文本: {' '.join(all_texts)}\n" + + # 基本信息 + if 'model' in speech_data: + formatted_text += f"转录模型: {speech_data['model']}\n" + + if 'transcribe_time' in speech_data: + formatted_text += f"转录耗时: {speech_data['transcribe_time']:.3f}秒\n" + + if 'file_path' in speech_data: + formatted_text += f"音频文件: {speech_data['file_path']}\n" + + # 旧SenseVoice格式(兼容) + elif 'clean_text' in speech_data: + formatted_text += f"完整转录文本: {speech_data['clean_text']}\n" + + if 'model' in speech_data: + formatted_text += f"转录模型: {speech_data['model']}\n" + + if 'transcribe_time' in speech_data: + formatted_text += f"转录耗时: {speech_data['transcribe_time']:.3f}秒\n" + + # 情绪分析 + if 'emotions' in speech_data and speech_data['emotions']: + emotions = [emotion.get('emotion', '') for emotion in speech_data['emotions']] + formatted_text += f"情绪分析: {', '.join(emotions)}\n" + + # 背景事件 + if 'events' in speech_data and speech_data['events']: + events = [event.get('event', '') for event in speech_data['events']] + formatted_text += f"音频事件: {', '.join(events)}\n" + + # 如果是字幕提取器的格式(备用) + elif 'continuous_text' in speech_data: + formatted_text += f"完整文本: {speech_data['continuous_text']}\n" + + if 'stats' in speech_data: + stats = speech_data['stats'] + formatted_text += f"统计信息: 检测数量{stats.get('filtered_detections', 0)}个," + formatted_text += f"平均置信度{stats.get('average_confidence', 0):.3f}\n" + + return formatted_text + +def format_whisper_json(whisper_data): + """格式化Whisper口播转文字JSON数据""" + if not whisper_data: + return "" + + formatted_text = "【Whisper口播转文字内容】\n" + + if isinstance(whisper_data, dict): + # 基本信息 + # 详细时间轴 - 显示所有片段 + if 'segments' in whisper_data and len(whisper_data['segments']) > 0: + formatted_text += "\n详细时间轴:\n" + for segment in whisper_data['segments']: + segment_id = segment.get('id', 0) + start_time = segment.get('start', 0) + end_time = segment.get('end', 0) + text = segment.get('text', '') + formatted_text += f" id:{segment_id}, start:{start_time:.2f}, end:{end_time:.2f}, text:{text}\n" + + return formatted_text + +def format_ocr_json(ocr_data): + """格式化OCR字幕转文字JSON数据""" + if not ocr_data: + return "" + + formatted_text = "【OCR字幕识别内容】\n" + + # 如果是字幕提取器的格式 + if isinstance(ocr_data, dict): + # 显示使用的OCR引擎 + # if 'ocr_engine' in ocr_data: + # formatted_text += f"OCR引擎: {ocr_data['ocr_engine']}\n" + + if 'continuous_text' in ocr_data: + formatted_text += f"完整字幕文本: {ocr_data['continuous_text']}\n" + + # if 'subtitles' in ocr_data and len(ocr_data['subtitles']) > 0: + # formatted_text += "详细字幕时间轴:\n" + # for subtitle in ocr_data['subtitles'][:10]: # 只显示前10个,避免过长 + # timestamp = subtitle.get('timestamp', 0) + # text = subtitle.get('text', '') + # engine = subtitle.get('engine', '') + # confidence = subtitle.get('confidence', 0) + # formatted_text += f" {timestamp:.2f}s [{engine}|{confidence:.3f}]: {text}\n" + + # if len(ocr_data['subtitles']) > 10: + # formatted_text += f" ... (还有{len(ocr_data['subtitles']) - 10}个字幕片段)\n" + + return formatted_text + +def format_clip_json(clip_data): + """格式化视频转场分析JSON数据""" + if not clip_data: + return "" + + formatted_text = "【视频转场分析内容】\n" + + if isinstance(clip_data, dict): + # 显示视频基本信息 + if 'video_name' in clip_data: + formatted_text += f"视频名称: {clip_data['video_name']}\n" + + if 'analysis_time' in clip_data: + formatted_text += f"分析时间: {clip_data['analysis_time']}\n" + + if 'total_scenes' in clip_data: + formatted_text += f"检测到场景数: {clip_data['total_scenes']} 个\n" + + # 详细场景信息 + if 'scenes' in clip_data and len(clip_data['scenes']) > 0: + formatted_text += "\n详细场景信息:\n" + for i, scene in enumerate(clip_data['scenes'], 1): + formatted_text += f"scenes {i}:\n" + formatted_text += f" start_time: {scene.get('start_time', 0):.2f}秒\n" + formatted_text += f" end_time: {scene.get('end_time', 0):.2f}秒\n" + formatted_text += f" duration: {scene.get('duration', 0):.2f}秒\n" + formatted_text += f" type: {scene.get('type')}\n" + formatted_text += "\n" + + return formatted_text + +def save_result_to_txt(response_text, video_path, save_dir="/root/autodl-tmp/final_output"): + """将分析结果保存为TXT文件""" + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + + # 生成文件名(基于视频文件名和时间戳) + video_name = os.path.splitext(os.path.basename(video_path))[0] + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + txt_filename = f"{video_name}_analysis_{timestamp}.txt" + txt_path = os.path.join(save_dir, txt_filename) + + # 准备保存内容(添加头部信息) + content = f"""视频分析结果 +===================================== +视频文件: {video_path} +分析时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} +===================================== + +{response_text} +""" + + # 保存到文件 + try: + with open(txt_path, 'w', encoding='utf-8') as f: + f.write(content) + print(f"\n✅ 分析结果已保存到: {txt_path}") + return txt_path + except Exception as e: + print(f"\n❌ 保存TXT文件失败: {e}") + return None + +STREAM_MODE = True + +# 文件路径配置 +video_path = "/root/autodl-tmp/new/老挝泼水节.mp4" +#audio_path = "/root/autodl-tmp/video2audio/sample_demo_6.wav" +#txt_path = "/root/autodl-tmp/hot_video_analyse/source/example_reference.txt" # 使用示例参考文档 + +# JSON文件路径配置 +speech_json_path = "/root/autodl-tmp/new_sensevoice/老挝泼水节_sensevoice.json" # 口播转文字JSON文件 +ocr_json_path = "/root/autodl-tmp/new_cnocr/老挝泼水节_subtitles.json" # OCR字幕转文字JSON文件 +#clip_json_path = "/root/autodl-tmp/02_VideoSplitter/VideoSplitter_output/shou_gonglve_3_scenes.json" +whisper_json_path = "/root/autodl-tmp/new_whisper/老挝泼水节_transcript.json" # Whisper转文字JSON文件 + +# 编码文件 +print("开始编码文件...") +encode_start_time = time.time() + +base64_video = encode_video(video_path) +#base64_audio = encode_audio(audio_path) +#txt_content = read_txt_file(txt_path) + +#读取JSON文件内容 +print("读取JSON文件...") +speech_data = read_json_file(speech_json_path) +ocr_data = read_json_file(ocr_json_path) +#clip_data = read_json_file(clip_json_path) +whisper_data = read_json_file(whisper_json_path) + +# 格式化JSON内容 +speech_content = format_speech_json(speech_data) +ocr_content = format_ocr_json(ocr_data) +#clip_content = format_clip_json(clip_data) +whisper_content = format_whisper_json(whisper_data) + +# # 合并内容 +txt_content = "" +# if speech_content: +# txt_content += speech_content + "\n\n" +if ocr_content: + txt_content += ocr_content + "\n\n" +# if clip_content: +# txt_content += clip_content + "\n\n" +if whisper_content: + txt_content += whisper_content + "\n\n" + +print(f"合并后的参考内容长度: {len(txt_content)} 字符") +print(txt_content) +encode_end_time = time.time() +encode_duration = encode_end_time - encode_start_time +print(f"文件编码完成,耗时: {encode_duration:.2f} 秒") + + +client = OpenAI( + # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx" + api_key="EMPTY", + base_url="http://localhost:8000/v1", +) + +# 构建content列表 +content_list = [ + { + # 直接传入视频文件时,请将type的值设置为video_url + "type": "video_url", + "video_url": {"url": f"data:video/mp4;base64,{base64_video}"}, + } + + # , + # { + # "type": "audio_url", + # "audio_url": {"url": f"data:audio/wav;base64,{base64_audio}"}, + # } +] + +# 如果txt文件有内容,添加到content中 +if txt_content.strip(): + content_list.append({ + "type": "text", + "text": f"参考资料内容:\n{txt_content}\n\n", + "need": "第一部分是视频内容,第二部分是视频的字幕时间轴内容,第三部分是口播的字幕时间轴内容" + }) + +# 添加主要提示文本 +content_list.append({ + "type": "text", + "text": """🎥 **抖音短视频内容分析专家** + ## 任务背景 +您是一位经验丰富的视频导演和编辑,需要基于以上两个时间轴数据,和视频内容。为视频写一个完整、流畅的脚本。 +请对这个抖音短视频进行详细的内容分析,重点关注以下三个方面: +## 🎤 一、口播内容提取 +请仔细听取视频中的语音内容,完整转录: +- **完整口播转录**:参考口播的字幕时间轴内容和视频内容,逐字逐句转录所有口语表达 +- **语音时长**:估算总的讲话时长 +## 📝 二、字幕文字识别 +请识别视频画面中出现的所有文字内容: +- **屏幕字幕**:参考口播的字幕时间轴内容和视频内容,识别字幕 +- **标题文字**:识别停靠时间稍长的,视频开头、中间、结尾出现的大标题和贴图。 + +## 🎬 三、转场效果分析 +请仔细观察视频中的转场效果,并且结合参考资料中的时间轴内容,请你整体分析一下视频。比如几个画面出现第一个转场等. +转场的time_start","time_end","textIdx"请严格按照参考资料中的口播内容的时间戳start,end,id和字幕内容的时间戳“开始时间“,”结束时间“ +填写,不要自己生成。 + + +## 📊 输出格式要求 + +## 视频内容分析 +请按照以下JSON格式输出视频描述: + +{ + "total_Oral broadcasting":"请你生成一个完整的口播内容。", + "summary": "请用一句话总结视频的核心内容,突出视频的主要卖点和价值主张", + "content": [ + { + "type": "cut", + "scenes": 1, + "time_start": 0.0, + "time_end": 2.0, + "talk": "请将对应时间的口播或字幕信息,填入此", + "description": "详细描述这个镜头的画面内容、人物动作、场景特点等" + }, + + { + "type": "cut", + "scenes": 2, + "time_start": 2.0, + "time_end": 4.5, + "talk": "请将对应时间的口播或字幕信息,填入此", + "description": "描述这个镜头的具体内容,包括画面细节、转场效果等" + }, + + { + "type": "cut", + "scenes": 3, + "time_start": 4.5, + "time_end": 6.0, + "talk": "请将对应时间的口播或字幕信息,填入此", + "description": "描述这个镜头的具体内容,包括画面细节、转场效果等" + } + ] +} + +## 输出要求 +1. summary:用一句话概括视频核心内容,突出主要卖点 +2. content:按时间顺序交替描述镜头和转场 + - 镜头(lens)描述: + * textIdx:镜头序号,从1开始递增 + * time_start:开始时间(秒),精确到小数点后一位 + * time_end:结束时间(秒),精确到小数点后一位 + * talk:该镜头中的对话或文字内容 + * description:详细描述镜头内容,包括: + - 画面构图和场景 + - 人物动作和表情 + - 重要道具和元素 + - 特殊效果和转场 + + +## 注意事项 +1. 保持描述简洁明了,但要有足够的细节 +2. 突出视频的亮点和特色 +3. 确保时间戳的准确性 +4. 对话内容要符合视频画面 +5. 整体风格要统一连贯 +6. 每个镜头的描述要包含关键信息 + +## 示例内容描述 +1. 镜头1: + - 开场特写镜头,展示产品外观 + - 画面从模糊到清晰,突出产品细节 + - 背景音乐渐入,营造氛围 + - 文字提示:"全新升级,品质保证" + +2. 转场1-2: + - 类型:平滑滑动 + - 目的:自然过渡到使用场景 + - 效果:画面从产品特写平滑滑向使用场景 + +3. 镜头2: + - 中景展示使用场景 + - 人物自然流畅的动作展示 + - 光线明亮,突出产品效果 + - 文字说明:"简单操作,轻松上手" + +4. 转场2-3: + - 类型:快速缩放 + - 目的:突出产品核心功能 + - 效果:画面快速聚焦到产品关键部位 + +5. 镜头3: + - 特写展示产品核心功能 + - 慢动作展示关键细节 + - 画面色彩鲜明,对比强烈 + - 文字强调:"专业性能,值得信赖" + +请根据以上要求,分析视频并输出JSON格式的描述。 + +请开始详细分析这个抖音短视频:""" +}) + +print(f"\n开始请求API...") +print(f"请求时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") +print(f"Stream模式: {STREAM_MODE}") +print(f"Content项目数量: {len(content_list)}") + +# 记录API请求开始时间 +api_start_time = time.time() +completion = client.chat.completions.create( + model="/root/autodl-tmp/llm/Qwen-omni", + messages=[ + { + "role": "system", + "content": [{"type":"text","text": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": content_list + } + ], + stream=STREAM_MODE, + stream_options={"include_usage": True} if STREAM_MODE else None, + temperature=0.3 +) + +if STREAM_MODE: + # 流式输出 - 拼接完整回复 + full_response = "" + usage_info = None + + # 记录第一个token的时间 + first_token_time = None + token_count = 0 + + print("正在生成回复...") + for chunk in completion: + if chunk.choices: + delta = chunk.choices[0].delta + if delta.content: + # 记录第一个token的时间 + if first_token_time is None: + first_token_time = time.time() + first_token_delay = first_token_time - api_start_time + print(f"首个token延迟: {first_token_delay:.2f} 秒") + + # 拼接内容 + full_response += delta.content + token_count += 1 + + # 实时显示(可选) + #print(delta.content, end='', flush=True) + else: + # 保存使用情况信息 + usage_info = chunk.usage + + # 记录API请求结束时间 + api_end_time = time.time() + total_duration = api_end_time - api_start_time + + # 输出完整的响应 + print("\n" + "="*50) + print("完整回复:") + print("="*50) + print(full_response) + + # 保存结果为TXT文件 + txt_file_path = save_result_to_txt(full_response + "total_duration:" + str(total_duration), video_path) + + # 输出时间统计信息 + print("\n" + "="*50) + print("⏱️ 时间统计:") + print("="*50) + print(f"📁 文件编码时间: {encode_duration:.2f} 秒") + if first_token_time: + print(f"🚀 首个token延迟: {first_token_delay:.2f} 秒") + generation_time = api_end_time - first_token_time + print(f"⚡ 内容生成时间: {generation_time:.2f} 秒") + print(f"🕐 API总响应时间: {total_duration:.2f} 秒") + print(f"📊 生成token数量: {token_count}") + if first_token_time and token_count > 0: + tokens_per_second = token_count / generation_time + print(f"🔥 生成速度: {tokens_per_second:.2f} tokens/秒") + print(f"⏰ 完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + # 输出使用情况信息 + if usage_info: + print("\n" + "="*50) + print("📈 使用情况:") + print("="*50) + print(usage_info) + +else: + # 非流式输出 - 直接输出完整响应 + api_end_time = time.time() + total_duration = api_end_time - api_start_time + + print("非流式输出模式:") + print("完整回复:") + print("="*50) + print(completion.choices[0].message.content) + + # 保存结果为TXT文件 + txt_file_path = save_result_to_txt(completion.choices[0].message.content, video_path) + + # 输出时间统计信息 + print("\n" + "="*50) + print("⏱️ 时间统计:") + print("="*50) + print(f"📁 文件编码时间: {encode_duration:.2f} 秒") + print(f"🕐 API总响应时间: {total_duration:.2f} 秒") + print(f"⏰ 完成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + # 输出使用情况信息 + if hasattr(completion, 'usage') and completion.usage: + print("\n" + "="*50) + print("📈 使用情况:") + print("="*50) + print(completion.usage) \ No newline at end of file diff --git a/code/director_prompt.py b/code/director_prompt.py new file mode 100644 index 0000000..495c8c7 --- /dev/null +++ b/code/director_prompt.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +导演编辑提示词生成器 +整合OCR字幕和Whisper口播时间轴,生成专业视频脚本 +""" + +def generate_director_prompt(ocr_timeline, whisper_timeline, video_info=None): + """ + 生成导演编辑提示词 + + Args: + ocr_timeline: OCR字幕时间轴数据 + whisper_timeline: Whisper口播时间轴数据 + video_info: 视频基本信息(可选) + + Returns: + str: 导演编辑提示词 + """ + + prompt = f""" +# 导演编辑任务:视频脚本整合与优化 + +## 任务背景 +您是一位经验丰富的视频导演和编辑,需要基于以下两个时间轴数据,为视频创作一个完整、流畅的脚本。 + +## 输入数据 + +### OCR字幕时间轴(视觉文字内容) +{ocr_timeline} + +### Whisper口播时间轴(音频内容) +{whisper_timeline} + +{video_info if video_info else ""} + +## 导演编辑要求 + +### 1. 内容整合策略 +- **时间同步**:将OCR字幕与Whisper口播按时间轴精确对齐 +- **内容互补**:识别字幕与口播的重复、补充和差异部分 +- **信息完整性**:确保重要信息不遗漏,避免重复冗余 + +### 2. 脚本结构设计 +- **开场设计**:基于前3秒内容设计引人入胜的开场 +- **节奏控制**:根据时间轴密度调整内容节奏 +- **高潮设置**:识别关键信息点,设计内容高潮 +- **结尾收束**:基于最后内容设计有力结尾 + +### 3. 语言风格优化 +- **口语化处理**:将OCR识别文字转换为自然口语表达 +- **情感表达**:根据内容调整语调、语速和情感色彩 +- **文化适配**:考虑目标受众,调整表达方式 + +### 4. 视觉与音频协调 +- **字幕时机**:优化字幕出现时机,与音频节奏配合 +- **重点突出**:识别关键信息,在脚本中重点标注 +- **转场设计**:设计自然的内容转场和过渡 + +## 输出格式要求 + +请按以下格式输出脚本: + +### 完整脚本 +``` +[时间戳] [角色/场景] [内容] +``` + +### 脚本分析 +- **内容概览**:简要总结视频核心内容 +- **关键信息点**:列出3-5个最重要的信息 +- **目标受众**:分析适合的观众群体 +- **传播建议**:提供传播和推广建议 + +### 技术参数 +- **总时长**:基于时间轴计算 +- **内容密度**:评估信息密度是否合适 +- **节奏分析**:分析内容节奏变化 + +## 创作原则 +1. **真实性**:保持原始内容的真实性,不添加虚构信息 +2. **流畅性**:确保脚本逻辑清晰,表达流畅 +3. **吸引力**:增强内容的吸引力和传播性 +4. **专业性**:体现专业导演的编辑水平 + +请基于以上要求,创作一个完整、专业的视频脚本。 +""" + + return prompt + +def format_timeline_for_prompt(timeline_data, timeline_type): + """ + 格式化时间轴数据用于提示词 + + Args: + timeline_data: 时间轴数据 + timeline_type: 时间轴类型 ("OCR" 或 "Whisper") + + Returns: + str: 格式化的时间轴文本 + """ + if timeline_type == "OCR": + formatted = "OCR字幕识别结果:\n" + for entry in timeline_data: + timestamp = entry.get('timestamp', 0) + contents = entry.get('contents', []) + formatted += f"时间点 {timestamp:.2f}s:\n" + for content in contents: + text = content.get('text', '') + bbox = content.get('bbox', []) + formatted += f" - 文字: '{text}'\n" + if bbox: + formatted += f" 位置: {bbox}\n" + formatted += "\n" + + elif timeline_type == "Whisper": + formatted = "Whisper语音识别结果:\n" + for i, entry in enumerate(timeline_data): + start_time = entry.get('start', 0) + end_time = entry.get('end', 0) + text = entry.get('text', '') + formatted += f" id:{i}, start:{start_time:.2f}, end:{end_time:.2f}, text:{text}\n" + + return formatted + +def create_video_script_prompt(ocr_json_path, whisper_data=None): + """ + 创建完整的视频脚本提示词 + + Args: + ocr_json_path: OCR JSON文件路径 + whisper_data: Whisper识别数据(可选) + + Returns: + str: 完整的导演编辑提示词 + """ + # 读取OCR数据 + from pre_data_1 import read_json_file, format_ocr_json + + ocr_data = read_json_file(ocr_json_path) + if not ocr_data: + return "错误:无法读取OCR数据文件" + + # 格式化OCR时间轴 + _, subtitle_array = format_ocr_json(ocr_data) + ocr_timeline = format_timeline_for_prompt(subtitle_array, "OCR") + + # 格式化Whisper时间轴(如果有) + whisper_timeline = "" + if whisper_data: + whisper_timeline = format_timeline_for_prompt(whisper_data, "Whisper") + else: + whisper_timeline = "(暂无Whisper数据)" + + # 视频基本信息 + video_info = f""" +### 视频基本信息 +- 文件路径: {ocr_json_path} +- OCR引擎: {ocr_data.get('ocr_engine', 'Unknown')} +- 视频时长: {ocr_data.get('duration', 0):.2f}秒 +- 视频分辨率: {ocr_data.get('frame_width', 0)}x{ocr_data.get('frame_height', 0)} +- 视频帧率: {ocr_data.get('fps', 0):.2f}FPS +""" + + # 生成导演提示词 + prompt = generate_director_prompt(ocr_timeline, whisper_timeline, video_info) + + return prompt + +# 示例使用 +if __name__ == "__main__": + # 示例Whisper数据(实际使用时应该从文件读取) + example_whisper_data = [ + {"start": 0.00, "end": 1.80, "text": "潑水街不只有云南"}, + {"start": 1.80, "end": 3.56, "text": "老窝更远更传统"}, + {"start": 3.56, "end": 5.64, "text": "快来接触这份湿身快乐"}, + # ... 更多数据 + ] + + # 生成提示词 + prompt = create_video_script_prompt( + "/root/autodl-tmp/new_cnocr/老挝泼水节_subtitles.json", + example_whisper_data + ) + + # 保存提示词到文件 + import os + output_path = "/root/autodl-tmp/new_cnocr/director_prompt.txt" + with open(output_path, 'w', encoding='utf-8') as f: + f.write(prompt) + + print(f"导演编辑提示词已保存到: {output_path}") + print("\n提示词预览(前500字符):") + print(prompt[:500] + "...") \ No newline at end of file diff --git a/code/pre_data_1.py b/code/pre_data_1.py new file mode 100644 index 0000000..97251ba --- /dev/null +++ b/code/pre_data_1.py @@ -0,0 +1,316 @@ +import os + +def read_json_file(json_path): + """读取JSON文件内容""" + try: + import json + with open(json_path, 'r', encoding='utf-8') as file: + data = json.load(file) + print(f"成功读取JSON文件: {json_path}") + return data + except FileNotFoundError: + print(f"错误: 找不到文件 {json_path}") + return None + except json.JSONDecodeError as e: + print(f"JSON解析错误: {e}") + return None + except Exception as e: + print(f"读取JSON文件时出错: {e}") + return None + +def calculate_text_similarity(text1, text2): + """ + 计算两个文本的相似度(使用Jaccard相似度) + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + float: 相似度值 (0-1之间) + """ + # 检查空文本 + if not text1 or not text2: + return 0.0 + + # 清理文本,移除空白字符 + text1 = text1.strip() + text2 = text2.strip() + + if not text1 or not text2: + return 0.0 + + # 如果两个文本完全相同 + if text1 == text2: + return 1.0 + + # 将文本转换为字符集合 + chars1 = set(text1) + chars2 = set(text2) + + # 计算Jaccard相似度 + intersection = len(chars1.intersection(chars2)) + union = len(chars1.union(chars2)) + + similarity = intersection / union if union > 0 else 0.0 + return similarity + +def calculate_iou(box1, box2): + """ + 计算两个边界框的IoU (Intersection over Union) + + Args: + box1: 第一个边界框 [x1, y1, x2, y2] 或 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + box2: 第二个边界框 [x1, y1, x2, y2] 或 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + + Returns: + float: IoU值 (0-1之间) + """ + # 处理不同的输入格式 + if len(box1) == 4 and isinstance(box1[0], (int, float)): + # 格式: [x1, y1, x2, y2] + x1_1, y1_1, x2_1, y2_1 = box1 + elif len(box1) == 4 and isinstance(box1[0], list): + # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标 + x_coords = [point[0] for point in box1] + y_coords = [point[1] for point in box1] + x1_1, x2_1 = min(x_coords), max(x_coords) + y1_1, y2_1 = min(y_coords), max(y_coords) + else: + raise ValueError("box1格式错误,应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]") + + if len(box2) == 4 and isinstance(box2[0], (int, float)): + # 格式: [x1, y1, x2, y2] + x1_2, y1_2, x2_2, y2_2 = box2 + elif len(box2) == 4 and isinstance(box2[0], list): + # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - 取最小和最大坐标 + x_coords = [point[0] for point in box2] + y_coords = [point[1] for point in box2] + x1_2, x2_2 = min(x_coords), max(x_coords) + y1_2, y2_2 = min(y_coords), max(y_coords) + else: + raise ValueError("box2格式错误,应为[x1,y1,x2,y2]或[[x1,y1],[x2,y2],[x3,y3],[x4,y4]]") + + # 计算交集区域 + x_left = max(x1_1, x1_2) + y_top = max(y1_1, y1_2) + x_right = min(x2_1, x2_2) + y_bottom = min(y2_1, y2_2) + + # 检查是否有交集 + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # 计算交集面积 + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # 计算并集面积 + box1_area = (x2_1 - x1_1) * (y2_1 - y1_1) + box2_area = (x2_2 - x1_2) * (y2_2 - y1_2) + union_area = box1_area + box2_area - intersection_area + + # 计算IoU + iou = intersection_area / union_area if union_area > 0 else 0.0 + + return iou + +def format_ocr_json(ocr_data): + """格式化OCR字幕转文字JSON数据""" + if not ocr_data: + return "", [] + + formatted_text = "【OCR字幕识别内容】\n" + + # 如果是字幕提取器的格式 + if isinstance(ocr_data, dict): + # 基本信息 + if 'ocr_engine' in ocr_data: + formatted_text += f"OCR引擎: {ocr_data['ocr_engine']}\n" + + if 'video_path' in ocr_data: + formatted_text += f"视频文件: {ocr_data['video_path']}\n" + + if 'duration' in ocr_data: + formatted_text += f"视频时长: {ocr_data['duration']:.2f}秒\n" + + if 'fps' in ocr_data: + formatted_text += f"视频帧率: {ocr_data['fps']:.2f}FPS\n" + + if 'frame_width' in ocr_data and 'frame_height' in ocr_data: + formatted_text += f"视频分辨率: {ocr_data['frame_width']}x{ocr_data['frame_height']}\n" + + # 字幕区域信息 + if 'subtitle_position' in ocr_data: + formatted_text += f"字幕区域: {ocr_data['subtitle_position']}\n" + + if 'subtitle_region' in ocr_data: + region = ocr_data['subtitle_region'] + formatted_text += f"字幕区域坐标: {region}\n" + + # 处理参数 + if 'sample_interval' in ocr_data: + formatted_text += f"采样间隔: {ocr_data['sample_interval']}帧\n" + + if 'confidence_threshold' in ocr_data: + formatted_text += f"置信度阈值: {ocr_data['confidence_threshold']}\n" + + # 完整字幕文本 + if 'continuous_text' in ocr_data: + formatted_text += f"\n📄 完整字幕文本:\n" + formatted_text += f"{ocr_data['continuous_text']}\n" + + # 详细字幕时间轴 - 按三层嵌套数组结构组织 + if 'subtitles' in ocr_data and len(ocr_data['subtitles']) > 0: + subtitles = ocr_data['subtitles'] + + # 按时间戳分组存储 + timestamp_groups = {} + for subtitle in subtitles: + timestamp = subtitle.get('timestamp', 0) + text = subtitle.get('text', '') + confidence = subtitle.get('confidence', 0) + engine = subtitle.get('engine', 'Unknown') + bbox = subtitle.get('bbox', []) + + if timestamp not in timestamp_groups: + timestamp_groups[timestamp] = [] + + # 第三层:内容和位置 + subtitle_content = { + 'text': text, + 'bbox': bbox, + "timestamp": timestamp + } + + timestamp_groups[timestamp].append(subtitle_content) + + # 转换为三层嵌套数组结构 + subtitle_array = [] + sorted_timestamps = sorted(timestamp_groups.keys()) + + for timestamp in sorted_timestamps: + # 第一层:时间戳 + timestamp_entry = { + 'timestamp': timestamp, + 'contents': timestamp_groups[timestamp] # 第二层:同一时间戳内的各个内容 + } + subtitle_array.append(timestamp_entry) + + # 显示三层嵌套数组结构 + formatted_text += f"\n⏰ 详细字幕时间轴 (三层嵌套数组结构):\n" + + # 只显示前10个时间戳,避免过长 + display_count = min(10, len(subtitle_array)) + for i, timestamp_entry in enumerate(subtitle_array[:display_count], 1): + timestamp = timestamp_entry['timestamp'] + contents = timestamp_entry['contents'] + + formatted_text += f" {i}. {timestamp:.2f}s:\n" + + # 显示该时间戳下的所有字幕(第二层) + for j, content in enumerate(contents, 1): + text = content['text'] + bbox = content['bbox'] + + formatted_text += f" {j}. [{timestamp:.2f}s|{confidence:.3f}]: {text}\n" + + # 如果有位置信息,显示bbox(第三层) + if bbox: + formatted_text += f" 位置: {bbox}\n" + + formatted_text += "\n" + + if len(subtitle_array) > display_count: + formatted_text += f" ... (还有{len(subtitle_array) - display_count}个时间戳)\n" + + # 返回三层嵌套数组结构 + return formatted_text, subtitle_array + + return formatted_text, [] + +def merge_and_filter_subtitles(subtitle_array, iou_threshold=0.7, text_similarity_threshold=0.7): + """ + 合并并过滤字幕内容,去除重复和空内容,返回格式化字符串和处理后的数组 + """ + # 深拷贝,避免原地修改 + import copy + subtitle_array = copy.deepcopy(subtitle_array) + formatted_text = [] + + for i in range(len(subtitle_array)): + for j in range(len(subtitle_array[i]["contents"])): + # 修复:确保i+k不会超出数组范围 + for k in range(1, len(subtitle_array) - i): # 从1开始,避免自己和自己比较 + if i + k >= len(subtitle_array): # 安全检查 + break + for l in range(len(subtitle_array[i+k]["contents"])): + text = subtitle_array[i]["contents"][j]["text"] + bbox = subtitle_array[i]["contents"][j]["bbox"] + text_1 = subtitle_array[i+k]["contents"][l]["text"] + bbox_1 = subtitle_array[i+k]["contents"][l]["bbox"] + + iou = calculate_iou(bbox, bbox_1) + text_similarity = calculate_text_similarity(text, text_1) + + if iou > iou_threshold and text_similarity > text_similarity_threshold: + # 记录需要删除的索引 + subtitle_array[i+k]["contents"][l]["text"] = '' + subtitle_array[i]["contents"][j]["timestamp"] += 1 + + # 删除text为空字符串的contents + for i in range(len(subtitle_array)): + subtitle_array[i]["contents"] = [content for content in subtitle_array[i]["contents"] if content["text"] != ''] + + # 删除contents为空的时间戳条目 + subtitle_array = [entry for entry in subtitle_array if len(entry["contents"]) > 0] + + #formatted_text.append("处理后的字幕数组:") + for i, timestamp_entry in enumerate(subtitle_array[:], 1): + formatted_text.append(f"\n开始时间 {timestamp_entry['timestamp']:.2f}s:") + #formatted_text.append(f" 包含 {len(timestamp_entry['contents'])} 个字幕内容") + for j, content in enumerate(timestamp_entry['contents'], 1): + formatted_text.append(f" {j}. 文本: '{content['text']}'") + if content['bbox']: + formatted_text.append(f" 位置: {content['bbox']}") + if 'timestamp' in content and content['timestamp']: + formatted_text.append(f" 结束时间: {content['timestamp']:.2f}s") + + #formatted_text.append("\n完整数组结构:") + #formatted_text.append(str(subtitle_array)) + + return '\n'.join(formatted_text), subtitle_array + + +ocr_json_path = "/root/autodl-tmp/new_cnocr/哈尔滨_subtitles.json" + +ocr_data = read_json_file(ocr_json_path) +pre_data , subtitle_array= format_ocr_json(ocr_data) + +iou_threshold = 0.8 +text_similarity_threshold = 0.8 +a , b = merge_and_filter_subtitles(subtitle_array, iou_threshold, text_similarity_threshold) +#print("\n完整数组结构:") +print(a) +print(b) + +# 保存输出结果到txt文件 + +output_dir = os.path.dirname(ocr_json_path) +output_filename = os.path.splitext(os.path.basename(ocr_json_path))[0] + "_processed.txt" +output_path = os.path.join(output_dir, output_filename) + +try: + with open(output_path, 'w', encoding='utf-8') as f: + f.write(a) + print(f"\n处理结果已保存到: {output_path}") +except Exception as e: + print(f"保存文件时出错: {e}") + +#验证 "/root/autodl-tmp/douyin_ocr/兰州_subtitles.json" 里面的重复的两个内容,确实是bbox不重叠 +# a = [[303, 243], [442, 243], [442, 303], [303, 303]] +# b = [[339, 231], [495, 241], [490, 304], [335, 294]] +# c = [[482, 273], [660, 276], [660, 303], [481, 300]] +# d = [[536, 268], [732, 273], [731, 300], [535, 295]] + +# iou = calculate_iou(a,b) # 0.47 +# d = calculate_iou(c,d) # 0.40 \ No newline at end of file diff --git a/code/token_counter.py b/code/token_counter.py new file mode 100644 index 0000000..83033ed --- /dev/null +++ b/code/token_counter.py @@ -0,0 +1,127 @@ +import tiktoken +import os +import cv2 + +def count_tokens(text, model="gpt-4"): + """统计文本的token数量""" + try: + encoding = tiktoken.encoding_for_model(model) + tokens = encoding.encode(text) + return len(tokens) + except Exception as e: + print(f"Token统计出错: {e}") + # 简单估算:中文字符约1.5个token,英文单词约1.3个token + chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') + english_words = len([word for word in text.split() if word.isascii()]) + estimated_tokens = int(chinese_chars * 1.5 + english_words * 1.3) + return estimated_tokens + +def get_video_token_estimate(video_path): + """估算视频的token数量(基于文件大小和时长)""" + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + return {'estimated_tokens': 0, 'duration': 0, 'frame_count': 0, 'fps': 0, 'file_size_mb': 0, 'frames_used': 0} + + # 获取视频信息 + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps if fps > 0 else 0 + + # 获取文件大小 + file_size = os.path.getsize(video_path) + + cap.release() + + # 基于GPT-4V的token估算规则 + # 视频token = 基础token + 帧数 * 每帧token + base_tokens = 85 # 基础token + frames_per_second = min(fps, 1) # 每秒最多1帧 + total_frames = min(frame_count, int(duration * frames_per_second)) + tokens_per_frame = 170 # 每帧约170个token + + estimated_tokens = base_tokens + total_frames * tokens_per_frame + + return { + 'estimated_tokens': int(estimated_tokens), + 'duration': duration, + 'frame_count': frame_count, + 'fps': fps, + 'file_size_mb': file_size / (1024 * 1024), + 'frames_used': total_frames + } + except Exception as e: + print(f"视频token估算出错: {e}") + return {'estimated_tokens': 0, 'duration': 0, 'frame_count': 0, 'fps': 0, 'file_size_mb': 0, 'frames_used': 0} + +def analyze_input_tokens(video_path, text_content="", prompt_text=""): + """分析输入token统计""" + print("\n" + "="*50) + print("📊 Token统计信息:") + print("="*50) + + # 统计视频token + video_token_info = get_video_token_estimate(video_path) + print(f"🎬 视频Token统计:") + print(f" 估算Token数量: {video_token_info['estimated_tokens']:,}") + print(f" 视频时长: {video_token_info['duration']:.2f}秒") + print(f" 总帧数: {video_token_info['frame_count']:,}") + print(f" 帧率: {video_token_info['fps']:.2f} fps") + print(f" 文件大小: {video_token_info['file_size_mb']:.2f} MB") + print(f" 使用帧数: {video_token_info['frames_used']:,}") + + # 统计文本token + text_tokens = 0 + if text_content.strip(): + text_tokens = count_tokens(text_content) + print(f"\n📝 文本Token统计:") + print(f" 文本内容Token: {text_tokens:,}") + print(f" 文本字符数: {len(text_content):,}") + + # 统计提示词token + prompt_tokens = 0 + if prompt_text.strip(): + prompt_tokens = count_tokens(prompt_text) + print(f" 提示词Token: {prompt_tokens:,}") + + video_cost = 0.0015 + text_cost = 0.0004 + total_cost = (video_token_info['estimated_tokens']*video_cost + text_tokens*text_cost + prompt_tokens*text_cost)/1000 + # 计算总输入token + total_input_tokens = (video_token_info['estimated_tokens'] + text_tokens + prompt_tokens) + print(f"\n📈 总输入Token统计:") + print(f" 视频Token: {video_token_info['estimated_tokens']:,}") + print(f" 文本Token: {text_tokens:,}") + print(f" 提示词Token: {prompt_tokens:,}") + print(f" 🔥 总输入Token: {total_input_tokens:,}") + print(f" 💰 总费用: {total_cost:.4f}元") + print("="*50) + + return { + 'video_tokens': video_token_info['estimated_tokens'], + 'text_tokens': text_tokens, + 'prompt_tokens': prompt_tokens, + 'total_input_tokens': total_input_tokens, + 'video_info': video_token_info, + 'total_cost': total_cost + } + +if __name__ == "__main__": + # 测试token统计功能 + test_video = "/root/autodl-tmp/new/哈尔滨.mp4" + test_text = "这是一个测试文本,包含中英文内容。This is a test text with Chinese and English content." + test_prompt = "请分析这个视频的内容。" + + result = analyze_input_tokens(test_video, test_text, test_prompt) + print(f"\n测试结果: {result}") + + # video_token = result['video_tokens'] + # video_cost = 0.0015 + + # prompt_token = result['prompt_tokens'] + # text_token = result['text_tokens'] + # text_cost = 0.0004 + + # total_cost = video_token*video_cost + prompt_token*text_cost + text_token*text_cost + + # print(total_cost) \ No newline at end of file