init datamate

This commit is contained in:
Dallas98
2025-10-21 23:00:48 +08:00
commit 1c97afed7d
692 changed files with 135442 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
from datetime import datetime
import cv2
import numpy as np
import os
import pytz
from loguru import logger
def check_valid_path(file_path):
full_path = os.path.abspath(file_path)
return os.path.exists(full_path)
def get_realpath_with_prefix_check(path, prefix):
realpath = os.path.realpath(path)
if realpath.startswith(prefix):
return realpath
else:
raise ValueError(f"The path {realpath} does not start with the prefix '{prefix}'.")
def bytes_to_numpy(image_bytes):
"""bytes转数组"""
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2
def numpy_to_bytes(image_np, file_type):
"""数组转bytes"""
if not image_np.size:
return b""
data = cv2.imencode(file_type, image_np)[1]
image_bytes = data.tobytes()
return image_bytes
def get_now_time(timezone, time_format, file_name, method):
timestamp = ""
try:
china_tz = pytz.timezone(timezone) # 设置时区
china_time = datetime.now(tz=china_tz) # 获取当前时间并转换为对应的时区
timestamp = china_time.strftime(time_format) # 格式化输出时间
except ValueError as e:
logger.error("fileName: %s, method: %s, formatting time failed: %s", file_name, method, e, exc_info=True)
return timestamp
def decrypt(enc_pass):
import kmc.kmc as K
os.environ['KMC_DATA_USER'] = 'modelenginepublic'
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = enc_pass
dec_pass = K.API().decrypt(0)
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = ""
return dec_pass

View File

@@ -0,0 +1,115 @@
# -- encoding: utf-8 --
from collections import deque
class TrieNode:
def __init__(self, value):
self.value = value
self.child = dict()
self.fail = None
self.word = None
class AhoCorasic:
"""AC自动机算法进行目标字符串搜索"""
def __init__(self, words):
self._root = add_fail_pointer(build_trie(words))
def search(self, text: str, special_symbols: set):
"""
匹配敏感词。
Args:
text: 文本
special_symbols: 特殊字符(需跳过)
Returns:
匹配成功的字符串列表
"""
seq_list = []
node = self._root
valid_len = 0 # 当前遍历的有效长度
for i, s in enumerate(text):
if s in special_symbols: # 跳过特殊字符
if valid_len != 0:
valid_len += 1
continue
matched = True
while s not in node.child: # 当node.child没有字符s
if node == self._root: # 当node为root(无node.fail),有效长度归0且跳出
valid_len = 0
matched = False
break
elif node.fail == self._root: # node.fail为root场景,有效长度归0,但可继续
valid_len = 0
node = node.fail # 移动到失败指针节点
if not matched:
continue
node = node.child.get(s)
valid_len += 1
if node.word: # node是单词尾字母
sensitive_word = text[i - valid_len + 1:i + 1]
seq_list.append(sensitive_word)
seq_list = list(set(seq_list))
return seq_list
def build_trie(words: list):
"""
构建前缀树。
Args:
words: 敏感词列表。
Returns:
前缀树根节点。
"""
root = TrieNode('root')
for word in words:
node = root
for s in word:
if s not in node.child:
node.child[s] = TrieNode(s)
node = node.child[s]
if not node.word:
node.word = {word}
else:
node.word.add(word)
return root
def add_fail_pointer(root: TrieNode):
"""
为前缀树添加失败指针。
步骤:
1. 从root开始逐层将node和node.parent以二元组存放队列。root没有fail指针,root.child的失败指针即为root。
2. 对于root和root.child以外的node,查询node.parent.fail.child。
3. 如果存在node.parent.fail.child.value == node.value,则构建node.fail = node.parent.fail.child.value。
Args:
root: 前缀树根节点。
returns:
添加失败指针后的前缀树根节点。
"""
queue = deque()
queue.appendleft((root, None))
while len(queue) > 0:
node_parent = queue.pop()
curr, parent = node_parent[0], node_parent[1]
for sub in curr.child.values():
queue.appendleft((sub, curr))
if parent is None:
continue
elif parent is root:
curr.fail = root
else:
parent_fail = parent.fail
while parent_fail and curr.value not in parent_fail.child:
parent_fail = parent_fail.fail
if parent_fail:
curr.fail = parent_fail.child[curr.value]
else:
curr.fail = root
return root

