Files
2025-10-21 23:00:48 +08:00

172 lines
6.8 KiB
Python

#!/user/bin/python
import re
from typing import List
from collections import deque
from loguru import logger
class TextSplitter:
"""文本切片"""
# 基于常用标点符号分句,保持句子完整
COMMON_PUNCTUATIONS = ["", "", "", "", "", ",", "?", "!", ";"]
PUNC_PATTERN = f"[{''.join(COMMON_PUNCTUATIONS)}]"
def __init__(self, max_characters: int, chunk_size: int, chunk_overlap: int):
"""文本切片初始化
Args:
max_characters :文件最大字符,超过截断,-1不处理
chunk_size: 块大小
chunk_overlap: 块重叠度
"""
if chunk_size <= chunk_overlap:
logger.error(f"param chunk_size should larger than chunk_overlap, "
f"current chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")
raise Exception(83000, str(ValueError)) from None
self.max_characters = max_characters
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = ["\n\n", "\n"]
@staticmethod
def split_text_by_separator(text: str, separator: str):
"""指定分隔符对文本进行切分,并且切分后的片段需要保留分隔符"""
# 处理一个换行符与两个换行符之间的冲突
if text.startswith("\n\n") and separator == "\n":
chunks = re.split(f"({separator})", text.strip())
chunks[0] = f"\n\n{chunks[0]}"
else:
chunks = re.split(f"({separator})", text)
new_chunks = [chunks[idx] + chunks[idx + 1] for idx in range(1, len(chunks), 2)]
new_chunks = [chunks[0]] + new_chunks
return [chunk for chunk in new_chunks if chunk.strip() != ""]
@staticmethod
def split_sentences(chunk: str):
"""对切片按照标点符号切分成句子,并且保持标点符号不丢失"""
sentences = re.split(TextSplitter.PUNC_PATTERN, chunk)
delimiters = [s for s in chunk if s in TextSplitter.COMMON_PUNCTUATIONS]
restore_chunks = []
for chunk, delimiter in zip(sentences[:-1], delimiters):
restore_chunks.append(chunk + delimiter)
return restore_chunks + [sentences[-1]]
def split_text(self, input_data: str):
if self.max_characters > 0:
logger.info(f"The document characters should be within: {self.max_characters}")
input_data = input_data[:self.max_characters]
logger.info(f"characters of the document: {len(input_data)}")
chunks = self.split_text_recursive(input_data, self.separators)
final_chunks = self.merge_chunks(chunks)
final_chunks = self.split_text_by_chunk_size(final_chunks)
return [chunk.strip() for chunk in final_chunks if chunk]
def split_text_recursive(self, input_data: str, separators: List[str]):
"""对文档按照分隔符优先级进行递归切分:
1. 符合chunk_size要求的切片不再切分。
2. 大于chunk_size要求的切片,继续进行递归切分。
Args:
input_data: 输入文本
separators: 分隔符
Returns:
List[str]: 切分后的文本片段
"""
chunks = []
cur_separator = ""
next_separators = []
for idx, sep in enumerate(separators):
sep = re.escape(sep)
if re.search(sep, input_data.strip()):
cur_separator = sep
next_separators = separators[idx + 1:]
break
if not cur_separator:
return [input_data]
else:
cur_chunks = TextSplitter.split_text_by_separator(input_data, cur_separator)
for chunk in cur_chunks:
if len(chunk.strip()) <= self.chunk_size:
chunks.append(chunk)
else:
if not next_separators:
chunks.append(chunk)
else:
next_chunks = self.split_text_recursive(chunk, next_separators)
chunks.extend(next_chunks)
return chunks
def merge_chunks(self, chunks: List[str]):
"""对切分后的文本片段进行合并,合并过程考虑overlap"""
final_chunks = []
idx = 0
while idx < len(chunks):
if len(chunks[idx]) >= self.chunk_size:
final_chunks.append(chunks[idx])
idx += 1
continue
merge_idxes = self.get_merge_idxes(idx, chunks)
content = ""
for inner_idx in merge_idxes:
content += chunks[inner_idx]
final_chunks.append(content)
idx = merge_idxes[-1] + 1
return final_chunks
def get_merge_idxes(self, cur_idx: int, chunks: List[str]):
"""获取可以合并的分片index,前向尽可能满足overlap,后向尽可能满足chunk_size"""
idxes = deque([cur_idx])
overlap_idx = cur_idx - 1
cur_len = len(chunks[cur_idx])
cur_idx += 1
# 获取overlap的index
over_lap_len = 0
while overlap_idx >= 0:
over_lap_len += len(chunks[overlap_idx])
if over_lap_len > self.chunk_overlap or (cur_len + over_lap_len) > self.chunk_size:
over_lap_len -= len(chunks[overlap_idx])
break
idxes.appendleft(overlap_idx)
overlap_idx -= 1
cur_len += over_lap_len
# 获取merge的index
while cur_idx < len(chunks):
cur_len += len(chunks[cur_idx])
if cur_len > self.chunk_size:
break
idxes.append(cur_idx)
cur_idx += 1
return idxes
def split_chunks(self, chunks: List[str]):
"""对超过`chunk_size`限制的切片进行截断,过程中需要考虑overlap参数"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
else:
start = 0
end = self.chunk_size
while end < len(chunk):
final_chunks.append(chunk[start: end])
start += self.chunk_size - self.chunk_overlap
end = start + self.chunk_size
final_chunks.append(chunk[start:])
return final_chunks
def split_text_by_chunk_size(self, chunks: List[str]):
"""对切片后超长的文本块进行二次切分,使用截断,并考虑overlap"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
continue
sentences = TextSplitter.split_sentences(chunk)
sub_chunks = self.merge_chunks(sentences)
final_chunks.extend(self.split_chunks(sub_chunks))
return final_chunks