video_translation/test_fix_overlap.py

108 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""测试修复后的算法是否解决重叠边界问题"""
import sys
sys.path.append('.')
# 模拟spaCy的Doc对象
class MockDoc:
def __init__(self, text, tokens):
self.text = text
self._tokens = tokens
def __len__(self):
return len(self._tokens)
def __getitem__(self, key):
if isinstance(key, slice):
start, stop, step = key.indices(len(self._tokens))
tokens = self._tokens[start:stop:step] if step else self._tokens[start:stop]
text = ''.join(tokens)
return MockDoc(text, tokens)
else:
return self._tokens[key]
def proportionally_split_source(source_doc, translated_parts_docs):
"""按翻译文本的token比例切分源文本使用累积比例算法避免重叠"""
if len(translated_parts_docs) <= 1:
return [source_doc]
# 计算每个翻译部分的token比例
tr_token_counts = [len(doc) for doc in translated_parts_docs]
total_tr_tokens = sum(tr_token_counts)
if total_tr_tokens == 0:
return [source_doc] + [source_doc[:0] for _ in range(len(translated_parts_docs) - 1)]
# 使用累积比例算法确保无重叠
src_parts = []
total_src_tokens = len(source_doc)
current_idx = 0
cumulative_tr_tokens = 0
print(f"源文本总token数: {total_src_tokens}")
print(f"翻译部分token数: {tr_token_counts}, 总计: {total_tr_tokens}")
for i, tr_token_count in enumerate(tr_token_counts):
cumulative_tr_tokens += tr_token_count
if i == len(tr_token_counts) - 1:
# 最后一部分包含剩余所有tokens
next_idx = total_src_tokens
else:
# 根据累积比例计算下一个分割点
next_idx = round(total_src_tokens * cumulative_tr_tokens / total_tr_tokens)
part_doc = source_doc[current_idx:next_idx]
src_parts.append(part_doc)
print(f"分片 {i}: [{current_idx}:{next_idx}] = '{part_doc.text}' ({len(part_doc)} tokens)")
current_idx = next_idx
return src_parts
def test_overlap_fix():
"""测试重叠边界修复"""
# 创建测试数据:模拟会导致重叠的情况
source_text = "新疆早餐配篝火歌舞"
source_tokens = list(source_text) # 每个字符是一个token
source_doc = MockDoc(source_text, source_tokens)
# 模拟翻译后的三个分片token数量分别为3, 2, 3
tr_part1 = MockDoc("Part1", ["tok1", "tok2", "tok3"]) # 3 tokens
tr_part2 = MockDoc("Part2", ["tok4", "tok5"]) # 2 tokens
tr_part3 = MockDoc("Part3", ["tok6", "tok7", "tok8"]) # 3 tokens
translated_parts_docs = [tr_part1, tr_part2, tr_part3]
print("=== 测试累积比例算法 ===")
src_parts = proportionally_split_source(source_doc, translated_parts_docs)
# 检查是否有重叠
print("\n=== 检查重叠情况 ===")
total_chars = 0
prev_end = 0
for i, part in enumerate(src_parts):
print(f"分片 {i}: '{part.text}' (长度: {len(part)})")
total_chars += len(part)
# 检查是否连续
if i > 0:
# 获取当前分片的起始位置(这里简化处理)
current_start_char = part.text[0] if part.text else ""
print(f" 检查连续性...")
print(f"\n原始字符数: {len(source_text)}")
print(f"分片总字符数: {total_chars}")
print(f"是否相等: {len(source_text) == total_chars}")
# 重组检查
reconstructed = ''.join([part.text for part in src_parts])
print(f"重组文本: '{reconstructed}'")
print(f"是否与原文相同: {reconstructed == source_text}")
if __name__ == '__main__':
test_overlap_fix()