View File

@@ -0,0 +1,66 @@
# -- encoding: utf-8 --
import pickle
import base64
from io import BytesIO
import cv2
import numpy as np
from PIL import Image
def bytes_to_numpy(image_bytes):
"""bytes转数组"""
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2
def numpy_to_bytes(image_np, file_type):
"""
数组转bytes
Params:
file_type: as required by OpenCV, extension must have a leading period.
"""
if not image_np.size:
return b""
data = cv2.imencode(file_type, image_np)[1]
image_bytes = data.tobytes()
return image_bytes
def pil_to_bytes(src: Image.Image) -> bytes:
"""将 PIL.Image 转换为字节流"""
# 确保图像是 RGB 模式
src = src.convert("RGB")
with BytesIO() as bytes_io:
src.save(bytes_io, format='PNG')
im_bytes = bytes_io.getvalue()
return im_bytes
def bytes_to_pil(src: bytes) -> Image.Image:
"""将字节流转换为 PIL.Image"""
with BytesIO() as bytes_io:
with Image.open(bytes_io) as pil_img: # 使用with/as语句确保资源被正确释放
pil_img.load() # 确保图像数据被加载
return pil_img.copy() # 返回图像的副本以避免资源被关闭后无法使用
def pil_to_base64(src: Image.Image):
"""PIl.Image转base64"""
with BytesIO() as img_buffer:
src.save(img_buffer, format='png')
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data)
return base64_str
def obj_to_bytes(src: object) -> bytes:
return pickle.dumps(src)
def bytes_to_obj(src: bytes) -> object:
return pickle.loads(src)

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import importlib.abc
import importlib.util
from pathlib import Path
class CustomImporter(importlib.abc.MetaPathFinder):
def __init__(self, base_path):
self.base_path = Path(base_path).resolve()
def find_spec(self, fullname, path, target=None):
# 将模块名转换为路径(例如:mypkg.mymodule -> mypkg/mymodule.py)
parts = fullname.split(".")
module_path = self.base_path.joinpath(*parts)
# 检查是否存在 .py 文件或目录
if module_path.with_suffix(".py").exists():
return importlib.util.spec_from_file_location(
fullname,
str(module_path.with_suffix(".py")),
submodule_search_locations=[str(module_path.parent)]
)
elif module_path.is_dir() and (module_path / "__init__.py").exists():
return importlib.util.spec_from_file_location(
fullname,
str(module_path / "__init__.py"),
submodule_search_locations=[str(module_path)]
)
else:
return None

View File

