108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
||
"""测试修复后的算法是否解决重叠边界问题"""
|
||
|
||
import sys
|
||
sys.path.append('.')
|
||
|
||
# 模拟spaCy的Doc对象
|
||
class MockDoc:
|
||
def __init__(self, text, tokens):
|
||
self.text = text
|
||
self._tokens = tokens
|
||
|
||
def __len__(self):
|
||
return len(self._tokens)
|
||
|
||
def __getitem__(self, key):
|
||
if isinstance(key, slice):
|
||
start, stop, step = key.indices(len(self._tokens))
|
||
tokens = self._tokens[start:stop:step] if step else self._tokens[start:stop]
|
||
text = ''.join(tokens)
|
||
return MockDoc(text, tokens)
|
||
else:
|
||
return self._tokens[key]
|
||
|
||
def proportionally_split_source(source_doc, translated_parts_docs):
|
||
"""按翻译文本的token比例切分源文本,使用累积比例算法避免重叠"""
|
||
if len(translated_parts_docs) <= 1:
|
||
return [source_doc]
|
||
|
||
# 计算每个翻译部分的token比例
|
||
tr_token_counts = [len(doc) for doc in translated_parts_docs]
|
||
total_tr_tokens = sum(tr_token_counts)
|
||
|
||
if total_tr_tokens == 0:
|
||
return [source_doc] + [source_doc[:0] for _ in range(len(translated_parts_docs) - 1)]
|
||
|
||
# 使用累积比例算法确保无重叠
|
||
src_parts = []
|
||
total_src_tokens = len(source_doc)
|
||
current_idx = 0
|
||
cumulative_tr_tokens = 0
|
||
|
||
print(f"源文本总token数: {total_src_tokens}")
|
||
print(f"翻译部分token数: {tr_token_counts}, 总计: {total_tr_tokens}")
|
||
|
||
for i, tr_token_count in enumerate(tr_token_counts):
|
||
cumulative_tr_tokens += tr_token_count
|
||
|
||
if i == len(tr_token_counts) - 1:
|
||
# 最后一部分包含剩余所有tokens
|
||
next_idx = total_src_tokens
|
||
else:
|
||
# 根据累积比例计算下一个分割点
|
||
next_idx = round(total_src_tokens * cumulative_tr_tokens / total_tr_tokens)
|
||
|
||
part_doc = source_doc[current_idx:next_idx]
|
||
src_parts.append(part_doc)
|
||
|
||
print(f"分片 {i}: [{current_idx}:{next_idx}] = '{part_doc.text}' ({len(part_doc)} tokens)")
|
||
current_idx = next_idx
|
||
|
||
return src_parts
|
||
|
||
def test_overlap_fix():
|
||
"""测试重叠边界修复"""
|
||
|
||
# 创建测试数据:模拟会导致重叠的情况
|
||
source_text = "新疆早餐配篝火歌舞"
|
||
source_tokens = list(source_text) # 每个字符是一个token
|
||
source_doc = MockDoc(source_text, source_tokens)
|
||
|
||
# 模拟翻译后的三个分片,token数量分别为3, 2, 3
|
||
tr_part1 = MockDoc("Part1", ["tok1", "tok2", "tok3"]) # 3 tokens
|
||
tr_part2 = MockDoc("Part2", ["tok4", "tok5"]) # 2 tokens
|
||
tr_part3 = MockDoc("Part3", ["tok6", "tok7", "tok8"]) # 3 tokens
|
||
|
||
translated_parts_docs = [tr_part1, tr_part2, tr_part3]
|
||
|
||
print("=== 测试累积比例算法 ===")
|
||
src_parts = proportionally_split_source(source_doc, translated_parts_docs)
|
||
|
||
# 检查是否有重叠
|
||
print("\n=== 检查重叠情况 ===")
|
||
total_chars = 0
|
||
prev_end = 0
|
||
|
||
for i, part in enumerate(src_parts):
|
||
print(f"分片 {i}: '{part.text}' (长度: {len(part)})")
|
||
total_chars += len(part)
|
||
|
||
# 检查是否连续
|
||
if i > 0:
|
||
# 获取当前分片的起始位置(这里简化处理)
|
||
current_start_char = part.text[0] if part.text else ""
|
||
print(f" 检查连续性...")
|
||
|
||
print(f"\n原始字符数: {len(source_text)}")
|
||
print(f"分片总字符数: {total_chars}")
|
||
print(f"是否相等: {len(source_text) == total_chars}")
|
||
|
||
# 重组检查
|
||
reconstructed = ''.join([part.text for part in src_parts])
|
||
print(f"重组文本: '{reconstructed}'")
|
||
print(f"是否与原文相同: {reconstructed == source_text}")
|
||
|
||
if __name__ == '__main__':
|
||
test_overlap_fix()
|