You've already forked DataMate
init datamate
This commit is contained in:
59
runtime/python-executor/datamate/common/utils/__init__.py
Normal file
59
runtime/python-executor/datamate/common/utils/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import pytz
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def check_valid_path(file_path):
|
||||
full_path = os.path.abspath(file_path)
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def get_realpath_with_prefix_check(path, prefix):
|
||||
realpath = os.path.realpath(path)
|
||||
|
||||
if realpath.startswith(prefix):
|
||||
return realpath
|
||||
else:
|
||||
raise ValueError(f"The path {realpath} does not start with the prefix '{prefix}'.")
|
||||
|
||||
|
||||
def bytes_to_numpy(image_bytes):
|
||||
"""bytes转数组"""
|
||||
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image_np2
|
||||
|
||||
|
||||
def numpy_to_bytes(image_np, file_type):
|
||||
"""数组转bytes"""
|
||||
if not image_np.size:
|
||||
return b""
|
||||
data = cv2.imencode(file_type, image_np)[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def get_now_time(timezone, time_format, file_name, method):
|
||||
timestamp = ""
|
||||
try:
|
||||
china_tz = pytz.timezone(timezone) # 设置时区
|
||||
china_time = datetime.now(tz=china_tz) # 获取当前时间并转换为对应的时区
|
||||
timestamp = china_time.strftime(time_format) # 格式化输出时间
|
||||
except ValueError as e:
|
||||
logger.error("fileName: %s, method: %s, formatting time failed: %s", file_name, method, e, exc_info=True)
|
||||
return timestamp
|
||||
|
||||
|
||||
def decrypt(enc_pass):
|
||||
import kmc.kmc as K
|
||||
os.environ['KMC_DATA_USER'] = 'modelenginepublic'
|
||||
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = enc_pass
|
||||
dec_pass = K.API().decrypt(0)
|
||||
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = ""
|
||||
return dec_pass
|
||||
115
runtime/python-executor/datamate/common/utils/aho_corasick.py
Normal file
115
runtime/python-executor/datamate/common/utils/aho_corasick.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# -- encoding: utf-8 --
|
||||
from collections import deque
|
||||
|
||||
|
||||
class TrieNode:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.child = dict()
|
||||
self.fail = None
|
||||
self.word = None
|
||||
|
||||
|
||||
class AhoCorasic:
|
||||
"""AC自动机算法进行目标字符串搜索"""
|
||||
|
||||
def __init__(self, words):
|
||||
self._root = add_fail_pointer(build_trie(words))
|
||||
|
||||
def search(self, text: str, special_symbols: set):
|
||||
"""
|
||||
匹配敏感词。
|
||||
|
||||
Args:
|
||||
text: 文本
|
||||
special_symbols: 特殊字符(需跳过)
|
||||
Returns:
|
||||
匹配成功的字符串列表
|
||||
"""
|
||||
seq_list = []
|
||||
node = self._root
|
||||
|
||||
valid_len = 0 # 当前遍历的有效长度
|
||||
for i, s in enumerate(text):
|
||||
if s in special_symbols: # 跳过特殊字符
|
||||
if valid_len != 0:
|
||||
valid_len += 1
|
||||
continue
|
||||
|
||||
matched = True
|
||||
while s not in node.child: # 当node.child没有字符s
|
||||
if node == self._root: # 当node为root(无node.fail),有效长度归0且跳出
|
||||
valid_len = 0
|
||||
matched = False
|
||||
break
|
||||
elif node.fail == self._root: # node.fail为root场景,有效长度归0,但可继续
|
||||
valid_len = 0
|
||||
node = node.fail # 移动到失败指针节点
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
node = node.child.get(s)
|
||||
valid_len += 1
|
||||
if node.word: # node是单词尾字母
|
||||
sensitive_word = text[i - valid_len + 1:i + 1]
|
||||
seq_list.append(sensitive_word)
|
||||
seq_list = list(set(seq_list))
|
||||
return seq_list
|
||||
|
||||
|
||||
def build_trie(words: list):
|
||||
"""
|
||||
构建前缀树。
|
||||
|
||||
Args:
|
||||
words: 敏感词列表。
|
||||
Returns:
|
||||
前缀树根节点。
|
||||
"""
|
||||
root = TrieNode('root')
|
||||
for word in words:
|
||||
node = root
|
||||
for s in word:
|
||||
if s not in node.child:
|
||||
node.child[s] = TrieNode(s)
|
||||
node = node.child[s]
|
||||
if not node.word:
|
||||
node.word = {word}
|
||||
else:
|
||||
node.word.add(word)
|
||||
return root
|
||||
|
||||
|
||||
def add_fail_pointer(root: TrieNode):
|
||||
"""
|
||||
为前缀树添加失败指针。
|
||||
步骤:
|
||||
1. 从root开始逐层将node和node.parent以二元组存放队列。root没有fail指针,root.child的失败指针即为root。
|
||||
2. 对于root和root.child以外的node,查询node.parent.fail.child。
|
||||
3. 如果存在node.parent.fail.child.value == node.value,则构建node.fail = node.parent.fail.child.value。
|
||||
|
||||
Args:
|
||||
root: 前缀树根节点。
|
||||
returns:
|
||||
添加失败指针后的前缀树根节点。
|
||||
"""
|
||||
queue = deque()
|
||||
queue.appendleft((root, None))
|
||||
while len(queue) > 0:
|
||||
node_parent = queue.pop()
|
||||
curr, parent = node_parent[0], node_parent[1]
|
||||
for sub in curr.child.values():
|
||||
queue.appendleft((sub, curr))
|
||||
if parent is None:
|
||||
continue
|
||||
elif parent is root:
|
||||
curr.fail = root
|
||||
else:
|
||||
parent_fail = parent.fail
|
||||
while parent_fail and curr.value not in parent_fail.child:
|
||||
parent_fail = parent_fail.fail
|
||||
if parent_fail:
|
||||
curr.fail = parent_fail.child[curr.value]
|
||||
else:
|
||||
curr.fail = root
|
||||
return root
|
||||
@@ -0,0 +1,66 @@
|
||||
# -- encoding: utf-8 --
|
||||
import pickle
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def bytes_to_numpy(image_bytes):
|
||||
"""bytes转数组"""
|
||||
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image_np2
|
||||
|
||||
|
||||
def numpy_to_bytes(image_np, file_type):
|
||||
"""
|
||||
数组转bytes
|
||||
|
||||
Params:
|
||||
|
||||
file_type: as required by OpenCV, extension must have a leading period.
|
||||
"""
|
||||
if not image_np.size:
|
||||
return b""
|
||||
data = cv2.imencode(file_type, image_np)[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def pil_to_bytes(src: Image.Image) -> bytes:
|
||||
"""将 PIL.Image 转换为字节流"""
|
||||
# 确保图像是 RGB 模式
|
||||
src = src.convert("RGB")
|
||||
with BytesIO() as bytes_io:
|
||||
src.save(bytes_io, format='PNG')
|
||||
im_bytes = bytes_io.getvalue()
|
||||
return im_bytes
|
||||
|
||||
|
||||
def bytes_to_pil(src: bytes) -> Image.Image:
|
||||
"""将字节流转换为 PIL.Image"""
|
||||
with BytesIO() as bytes_io:
|
||||
with Image.open(bytes_io) as pil_img: # 使用with/as语句确保资源被正确释放
|
||||
pil_img.load() # 确保图像数据被加载
|
||||
return pil_img.copy() # 返回图像的副本以避免资源被关闭后无法使用
|
||||
|
||||
|
||||
def pil_to_base64(src: Image.Image):
|
||||
"""PIl.Image转base64"""
|
||||
with BytesIO() as img_buffer:
|
||||
src.save(img_buffer, format='png')
|
||||
byte_data = img_buffer.getvalue()
|
||||
base64_str = base64.b64encode(byte_data)
|
||||
return base64_str
|
||||
|
||||
|
||||
def obj_to_bytes(src: object) -> bytes:
|
||||
return pickle.dumps(src)
|
||||
|
||||
|
||||
def bytes_to_obj(src: bytes) -> object:
|
||||
return pickle.loads(src)
|
||||
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import importlib.abc
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CustomImporter(importlib.abc.MetaPathFinder):
|
||||
def __init__(self, base_path):
|
||||
self.base_path = Path(base_path).resolve()
|
||||
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
# 将模块名转换为路径(例如:mypkg.mymodule -> mypkg/mymodule.py)
|
||||
parts = fullname.split(".")
|
||||
module_path = self.base_path.joinpath(*parts)
|
||||
|
||||
# 检查是否存在 .py 文件或目录
|
||||
if module_path.with_suffix(".py").exists():
|
||||
return importlib.util.spec_from_file_location(
|
||||
fullname,
|
||||
str(module_path.with_suffix(".py")),
|
||||
submodule_search_locations=[str(module_path.parent)]
|
||||
)
|
||||
elif module_path.is_dir() and (module_path / "__init__.py").exists():
|
||||
return importlib.util.spec_from_file_location(
|
||||
fullname,
|
||||
str(module_path / "__init__.py"),
|
||||
submodule_search_locations=[str(module_path)]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
229
runtime/python-executor/datamate/common/utils/lazy_loader.py
Normal file
229
runtime/python-executor/datamate/common/utils/lazy_loader.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import importlib
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from loguru import logger
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
|
||||
def is_valid_whl_filename(package_name):
|
||||
"""
|
||||
验证WHL文件名是否安全(聚焦防范命令注入攻击)
|
||||
|
||||
Args:
|
||||
package_name (str): 要验证的whl文件名
|
||||
|
||||
Returns:
|
||||
bool: True表示安全,False表示包含危险字符
|
||||
"""
|
||||
|
||||
# 禁止路径分隔符
|
||||
if os.path.sep in package_name or (os.path.altsep and os.path.altsep in package_name):
|
||||
return False
|
||||
|
||||
# 定义危险字符黑名单
|
||||
dangerous_pattern = re.compile(r"""
|
||||
[ # 匹配任意以下字符
|
||||
; & | ` # 命令分隔符
|
||||
$ () {} <> # 变量展开/命令替换
|
||||
\] \[ # 特殊符号
|
||||
! \\ # 历史扩展/转义符
|
||||
'"*?#\s # 引号/通配符/空格
|
||||
]
|
||||
""", re.VERBOSE)
|
||||
|
||||
if dangerous_pattern.search(package_name):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class PackageNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LazyLoader(ModuleType):
|
||||
def __init__(self,
|
||||
package_name,
|
||||
module_name=None,
|
||||
whl_path="/dataset/ops_whl",
|
||||
exact_version=None,
|
||||
force_reinstall=False
|
||||
):
|
||||
"""
|
||||
:param package_name: WHL包名称中的模块名称部分(一般是模块名称_替换为-)
|
||||
:param module_name: WHL包安装后,可用于import的模块名称, 当whl包名称和import名称不一致时,填写local_name.
|
||||
:param whl_path: WHL文件所在目录
|
||||
:param exact_version: 精确版本要求
|
||||
:param force_reinstall: 强制重新安装
|
||||
"""
|
||||
try:
|
||||
frame = sys._getframe(1)
|
||||
self._parent_globals = frame.f_globals
|
||||
except (AttributeError, ValueError) as e:
|
||||
logger.error(f"Failed to get stack frame: {e}")
|
||||
raise RuntimeError("Stack frame retrieval failed") from e
|
||||
|
||||
self._module_name = module_name if module_name else package_name
|
||||
self._package_name = package_name
|
||||
|
||||
self.whl_path = Path(whl_path).resolve()
|
||||
self.exact_version = exact_version
|
||||
|
||||
self.force_reinstall = force_reinstall
|
||||
self._cached_module = None
|
||||
|
||||
# 注册别名到父级命名空间
|
||||
self._parent_globals[self._module_name] = self
|
||||
super().__init__(self._module_name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._cached_module is None:
|
||||
self._cached_module = self._load_module()
|
||||
return getattr(self._cached_module, name)
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._load_module())
|
||||
|
||||
def _load_module(self):
|
||||
"""模块加载逻辑"""
|
||||
|
||||
if self._cached_module is not None:
|
||||
return self._cached_module
|
||||
|
||||
package_name: str = self._package_name.split('.')[0]
|
||||
if not is_valid_whl_filename(package_name):
|
||||
logger.error(f"Invalid package_name: {package_name}")
|
||||
raise RuntimeError("Invalide package_name, please check it again!")
|
||||
|
||||
module_name = self._module_name if self._module_name else package_name.replace("_", "-")
|
||||
|
||||
need_install = False
|
||||
|
||||
try:
|
||||
if not self.force_reinstall:
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
return module
|
||||
except ImportError:
|
||||
need_install = True
|
||||
|
||||
if self.force_reinstall:
|
||||
# 强制安装时的版本检查
|
||||
installed = self._check_package_exists(package_name)
|
||||
if installed and self.exact_version:
|
||||
installed_version = self._get_installed_version(package_name)
|
||||
if parse_version(installed_version) != parse_version(self.exact_version):
|
||||
logger.info(f"Version mismatch detected: {installed_version} vs {self.exact_version}")
|
||||
need_install = True
|
||||
else:
|
||||
need_install = True
|
||||
|
||||
if need_install:
|
||||
self._pip_install_package(package_name)
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
else:
|
||||
# 版本检查通过,无需再次安装
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
|
||||
return self._cached_module
|
||||
|
||||
def _register_alias(self, module):
|
||||
"""注册本地别名 """
|
||||
self._parent_globals[self._module_name] = module
|
||||
sys.modules[self._module_name] = module
|
||||
|
||||
def _check_package_exists(self, package_name):
|
||||
"""增强版包检查"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", package_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Package check failed: {e}")
|
||||
return False
|
||||
|
||||
def _get_installed_version(self, package_name):
|
||||
"""获取已安装版本 """
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", package_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if line.startswith('Version:'):
|
||||
return line.split()[-1]
|
||||
raise PackageNotFoundError()
|
||||
|
||||
def _pip_install_package(self, package_name: str):
|
||||
"""安装逻辑 """
|
||||
|
||||
if not self.whl_path.exists():
|
||||
raise FileNotFoundError(f"WHL directory not found: {self.whl_path}")
|
||||
|
||||
whl_files = list(self.whl_path.glob(f"{package_name}*.whl"))
|
||||
if not whl_files:
|
||||
raise RuntimeError(f"No WHL files found for {package_name}")
|
||||
|
||||
# 版本过滤
|
||||
if self.exact_version:
|
||||
pattern = re.compile(
|
||||
rf'^{re.escape(package_name)}-{re.escape(self.exact_version)}-\S*\.whl$',
|
||||
re.IGNORECASE
|
||||
)
|
||||
whl_files = [f for f in whl_files if pattern.match(f.name)]
|
||||
|
||||
# 选择最新版本
|
||||
whl_versions = []
|
||||
for f in whl_files:
|
||||
try:
|
||||
version = self._extract_version(f)
|
||||
whl_versions.append((f, version))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not whl_versions:
|
||||
raise FileNotFoundError("No valid WHL files")
|
||||
|
||||
whl_versions.sort(key=lambda x: x[1], reverse=True)
|
||||
target_whl = whl_versions[0][0]
|
||||
|
||||
# 执行安装
|
||||
try:
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"--no-index",
|
||||
f"--find-links={self.whl_path}",
|
||||
str(target_whl)
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
logger.info(f"Successfully installed {target_whl}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Installation failed: {e}")
|
||||
raise RuntimeError(f"Installation failed: {e}") from e
|
||||
|
||||
def _extract_version(self, filename):
|
||||
"""版本解析 """
|
||||
|
||||
version_pattern = r"(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)"
|
||||
|
||||
match = re.search(
|
||||
rf"^{re.escape(self._package_name)}-({version_pattern})",
|
||||
filename.name,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {filename.name}")
|
||||
return parse_version(match.group(1))
|
||||
118
runtime/python-executor/datamate/common/utils/llm_request.py
Normal file
118
runtime/python-executor/datamate/common/utils/llm_request.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import requests
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from datamate.common.utils import decrypt
|
||||
|
||||
|
||||
class LlmReq:
|
||||
# 定义常量用于解释错误码
|
||||
ERRORCODE_INCOMPLETE_CONFIG = 83005
|
||||
ERRORCODE_INVALID_RESPONSE = 83006
|
||||
ERRORCODE_SERVICE_UNAVAILABLE = 83007
|
||||
|
||||
def __init__(self, url: str = None, header: Dict = None, body: Dict = None, access_type: int = None,
|
||||
is_https: bool = False, is_certificate: bool = False, certificate_path: Path = None):
|
||||
self.url = url
|
||||
self.header = header
|
||||
self.access_type = access_type
|
||||
self.is_https = is_https
|
||||
self.context = self._load_certificate(certificate_path, is_certificate) if is_https else None
|
||||
self.body = body
|
||||
if not self.body.get("messages", [])[0].get("content"):
|
||||
self.body["messages"][0]["content"] = "你好"
|
||||
|
||||
def __call__(self, input_str: str) -> str:
|
||||
outputs = ''
|
||||
try:
|
||||
self.body["messages"][0]["content"] = input_str
|
||||
outputs = self._call_service()
|
||||
except KeyError as e:
|
||||
logger.error(f"The body format is not completed, error detail: {e}")
|
||||
self.body["messages"][0]["content"] = "你好"
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _load_certificate(certificate_path: Path, is_certificate: bool) -> ssl.SSLContext:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
if is_certificate:
|
||||
context.load_verify_locations(certificate_path)
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _pool_manager():
|
||||
cert_file = os.getenv("RAY_TLS_SERVER_CERT", "/certPersonal/global/identity/global.crt")
|
||||
key_file = os.getenv("RAY_TLS_SERVER_KEY", "/certPersonal/global/identity/global.key")
|
||||
ca_crt = os.getenv("RAY_TLS_CA_CERT", "/certPersonal/global/trust/ca.crt")
|
||||
pwd = os.getenv("RAY_TLS_SERVER_KEY_PASSWORD", "/certPersonal/global/identity/pwd.txt")
|
||||
key_password = os.getenv("GLOBAL_PWD", None)
|
||||
if not key_password:
|
||||
with open(pwd, "r") as f:
|
||||
key_password = f.read().strip()
|
||||
key_password = decrypt(key_password)
|
||||
|
||||
pool_manager = urllib3.PoolManager(cert_file=cert_file,
|
||||
key_file=key_file,
|
||||
key_password=key_password,
|
||||
cert_reqs='CERT_REQUIRED',
|
||||
ca_certs=ca_crt,
|
||||
assert_hostname='edatamate',
|
||||
ssl_version='TLSv1_2')
|
||||
return pool_manager
|
||||
|
||||
def _call_service(self):
|
||||
if not all([self.url, self.header, self.body.get("messages", [])[0].get("content")]):
|
||||
logger.error("LLM is not configured completely")
|
||||
raise RuntimeError(self.ERRORCODE_INCOMPLETE_CONFIG, "LLM is not configured completely") from None
|
||||
if not self.access_type:
|
||||
try:
|
||||
pool_manager = self._pool_manager()
|
||||
response = pool_manager.request(
|
||||
"POST",
|
||||
url=self.url,
|
||||
body=json.dumps(self.body).encode(),
|
||||
headers=self.header
|
||||
)
|
||||
logger.info(f"Response status code: {response.status}")
|
||||
response_json = json.loads(response.data.decode('utf-8'))
|
||||
outputs = response_json.get("choices", [])[0].get("message", {}).get("content")
|
||||
if not outputs:
|
||||
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
|
||||
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
|
||||
"Invalid response format for LLM, missing the 'prompt' key word") from None
|
||||
return outputs
|
||||
except Exception as e:
|
||||
logger.error(f"LLM service is not available, error detail: {e}")
|
||||
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
|
||||
if self.access_type:
|
||||
try:
|
||||
if self.is_https:
|
||||
req = Request(url=self.url, data=json.dumps(self.body).encode(), headers=self.header, method="POST")
|
||||
response_json = urlopen(req, context=self.context).read().decode("utf-8")
|
||||
response = json.loads(response_json)
|
||||
else:
|
||||
response = requests.post(url=self.url, data=json.dumps(self.body), headers=self.header,
|
||||
stream=False).json()
|
||||
outputs = response.get("choices", [])[0].get("message", {}).get("content")
|
||||
if not outputs:
|
||||
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
|
||||
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
|
||||
"Invalid response format for LLM, missing the 'prompt' key word") from None
|
||||
return outputs
|
||||
except Exception as e:
|
||||
logger.error(f"LLM service is not available, error detail: {e}")
|
||||
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
|
||||
return None # 确保在所有情况下都返回
|
||||
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from packaging.version import parse as parse_version, Version
|
||||
|
||||
|
||||
def install_whl(
|
||||
package_name: str,
|
||||
whl_path: str,
|
||||
exact_version: Optional[str] = None,
|
||||
filename_pattern: Optional[str] = None,
|
||||
force_reinstall: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param package_name: eg: ("zh_core_web_sm")
|
||||
:param whl_path: WHL file save path
|
||||
:param exact_version: version number
|
||||
:param filename_pattern: custom filename pattern for REGEX
|
||||
:param force_reinstall: which decide to overlap the original number or not (default: False)
|
||||
"""
|
||||
|
||||
whl_dir = Path(whl_path).resolve()
|
||||
whl_files = _get_whl_files(exact_version, filename_pattern, package_name, whl_dir)
|
||||
|
||||
# 语义化版本排序
|
||||
target_whl = _sort_whl_files(whl_files)
|
||||
|
||||
# 安装命令
|
||||
cmd = [
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"--no-index",
|
||||
f"--find-links={whl_dir}",
|
||||
str(target_whl)
|
||||
]
|
||||
if force_reinstall:
|
||||
cmd.append("--force-reinstall")
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
logger.info(f"Successfully installed {target_whl.name}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = (
|
||||
f"Installation failed for {package_name}\n"
|
||||
f"Possible reasons:\n"
|
||||
f"1. Missing dependencies in {whl_dir}\n"
|
||||
f"2. Incompatible Python version\n"
|
||||
f"3. Platform mismatch (e.g., x86 vs ARM)"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
|
||||
def _sort_whl_files(whl_files):
|
||||
whl_versions = []
|
||||
logger.info(f"[load_offline_module]whl_files: {whl_files}")
|
||||
for f in whl_files:
|
||||
try:
|
||||
version = _extract_version(f)
|
||||
whl_versions.append((f, version))
|
||||
except ValueError as e:
|
||||
logger.warning(f"Skipping invalid file {f.name}: {e}")
|
||||
continue
|
||||
if not whl_versions:
|
||||
raise FileNotFoundError("No valid WHL files with parseable versions")
|
||||
whl_versions.sort(key=lambda x: x[1], reverse=True)
|
||||
target_whl = whl_versions[0][0]
|
||||
return target_whl
|
||||
|
||||
|
||||
def _get_whl_files(exact_version, filename_pattern, package_name, whl_dir):
|
||||
# 正则表达式
|
||||
if filename_pattern:
|
||||
pattern = filename_pattern
|
||||
else:
|
||||
if exact_version:
|
||||
version_part = re.escape(exact_version)
|
||||
pattern = rf"^{re.escape(package_name)}-{version_part}-\S*\.whl$"
|
||||
else:
|
||||
pattern = rf"^{re.escape(package_name)}\S*\.whl$"
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
whl_files = [f for f in whl_dir.glob("*.whl") if regex.match(f.name)]
|
||||
if not whl_files:
|
||||
available_files = "\n".join([f.name for f in whl_dir.glob("*.whl")])
|
||||
raise FileNotFoundError(
|
||||
f"No matching WHL found for {package_name} in {whl_dir}\n"
|
||||
f"Available files:\n{available_files}"
|
||||
)
|
||||
return whl_files
|
||||
|
||||
|
||||
def _extract_version(filename: Path) -> Version:
|
||||
"""从文件名提取语义化版本("""
|
||||
|
||||
match = re.search(
|
||||
r"-(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)-",
|
||||
filename.name
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {filename.name}")
|
||||
return parse_version(match.group(1))
|
||||
102
runtime/python-executor/datamate/common/utils/registry.py
Normal file
102
runtime/python-executor/datamate/common/utils/registry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Registry(object):
|
||||
"""注册器类,用于注册所有的算子."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
||||
self._name = name
|
||||
self._modules = {}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def modules(self):
|
||||
return self._modules
|
||||
|
||||
def list(self):
|
||||
"""日志打印注册器中所有的算子"""
|
||||
for m in self._modules.keys():
|
||||
logger.info(f'{self._name}\t{m}')
|
||||
|
||||
def get(self, module_key):
|
||||
return self._modules.get(module_key, None)
|
||||
|
||||
def register_module(self, module_name: str = None, module_cls: type = None, module_path: str = None, force=False):
|
||||
"""
|
||||
使用特定的模块名称注册模块
|
||||
|
||||
:param module_name: 模块名称
|
||||
:param module_cls: 模块类定义
|
||||
:param module_path: 模块所在的路径
|
||||
:param force: 是否强行覆盖同名模块,默认值为False.
|
||||
|
||||
Example:
|
||||
>>> registry = Registry()
|
||||
>>> @registry.register_module()
|
||||
>>> class TextFormatter:
|
||||
>>> pass
|
||||
|
||||
>>> class TextFormatter2:
|
||||
>>> pass
|
||||
>>> registry.register_module( module_name='text_formatter2', module_cls=TextFormatter2)
|
||||
"""
|
||||
if not (module_name is None or isinstance(module_name, str)):
|
||||
raise TypeError(f'module_name must be either of None, str,'
|
||||
f'got {type(module_name)}')
|
||||
if module_cls is not None:
|
||||
self._register_module(module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
force=force)
|
||||
return module_cls
|
||||
|
||||
elif module_cls is None and isinstance(module_path, str):
|
||||
self._register_module(module_name=module_name,
|
||||
module_path=module_path,
|
||||
force=force)
|
||||
return module_path
|
||||
|
||||
def _register(module_cls):
|
||||
"""
|
||||
注册其中module_cls为None是,返回装饰器函数
|
||||
"""
|
||||
self._register_module(module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
force=force)
|
||||
return module_cls
|
||||
|
||||
return _register
|
||||
|
||||
def _register_module(self, module_name=None, module_cls=None, module_path=None, force=False):
|
||||
"""
|
||||
注册模块到注册器中.
|
||||
|
||||
:param module_name: 模块名称
|
||||
:param module_cls: 模块类定义
|
||||
:param force: 是否强行覆盖同名模块,默认值为False.
|
||||
"""
|
||||
|
||||
if module_name is None and module_cls is not None:
|
||||
module_name = module_cls.__name__
|
||||
|
||||
if module_name in self._modules:
|
||||
if module_cls is not None and module_cls == self._modules[module_name]:
|
||||
return
|
||||
|
||||
if module_path is not None and module_path == self._modules[module_name]:
|
||||
return
|
||||
|
||||
if not force:
|
||||
raise KeyError(
|
||||
f'{module_name} is already registered in {self._name}, content: {self.modules.keys()}')
|
||||
|
||||
if module_cls is not None:
|
||||
self._modules[module_name] = module_cls
|
||||
|
||||
elif module_path is not None:
|
||||
self._modules[module_name] = module_path
|
||||
171
runtime/python-executor/datamate/common/utils/text_splitter.py
Normal file
171
runtime/python-executor/datamate/common/utils/text_splitter.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/user/bin/python
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
from collections import deque
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
"""文本切片"""
|
||||
# 基于常用标点符号分句,保持句子完整
|
||||
COMMON_PUNCTUATIONS = [",", "。", "?", "!", ";", ",", "?", "!", ";"]
|
||||
PUNC_PATTERN = f"[{''.join(COMMON_PUNCTUATIONS)}]"
|
||||
|
||||
def __init__(self, max_characters: int, chunk_size: int, chunk_overlap: int):
|
||||
"""文本切片初始化
|
||||
Args:
|
||||
max_characters :文件最大字符,超过截断,-1不处理
|
||||
chunk_size: 块大小
|
||||
chunk_overlap: 块重叠度
|
||||
"""
|
||||
if chunk_size <= chunk_overlap:
|
||||
logger.error(f"param chunk_size should larger than chunk_overlap, "
|
||||
f"current chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")
|
||||
raise Exception(83000, str(ValueError)) from None
|
||||
self.max_characters = max_characters
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.separators = ["\n\n", "\n"]
|
||||
|
||||
@staticmethod
|
||||
def split_text_by_separator(text: str, separator: str):
|
||||
"""指定分隔符对文本进行切分,并且切分后的片段需要保留分隔符"""
|
||||
# 处理一个换行符与两个换行符之间的冲突
|
||||
if text.startswith("\n\n") and separator == "\n":
|
||||
chunks = re.split(f"({separator})", text.strip())
|
||||
chunks[0] = f"\n\n{chunks[0]}"
|
||||
else:
|
||||
chunks = re.split(f"({separator})", text)
|
||||
new_chunks = [chunks[idx] + chunks[idx + 1] for idx in range(1, len(chunks), 2)]
|
||||
new_chunks = [chunks[0]] + new_chunks
|
||||
return [chunk for chunk in new_chunks if chunk.strip() != ""]
|
||||
|
||||
@staticmethod
|
||||
def split_sentences(chunk: str):
|
||||
"""对切片按照标点符号切分成句子,并且保持标点符号不丢失"""
|
||||
sentences = re.split(TextSplitter.PUNC_PATTERN, chunk)
|
||||
delimiters = [s for s in chunk if s in TextSplitter.COMMON_PUNCTUATIONS]
|
||||
restore_chunks = []
|
||||
for chunk, delimiter in zip(sentences[:-1], delimiters):
|
||||
restore_chunks.append(chunk + delimiter)
|
||||
return restore_chunks + [sentences[-1]]
|
||||
|
||||
def split_text(self, input_data: str):
|
||||
if self.max_characters > 0:
|
||||
logger.info(f"The document characters should be within: {self.max_characters}")
|
||||
input_data = input_data[:self.max_characters]
|
||||
logger.info(f"characters of the document: {len(input_data)}")
|
||||
chunks = self.split_text_recursive(input_data, self.separators)
|
||||
final_chunks = self.merge_chunks(chunks)
|
||||
final_chunks = self.split_text_by_chunk_size(final_chunks)
|
||||
return [chunk.strip() for chunk in final_chunks if chunk]
|
||||
|
||||
def split_text_recursive(self, input_data: str, separators: List[str]):
|
||||
"""对文档按照分隔符优先级进行递归切分:
|
||||
1. 符合chunk_size要求的切片不再切分。
|
||||
2. 大于chunk_size要求的切片,继续进行递归切分。
|
||||
Args:
|
||||
input_data: 输入文本
|
||||
separators: 分隔符
|
||||
|
||||
Returns:
|
||||
List[str]: 切分后的文本片段
|
||||
|
||||
"""
|
||||
chunks = []
|
||||
cur_separator = ""
|
||||
next_separators = []
|
||||
for idx, sep in enumerate(separators):
|
||||
sep = re.escape(sep)
|
||||
if re.search(sep, input_data.strip()):
|
||||
cur_separator = sep
|
||||
next_separators = separators[idx + 1:]
|
||||
break
|
||||
|
||||
if not cur_separator:
|
||||
return [input_data]
|
||||
else:
|
||||
cur_chunks = TextSplitter.split_text_by_separator(input_data, cur_separator)
|
||||
|
||||
for chunk in cur_chunks:
|
||||
if len(chunk.strip()) <= self.chunk_size:
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
if not next_separators:
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
next_chunks = self.split_text_recursive(chunk, next_separators)
|
||||
chunks.extend(next_chunks)
|
||||
return chunks
|
||||
|
||||
def merge_chunks(self, chunks: List[str]):
|
||||
"""对切分后的文本片段进行合并,合并过程考虑overlap"""
|
||||
final_chunks = []
|
||||
idx = 0
|
||||
while idx < len(chunks):
|
||||
if len(chunks[idx]) >= self.chunk_size:
|
||||
final_chunks.append(chunks[idx])
|
||||
idx += 1
|
||||
continue
|
||||
merge_idxes = self.get_merge_idxes(idx, chunks)
|
||||
content = ""
|
||||
for inner_idx in merge_idxes:
|
||||
content += chunks[inner_idx]
|
||||
final_chunks.append(content)
|
||||
idx = merge_idxes[-1] + 1
|
||||
return final_chunks
|
||||
|
||||
def get_merge_idxes(self, cur_idx: int, chunks: List[str]):
|
||||
"""获取可以合并的分片index,前向尽可能满足overlap,后向尽可能满足chunk_size"""
|
||||
idxes = deque([cur_idx])
|
||||
overlap_idx = cur_idx - 1
|
||||
cur_len = len(chunks[cur_idx])
|
||||
cur_idx += 1
|
||||
# 获取overlap的index
|
||||
over_lap_len = 0
|
||||
while overlap_idx >= 0:
|
||||
over_lap_len += len(chunks[overlap_idx])
|
||||
if over_lap_len > self.chunk_overlap or (cur_len + over_lap_len) > self.chunk_size:
|
||||
over_lap_len -= len(chunks[overlap_idx])
|
||||
break
|
||||
idxes.appendleft(overlap_idx)
|
||||
overlap_idx -= 1
|
||||
cur_len += over_lap_len
|
||||
# 获取merge的index
|
||||
while cur_idx < len(chunks):
|
||||
cur_len += len(chunks[cur_idx])
|
||||
if cur_len > self.chunk_size:
|
||||
break
|
||||
idxes.append(cur_idx)
|
||||
cur_idx += 1
|
||||
return idxes
|
||||
|
||||
def split_chunks(self, chunks: List[str]):
|
||||
"""对超过`chunk_size`限制的切片进行截断,过程中需要考虑overlap参数"""
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if len(chunk) <= self.chunk_size:
|
||||
final_chunks.append(chunk)
|
||||
else:
|
||||
start = 0
|
||||
end = self.chunk_size
|
||||
while end < len(chunk):
|
||||
final_chunks.append(chunk[start: end])
|
||||
start += self.chunk_size - self.chunk_overlap
|
||||
end = start + self.chunk_size
|
||||
final_chunks.append(chunk[start:])
|
||||
return final_chunks
|
||||
|
||||
def split_text_by_chunk_size(self, chunks: List[str]):
|
||||
"""对切片后超长的文本块进行二次切分,使用截断,并考虑overlap"""
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if len(chunk) <= self.chunk_size:
|
||||
final_chunks.append(chunk)
|
||||
continue
|
||||
sentences = TextSplitter.split_sentences(chunk)
|
||||
sub_chunks = self.merge_chunks(sentences)
|
||||
final_chunks.extend(self.split_chunks(sub_chunks))
|
||||
return final_chunks
|
||||
Reference in New Issue
Block a user