@@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
import os
import re
import sys
import importlib
import subprocess
from pathlib import Path
from types import ModuleType
from loguru import logger
from packaging.version import parse as parse_version
def is_valid_whl_filename(package_name):
"""
验证WHL文件名是否安全(聚焦防范命令注入攻击)
Args:
package_name (str): 要验证的whl文件名
Returns:
bool: True表示安全,False表示包含危险字符
"""
# 禁止路径分隔符
if os.path.sep in package_name or (os.path.altsep and os.path.altsep in package_name):
return False
# 定义危险字符黑名单
dangerous_pattern = re.compile(r"""
[ # 匹配任意以下字符
; & | ` # 命令分隔符
$ () {} <> # 变量展开/命令替换
\] \[ # 特殊符号
! \\ # 历史扩展/转义符
'"*?#\s # 引号/通配符/空格
]
""", re.VERBOSE)
if dangerous_pattern.search(package_name):
return False
return True
class PackageNotFoundError(Exception):
pass
class LazyLoader(ModuleType):
def __init__(self,
package_name,
module_name=None,
whl_path="/dataset/ops_whl",
exact_version=None,
force_reinstall=False
):
"""
:param package_name: WHL包名称中的模块名称部分(一般是模块名称_替换为-)
:param module_name: WHL包安装后,可用于import的模块名称, 当whl包名称和import名称不一致时,填写local_name.
:param whl_path: WHL文件所在目录
:param exact_version: 精确版本要求
:param force_reinstall: 强制重新安装
"""
try:
frame = sys._getframe(1)
self._parent_globals = frame.f_globals
except (AttributeError, ValueError) as e:
logger.error(f"Failed to get stack frame: {e}")
raise RuntimeError("Stack frame retrieval failed") from e
self._module_name = module_name if module_name else package_name
self._package_name = package_name
self.whl_path = Path(whl_path).resolve()
self.exact_version = exact_version
self.force_reinstall = force_reinstall
self._cached_module = None
# 注册别名到父级命名空间
self._parent_globals[self._module_name] = self
super().__init__(self._module_name)
def __getattr__(self, name):
if self._cached_module is None:
self._cached_module = self._load_module()
return getattr(self._cached_module, name)
def __dir__(self):
return dir(self._load_module())
def _load_module(self):
"""模块加载逻辑"""
if self._cached_module is not None:
return self._cached_module
package_name: str = self._package_name.split('.')[0]
if not is_valid_whl_filename(package_name):
logger.error(f"Invalid package_name: {package_name}")
raise RuntimeError("Invalide package_name, please check it again!")
module_name = self._module_name if self._module_name else package_name.replace("_", "-")
need_install = False
try:
if not self.force_reinstall:
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
return module
except ImportError:
need_install = True
if self.force_reinstall:
# 强制安装时的版本检查
installed = self._check_package_exists(package_name)
if installed and self.exact_version:
installed_version = self._get_installed_version(package_name)
if parse_version(installed_version) != parse_version(self.exact_version):
logger.info(f"Version mismatch detected: {installed_version} vs {self.exact_version}")
need_install = True
else:
need_install = True
if need_install:
self._pip_install_package(package_name)
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
else:
# 版本检查通过,无需再次安装
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
return self._cached_module
def _register_alias(self, module):
"""注册本地别名 """
self._parent_globals[self._module_name] = module
sys.modules[self._module_name] = module
def _check_package_exists(self, package_name):
"""增强版包检查"""
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "show", package_name],
capture_output=True,
text=True
)
return result.returncode == 0
except subprocess.SubprocessError as e:
logger.error(f"Package check failed: {e}")
return False
def _get_installed_version(self, package_name):
"""获取已安装版本 """
result = subprocess.run(
[sys.executable, "-m", "pip", "show", package_name],
capture_output=True,
text=True
)
for line in result.stdout.split('\n'):
if line.startswith('Version:'):
return line.split()[-1]
raise PackageNotFoundError()
def _pip_install_package(self, package_name: str):
"""安装逻辑 """
if not self.whl_path.exists():
raise FileNotFoundError(f"WHL directory not found: {self.whl_path}")
whl_files = list(self.whl_path.glob(f"{package_name}*.whl"))
if not whl_files:
raise RuntimeError(f"No WHL files found for {package_name}")
# 版本过滤
if self.exact_version:
pattern = re.compile(
rf'^{re.escape(package_name)}-{re.escape(self.exact_version)}-\S*\.whl$',
re.IGNORECASE
)
whl_files = [f for f in whl_files if pattern.match(f.name)]
# 选择最新版本
whl_versions = []
for f in whl_files:
try:
version = self._extract_version(f)
whl_versions.append((f, version))
except ValueError:
continue
if not whl_versions:
raise FileNotFoundError("No valid WHL files")
whl_versions.sort(key=lambda x: x[1], reverse=True)
target_whl = whl_versions[0][0]
# 执行安装
try:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"--no-index",
f"--find-links={self.whl_path}",
str(target_whl)
], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
logger.info(f"Successfully installed {target_whl}")
except subprocess.CalledProcessError as e:
logger.error(f"Installation failed: {e}")
raise RuntimeError(f"Installation failed: {e}") from e
def _extract_version(self, filename):
"""版本解析 """
version_pattern = r"(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)"
match = re.search(
rf"^{re.escape(self._package_name)}-({version_pattern})",
filename.name,
re.IGNORECASE
)
if not match:
raise ValueError(f"Invalid version format: {filename.name}")
return parse_version(match.group(1))

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
import json
import os
import ssl
from pathlib import Path
from typing import Dict
from urllib.request import Request, urlopen
import requests
import urllib3
from loguru import logger
from datamate.common.utils import decrypt
class LlmReq:
# 定义常量用于解释错误码
ERRORCODE_INCOMPLETE_CONFIG = 83005
ERRORCODE_INVALID_RESPONSE = 83006
ERRORCODE_SERVICE_UNAVAILABLE = 83007
def __init__(self, url: str = None, header: Dict = None, body: Dict = None, access_type: int = None,
is_https: bool = False, is_certificate: bool = False, certificate_path: Path = None):
self.url = url
self.header = header
self.access_type = access_type
self.is_https = is_https
self.context = self._load_certificate(certificate_path, is_certificate) if is_https else None
self.body = body
if not self.body.get("messages", [])[0].get("content"):
self.body["messages"][0]["content"] = "你好"
def __call__(self, input_str: str) -> str:
outputs = ''
try:
self.body["messages"][0]["content"] = input_str
outputs = self._call_service()
except KeyError as e:
logger.error(f"The body format is not completed, error detail: {e}")
self.body["messages"][0]["content"] = "你好"
return outputs
@staticmethod
def _load_certificate(certificate_path: Path, is_certificate: bool) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
if is_certificate:
context.load_verify_locations(certificate_path)
context.verify_mode = ssl.CERT_REQUIRED
else:
context.verify_mode = ssl.CERT_NONE
return context
@staticmethod
def _pool_manager():
cert_file = os.getenv("RAY_TLS_SERVER_CERT", "/certPersonal/global/identity/global.crt")
key_file = os.getenv("RAY_TLS_SERVER_KEY", "/certPersonal/global/identity/global.key")
ca_crt = os.getenv("RAY_TLS_CA_CERT", "/certPersonal/global/trust/ca.crt")
pwd = os.getenv("RAY_TLS_SERVER_KEY_PASSWORD", "/certPersonal/global/identity/pwd.txt")
key_password = os.getenv("GLOBAL_PWD", None)
if not key_password:
with open(pwd, "r") as f:
key_password = f.read().strip()
key_password = decrypt(key_password)
pool_manager = urllib3.PoolManager(cert_file=cert_file,
key_file=key_file,
key_password=key_password,
cert_reqs='CERT_REQUIRED',
ca_certs=ca_crt,
assert_hostname='edatamate',
ssl_version='TLSv1_2')
return pool_manager
def _call_service(self):
if not all([self.url, self.header, self.body.get("messages", [])[0].get("content")]):
logger.error("LLM is not configured completely")
raise RuntimeError(self.ERRORCODE_INCOMPLETE_CONFIG, "LLM is not configured completely") from None
if not self.access_type:
try:
pool_manager = self._pool_manager()
response = pool_manager.request(
"POST",
url=self.url,
body=json.dumps(self.body).encode(),
headers=self.header
)
logger.info(f"Response status code: {response.status}")
response_json = json.loads(response.data.decode('utf-8'))
outputs = response_json.get("choices", [])[0].get("message", {}).get("content")
if not outputs:
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
"Invalid response format for LLM, missing the 'prompt' key word") from None
return outputs
except Exception as e:
logger.error(f"LLM service is not available, error detail: {e}")
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
if self.access_type:
try:
if self.is_https:
req = Request(url=self.url, data=json.dumps(self.body).encode(), headers=self.header, method="POST")
response_json = urlopen(req, context=self.context).read().decode("utf-8")
response = json.loads(response_json)
else:
response = requests.post(url=self.url, data=json.dumps(self.body), headers=self.header,
stream=False).json()
outputs = response.get("choices", [])[0].get("message", {}).get("content")
if not outputs:
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
"Invalid response format for LLM, missing the 'prompt' key word") from None
return outputs
except Exception as e:
logger.error(f"LLM service is not available, error detail: {e}")
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
return None # 确保在所有情况下都返回

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
import sys
import re
import subprocess
from pathlib import Path
from typing import Optional
from loguru import logger
from packaging.version import parse as parse_version, Version
def install_whl(
package_name: str,
whl_path: str,
exact_version: Optional[str] = None,
filename_pattern: Optional[str] = None,
force_reinstall: bool = False
) -> None:
"""
:param package_name: eg: ("zh_core_web_sm")
:param whl_path: WHL file save path
:param exact_version: version number
:param filename_pattern: custom filename pattern for REGEX
:param force_reinstall: which decide to overlap the original number or not (default: False)
"""
whl_dir = Path(whl_path).resolve()
whl_files = _get_whl_files(exact_version, filename_pattern, package_name, whl_dir)
# 语义化版本排序
target_whl = _sort_whl_files(whl_files)
# 安装命令
cmd = [
sys.executable, "-m", "pip", "install",
"--no-index",
f"--find-links={whl_dir}",
str(target_whl)
]
if force_reinstall:
cmd.append("--force-reinstall")
try:
subprocess.check_call(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT
)
logger.info(f"Successfully installed {target_whl.name}")
except subprocess.CalledProcessError as e:
error_msg = (
f"Installation failed for {package_name}\n"
f"Possible reasons:\n"
f"1. Missing dependencies in {whl_dir}\n"
f"2. Incompatible Python version\n"
f"3. Platform mismatch (e.g., x86 vs ARM)"
)
raise RuntimeError(error_msg) from e
def _sort_whl_files(whl_files):
whl_versions = []
logger.info(f"[load_offline_module]whl_files: {whl_files}")
for f in whl_files:
try:
version = _extract_version(f)
whl_versions.append((f, version))
except ValueError as e:
logger.warning(f"Skipping invalid file {f.name}: {e}")
continue
if not whl_versions:
raise FileNotFoundError("No valid WHL files with parseable versions")
whl_versions.sort(key=lambda x: x[1], reverse=True)
target_whl = whl_versions[0][0]
return target_whl
def _get_whl_files(exact_version, filename_pattern, package_name, whl_dir):
# 正则表达式
if filename_pattern:
pattern = filename_pattern
else:
if exact_version:
version_part = re.escape(exact_version)
pattern = rf"^{re.escape(package_name)}-{version_part}-\S*\.whl$"
else:
pattern = rf"^{re.escape(package_name)}\S*\.whl$"
regex = re.compile(pattern, re.IGNORECASE)
whl_files = [f for f in whl_dir.glob("*.whl") if regex.match(f.name)]
if not whl_files:
available_files = "\n".join([f.name for f in whl_dir.glob("*.whl")])
raise FileNotFoundError(
f"No matching WHL found for {package_name} in {whl_dir}\n"
f"Available files:\n{available_files}"
)
return whl_files
def _extract_version(filename: Path) -> Version:
"""从文件名提取语义化版本("""
match = re.search(
r"-(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)-",
filename.name
)
if not match:
raise ValueError(f"Invalid version format: {filename.name}")
return parse_version(match.group(1))

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
from loguru import logger
class Registry(object):
"""注册器类,用于注册所有的算子."""
def __init__(self, name: str):
self._name = name
self._modules = {}
@property
def name(self):
return self._name
@property
def modules(self):
return self._modules
def list(self):
"""日志打印注册器中所有的算子"""
for m in self._modules.keys():
logger.info(f'{self._name}\t{m}')
def get(self, module_key):
return self._modules.get(module_key, None)
def register_module(self, module_name: str = None, module_cls: type = None, module_path: str = None, force=False):
"""
使用特定的模块名称注册模块
:param module_name: 模块名称
:param module_cls: 模块类定义
:param module_path: 模块所在的路径
:param force: 是否强行覆盖同名模块,默认值为False.
Example:
>>> registry = Registry()
>>> @registry.register_module()
>>> class TextFormatter:
>>> pass
>>> class TextFormatter2:
>>> pass
>>> registry.register_module( module_name='text_formatter2', module_cls=TextFormatter2)
"""
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str,'
f'got {type(module_name)}')
if module_cls is not None:
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
elif module_cls is None and isinstance(module_path, str):
self._register_module(module_name=module_name,
module_path=module_path,
force=force)
return module_path
def _register(module_cls):
"""
注册其中module_cls为None是,返回装饰器函数
"""
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
return _register
def _register_module(self, module_name=None, module_cls=None, module_path=None, force=False):
"""
注册模块到注册器中.
:param module_name: 模块名称
:param module_cls: 模块类定义
:param force: 是否强行覆盖同名模块,默认值为False.
"""
if module_name is None and module_cls is not None:
module_name = module_cls.__name__
if module_name in self._modules:
if module_cls is not None and module_cls == self._modules[module_name]:
return
if module_path is not None and module_path == self._modules[module_name]:
return
if not force:
raise KeyError(
f'{module_name} is already registered in {self._name}, content: {self.modules.keys()}')
if module_cls is not None:
self._modules[module_name] = module_cls
elif module_path is not None:
self._modules[module_name] = module_path

