video_translation/test_fix_overlap.py

108 lines
3.8 KiB
Python
Raw Permalink Normal View History

2025-09-05 14:41:59 +08:00
#!/usr/bin/env python3
"""测试修复后的算法是否解决重叠边界问题"""
import sys
sys.path.append('.')
# 模拟spaCy的Doc对象
class MockDoc:
def __init__(self, text, tokens):
self.text = text
self._tokens = tokens
def __len__(self):
return len(self._tokens)
def __getitem__(self, key):
if isinstance(key, slice):
start, stop, step = key.indices(len(self._tokens))
tokens = self._tokens[start:stop:step] if step else self._tokens[start:stop]
text = ''.join(tokens)
return MockDoc(text, tokens)
else:
return self._tokens[key]
def proportionally_split_source(source_doc, translated_parts_docs):
"""按翻译文本的token比例切分源文本使用累积比例算法避免重叠"""
if len(translated_parts_docs) <= 1:
return [source_doc]
# 计算每个翻译部分的token比例
tr_token_counts = [len(doc) for doc in translated_parts_docs]
total_tr_tokens = sum(tr_token_counts)
if total_tr_tokens == 0:
return [source_doc] + [source_doc[:0] for _ in range(len(translated_parts_docs) - 1)]
# 使用累积比例算法确保无重叠
src_parts = []
total_src_tokens = len(source_doc)
current_idx = 0
cumulative_tr_tokens = 0
print(f"源文本总token数: {total_src_tokens}")
print(f"翻译部分token数: {tr_token_counts}, 总计: {total_tr_tokens}")
for i, tr_token_count in enumerate(tr_token_counts):
cumulative_tr_tokens += tr_token_count
if i == len(tr_token_counts) - 1:
# 最后一部分包含剩余所有tokens
next_idx = total_src_tokens
else:
# 根据累积比例计算下一个分割点
next_idx = round(total_src_tokens * cumulative_tr_tokens / total_tr_tokens)
part_doc = source_doc[current_idx:next_idx]
src_parts.append(part_doc)
print(f"分片 {i}: [{current_idx}:{next_idx}] = '{part_doc.text}' ({len(part_doc)} tokens)")
current_idx = next_idx
return src_parts
def test_overlap_fix():
"""测试重叠边界修复"""
# 创建测试数据:模拟会导致重叠的情况
source_text = "新疆早餐配篝火歌舞"
source_tokens = list(source_text) # 每个字符是一个token
source_doc = MockDoc(source_text, source_tokens)
# 模拟翻译后的三个分片token数量分别为3, 2, 3
tr_part1 = MockDoc("Part1", ["tok1", "tok2", "tok3"]) # 3 tokens
tr_part2 = MockDoc("Part2", ["tok4", "tok5"]) # 2 tokens
tr_part3 = MockDoc("Part3", ["tok6", "tok7", "tok8"]) # 3 tokens
translated_parts_docs = [tr_part1, tr_part2, tr_part3]
print("=== 测试累积比例算法 ===")
src_parts = proportionally_split_source(source_doc, translated_parts_docs)
# 检查是否有重叠
print("\n=== 检查重叠情况 ===")
total_chars = 0
prev_end = 0
for i, part in enumerate(src_parts):
print(f"分片 {i}: '{part.text}' (长度: {len(part)})")
total_chars += len(part)
# 检查是否连续
if i > 0:
# 获取当前分片的起始位置(这里简化处理)
current_start_char = part.text[0] if part.text else ""
print(f" 检查连续性...")
print(f"\n原始字符数: {len(source_text)}")
print(f"分片总字符数: {total_chars}")
print(f"是否相等: {len(source_text) == total_chars}")
# 重组检查
reconstructed = ''.join([part.text for part in src_parts])
print(f"重组文本: '{reconstructed}'")
print(f"是否与原文相同: {reconstructed == source_text}")
if __name__ == '__main__':
test_overlap_fix()