View File

@@ -0,0 +1,171 @@
#!/user/bin/python
import re
from typing import List
from collections import deque
from loguru import logger
class TextSplitter:
"""文本切片"""
# 基于常用标点符号分句,保持句子完整
COMMON_PUNCTUATIONS = ["", "", "", "", "", ",", "?", "!", ";"]
PUNC_PATTERN = f"[{''.join(COMMON_PUNCTUATIONS)}]"
def __init__(self, max_characters: int, chunk_size: int, chunk_overlap: int):
"""文本切片初始化
Args:
max_characters :文件最大字符,超过截断,-1不处理
chunk_size: 块大小
chunk_overlap: 块重叠度
"""
if chunk_size <= chunk_overlap:
logger.error(f"param chunk_size should larger than chunk_overlap, "
f"current chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")
raise Exception(83000, str(ValueError)) from None
self.max_characters = max_characters
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = ["\n\n", "\n"]
@staticmethod
def split_text_by_separator(text: str, separator: str):
"""指定分隔符对文本进行切分,并且切分后的片段需要保留分隔符"""
# 处理一个换行符与两个换行符之间的冲突
if text.startswith("\n\n") and separator == "\n":
chunks = re.split(f"({separator})", text.strip())
chunks[0] = f"\n\n{chunks[0]}"
else:
chunks = re.split(f"({separator})", text)
new_chunks = [chunks[idx] + chunks[idx + 1] for idx in range(1, len(chunks), 2)]
new_chunks = [chunks[0]] + new_chunks
return [chunk for chunk in new_chunks if chunk.strip() != ""]
@staticmethod
def split_sentences(chunk: str):
"""对切片按照标点符号切分成句子,并且保持标点符号不丢失"""
sentences = re.split(TextSplitter.PUNC_PATTERN, chunk)
delimiters = [s for s in chunk if s in TextSplitter.COMMON_PUNCTUATIONS]
restore_chunks = []
for chunk, delimiter in zip(sentences[:-1], delimiters):
restore_chunks.append(chunk + delimiter)
return restore_chunks + [sentences[-1]]
def split_text(self, input_data: str):
if self.max_characters > 0:
logger.info(f"The document characters should be within: {self.max_characters}")
input_data = input_data[:self.max_characters]
logger.info(f"characters of the document: {len(input_data)}")
chunks = self.split_text_recursive(input_data, self.separators)
final_chunks = self.merge_chunks(chunks)
final_chunks = self.split_text_by_chunk_size(final_chunks)
return [chunk.strip() for chunk in final_chunks if chunk]
def split_text_recursive(self, input_data: str, separators: List[str]):
"""对文档按照分隔符优先级进行递归切分:
1. 符合chunk_size要求的切片不再切分。
2. 大于chunk_size要求的切片,继续进行递归切分。
Args:
input_data: 输入文本
separators: 分隔符
Returns:
List[str]: 切分后的文本片段
"""
chunks = []
cur_separator = ""
next_separators = []
for idx, sep in enumerate(separators):
sep = re.escape(sep)
if re.search(sep, input_data.strip()):
cur_separator = sep
next_separators = separators[idx + 1:]
break
if not cur_separator:
return [input_data]
else:
cur_chunks = TextSplitter.split_text_by_separator(input_data, cur_separator)
for chunk in cur_chunks:
if len(chunk.strip()) <= self.chunk_size:
chunks.append(chunk)
else:
if not next_separators:
chunks.append(chunk)
else:
next_chunks = self.split_text_recursive(chunk, next_separators)
chunks.extend(next_chunks)
return chunks
def merge_chunks(self, chunks: List[str]):
"""对切分后的文本片段进行合并,合并过程考虑overlap"""
final_chunks = []
idx = 0
while idx < len(chunks):
if len(chunks[idx]) >= self.chunk_size:
final_chunks.append(chunks[idx])
idx += 1
continue
merge_idxes = self.get_merge_idxes(idx, chunks)
content = ""
for inner_idx in merge_idxes:
content += chunks[inner_idx]
final_chunks.append(content)
idx = merge_idxes[-1] + 1
return final_chunks
def get_merge_idxes(self, cur_idx: int, chunks: List[str]):
"""获取可以合并的分片index,前向尽可能满足overlap,后向尽可能满足chunk_size"""
idxes = deque([cur_idx])
overlap_idx = cur_idx - 1
cur_len = len(chunks[cur_idx])
cur_idx += 1
# 获取overlap的index
over_lap_len = 0
while overlap_idx >= 0:
over_lap_len += len(chunks[overlap_idx])
if over_lap_len > self.chunk_overlap or (cur_len + over_lap_len) > self.chunk_size:
over_lap_len -= len(chunks[overlap_idx])
break
idxes.appendleft(overlap_idx)
overlap_idx -= 1
cur_len += over_lap_len
# 获取merge的index
while cur_idx < len(chunks):
cur_len += len(chunks[cur_idx])
if cur_len > self.chunk_size:
break
idxes.append(cur_idx)
cur_idx += 1
return idxes
def split_chunks(self, chunks: List[str]):
"""对超过`chunk_size`限制的切片进行截断,过程中需要考虑overlap参数"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
else:
start = 0
end = self.chunk_size
while end < len(chunk):
final_chunks.append(chunk[start: end])
start += self.chunk_size - self.chunk_overlap
end = start + self.chunk_size
final_chunks.append(chunk[start:])
return final_chunks
def split_text_by_chunk_size(self, chunks: List[str]):
"""对切片后超长的文本块进行二次切分,使用截断,并考虑overlap"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
continue
sentences = TextSplitter.split_sentences(chunk)
sub_chunks = self.merge_chunks(sentences)
final_chunks.extend(self.split_chunks(sub_chunks))
return final_chunks