You've already forked DataMate
init datamate
This commit is contained in:
1
runtime/python-executor/datamate/__init__.py
Normal file
1
runtime/python-executor/datamate/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.0.1"
|
||||
0
runtime/python-executor/datamate/common/__init__.py
Normal file
0
runtime/python-executor/datamate/common/__init__.py
Normal file
98
runtime/python-executor/datamate/common/error_code.py
Normal file
98
runtime/python-executor/datamate/common/error_code.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The class hierarchy for built-in exceptions is:
|
||||
see https://docs.python.org/3/library/exceptions.html
|
||||
|
||||
BaseException
|
||||
├── SystemExit
|
||||
├── KeyboardInterrupt
|
||||
├── GeneratorExit
|
||||
└── Exception
|
||||
├── ArithmeticError
|
||||
│ ├── FloatingPointError
|
||||
│ ├── OverflowError
|
||||
│ └── ZeroDivisionError
|
||||
├── AssertionError
|
||||
├── AttributeError
|
||||
├── BufferError
|
||||
├── EOFError
|
||||
├── ImportError
|
||||
│ └── ModuleNotFoundError
|
||||
├── LookupError
|
||||
│ ├── IndexError
|
||||
│ └── KeyError
|
||||
├── MemoryError
|
||||
├── NameError
|
||||
│ └── UnboundLocalError
|
||||
├── OSError
|
||||
│ ├── BlockingIOError
|
||||
│ ├── ChildProcessError
|
||||
│ ├── ConnectionError
|
||||
│ │ ├── BrokenPipeError
|
||||
│ │ ├── ConnectionAbortedError
|
||||
│ │ ├── ConnectionRefusedError
|
||||
│ │ └── ConnectionResetError
|
||||
│ ├── FileExistsError
|
||||
│ ├── FileNotFoundError
|
||||
│ ├── InterruptedError
|
||||
│ ├── IsADirectoryError
|
||||
│ ├── NotADirectoryError
|
||||
│ ├── PermissionError
|
||||
│ ├── ProcessLookupError
|
||||
│ └── TimeoutError
|
||||
├── ReferenceError
|
||||
├── RuntimeError
|
||||
│ ├── NotImplementedError
|
||||
│ └── RecursionError
|
||||
├── SyntaxError
|
||||
│ └── IndentationError
|
||||
│ └── TabError
|
||||
├── SystemError
|
||||
├── TypeError
|
||||
├── ValueError
|
||||
│ └── UnicodeError
|
||||
│ ├── DecodeError
|
||||
│ ├── EncodeError
|
||||
│ └── UnicodeTranslateError
|
||||
└── Warning
|
||||
├── DeprecationWarning
|
||||
├── PendingDeprecationWarning
|
||||
├── RuntimeWarning
|
||||
├── SyntaxWarning
|
||||
├── UserWarning
|
||||
├── FutureWarning
|
||||
├── ImportWarning
|
||||
└── UnicodeWarning
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
ERROR_CODE_TABLE = {
|
||||
ImportError: "ops.0001",
|
||||
ModuleNotFoundError: "ops.0002",
|
||||
NameError: "ops.0003",
|
||||
KeyError: "ops.0004",
|
||||
IndexError: "ops.0005",
|
||||
ValueError: "ops.0006",
|
||||
TypeError: "ops.0007",
|
||||
SyntaxError: "ops.0008",
|
||||
AttributeError: "ops.0009",
|
||||
ArithmeticError: "ops.0010",
|
||||
MemoryError: "ops.0011",
|
||||
OSError: "ops.0012",
|
||||
FileNotFoundError: "ops.0013",
|
||||
NotADirectoryError: "ops.0014",
|
||||
PermissionError: "ops.0015",
|
||||
TimeoutError: "ops.0016",
|
||||
}
|
||||
|
||||
UNKNOWN_ERROR_CODE = "ops.9999"
|
||||
|
||||
|
||||
class ErrorCode(Enum):
|
||||
# 通用错误
|
||||
SUCCESS = (0, "Success")
|
||||
UNKNOWN_ERROR = (1, "Unknown error")
|
||||
FILE_NOT_FOUND_ERROR = (1000, "File not found!")
|
||||
SUBMIT_TASK_ERROR = (1001, "Task submitted Failed!")
|
||||
CANCEL_TASK_ERROR = (1002, "Task canceled Failed!")
|
||||
59
runtime/python-executor/datamate/common/utils/__init__.py
Normal file
59
runtime/python-executor/datamate/common/utils/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import pytz
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def check_valid_path(file_path):
|
||||
full_path = os.path.abspath(file_path)
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def get_realpath_with_prefix_check(path, prefix):
|
||||
realpath = os.path.realpath(path)
|
||||
|
||||
if realpath.startswith(prefix):
|
||||
return realpath
|
||||
else:
|
||||
raise ValueError(f"The path {realpath} does not start with the prefix '{prefix}'.")
|
||||
|
||||
|
||||
def bytes_to_numpy(image_bytes):
|
||||
"""bytes转数组"""
|
||||
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image_np2
|
||||
|
||||
|
||||
def numpy_to_bytes(image_np, file_type):
|
||||
"""数组转bytes"""
|
||||
if not image_np.size:
|
||||
return b""
|
||||
data = cv2.imencode(file_type, image_np)[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def get_now_time(timezone, time_format, file_name, method):
|
||||
timestamp = ""
|
||||
try:
|
||||
china_tz = pytz.timezone(timezone) # 设置时区
|
||||
china_time = datetime.now(tz=china_tz) # 获取当前时间并转换为对应的时区
|
||||
timestamp = china_time.strftime(time_format) # 格式化输出时间
|
||||
except ValueError as e:
|
||||
logger.error("fileName: %s, method: %s, formatting time failed: %s", file_name, method, e, exc_info=True)
|
||||
return timestamp
|
||||
|
||||
|
||||
def decrypt(enc_pass):
|
||||
import kmc.kmc as K
|
||||
os.environ['KMC_DATA_USER'] = 'modelenginepublic'
|
||||
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = enc_pass
|
||||
dec_pass = K.API().decrypt(0)
|
||||
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = ""
|
||||
return dec_pass
|
||||
115
runtime/python-executor/datamate/common/utils/aho_corasick.py
Normal file
115
runtime/python-executor/datamate/common/utils/aho_corasick.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# -- encoding: utf-8 --
|
||||
from collections import deque
|
||||
|
||||
|
||||
class TrieNode:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.child = dict()
|
||||
self.fail = None
|
||||
self.word = None
|
||||
|
||||
|
||||
class AhoCorasic:
|
||||
"""AC自动机算法进行目标字符串搜索"""
|
||||
|
||||
def __init__(self, words):
|
||||
self._root = add_fail_pointer(build_trie(words))
|
||||
|
||||
def search(self, text: str, special_symbols: set):
|
||||
"""
|
||||
匹配敏感词。
|
||||
|
||||
Args:
|
||||
text: 文本
|
||||
special_symbols: 特殊字符(需跳过)
|
||||
Returns:
|
||||
匹配成功的字符串列表
|
||||
"""
|
||||
seq_list = []
|
||||
node = self._root
|
||||
|
||||
valid_len = 0 # 当前遍历的有效长度
|
||||
for i, s in enumerate(text):
|
||||
if s in special_symbols: # 跳过特殊字符
|
||||
if valid_len != 0:
|
||||
valid_len += 1
|
||||
continue
|
||||
|
||||
matched = True
|
||||
while s not in node.child: # 当node.child没有字符s
|
||||
if node == self._root: # 当node为root(无node.fail),有效长度归0且跳出
|
||||
valid_len = 0
|
||||
matched = False
|
||||
break
|
||||
elif node.fail == self._root: # node.fail为root场景,有效长度归0,但可继续
|
||||
valid_len = 0
|
||||
node = node.fail # 移动到失败指针节点
|
||||
if not matched:
|
||||
continue
|
||||
|
||||
node = node.child.get(s)
|
||||
valid_len += 1
|
||||
if node.word: # node是单词尾字母
|
||||
sensitive_word = text[i - valid_len + 1:i + 1]
|
||||
seq_list.append(sensitive_word)
|
||||
seq_list = list(set(seq_list))
|
||||
return seq_list
|
||||
|
||||
|
||||
def build_trie(words: list):
|
||||
"""
|
||||
构建前缀树。
|
||||
|
||||
Args:
|
||||
words: 敏感词列表。
|
||||
Returns:
|
||||
前缀树根节点。
|
||||
"""
|
||||
root = TrieNode('root')
|
||||
for word in words:
|
||||
node = root
|
||||
for s in word:
|
||||
if s not in node.child:
|
||||
node.child[s] = TrieNode(s)
|
||||
node = node.child[s]
|
||||
if not node.word:
|
||||
node.word = {word}
|
||||
else:
|
||||
node.word.add(word)
|
||||
return root
|
||||
|
||||
|
||||
def add_fail_pointer(root: TrieNode):
|
||||
"""
|
||||
为前缀树添加失败指针。
|
||||
步骤:
|
||||
1. 从root开始逐层将node和node.parent以二元组存放队列。root没有fail指针,root.child的失败指针即为root。
|
||||
2. 对于root和root.child以外的node,查询node.parent.fail.child。
|
||||
3. 如果存在node.parent.fail.child.value == node.value,则构建node.fail = node.parent.fail.child.value。
|
||||
|
||||
Args:
|
||||
root: 前缀树根节点。
|
||||
returns:
|
||||
添加失败指针后的前缀树根节点。
|
||||
"""
|
||||
queue = deque()
|
||||
queue.appendleft((root, None))
|
||||
while len(queue) > 0:
|
||||
node_parent = queue.pop()
|
||||
curr, parent = node_parent[0], node_parent[1]
|
||||
for sub in curr.child.values():
|
||||
queue.appendleft((sub, curr))
|
||||
if parent is None:
|
||||
continue
|
||||
elif parent is root:
|
||||
curr.fail = root
|
||||
else:
|
||||
parent_fail = parent.fail
|
||||
while parent_fail and curr.value not in parent_fail.child:
|
||||
parent_fail = parent_fail.fail
|
||||
if parent_fail:
|
||||
curr.fail = parent_fail.child[curr.value]
|
||||
else:
|
||||
curr.fail = root
|
||||
return root
|
||||
@@ -0,0 +1,66 @@
|
||||
# -- encoding: utf-8 --
|
||||
import pickle
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def bytes_to_numpy(image_bytes):
|
||||
"""bytes转数组"""
|
||||
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
|
||||
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image_np2
|
||||
|
||||
|
||||
def numpy_to_bytes(image_np, file_type):
|
||||
"""
|
||||
数组转bytes
|
||||
|
||||
Params:
|
||||
|
||||
file_type: as required by OpenCV, extension must have a leading period.
|
||||
"""
|
||||
if not image_np.size:
|
||||
return b""
|
||||
data = cv2.imencode(file_type, image_np)[1]
|
||||
image_bytes = data.tobytes()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def pil_to_bytes(src: Image.Image) -> bytes:
|
||||
"""将 PIL.Image 转换为字节流"""
|
||||
# 确保图像是 RGB 模式
|
||||
src = src.convert("RGB")
|
||||
with BytesIO() as bytes_io:
|
||||
src.save(bytes_io, format='PNG')
|
||||
im_bytes = bytes_io.getvalue()
|
||||
return im_bytes
|
||||
|
||||
|
||||
def bytes_to_pil(src: bytes) -> Image.Image:
|
||||
"""将字节流转换为 PIL.Image"""
|
||||
with BytesIO() as bytes_io:
|
||||
with Image.open(bytes_io) as pil_img: # 使用with/as语句确保资源被正确释放
|
||||
pil_img.load() # 确保图像数据被加载
|
||||
return pil_img.copy() # 返回图像的副本以避免资源被关闭后无法使用
|
||||
|
||||
|
||||
def pil_to_base64(src: Image.Image):
|
||||
"""PIl.Image转base64"""
|
||||
with BytesIO() as img_buffer:
|
||||
src.save(img_buffer, format='png')
|
||||
byte_data = img_buffer.getvalue()
|
||||
base64_str = base64.b64encode(byte_data)
|
||||
return base64_str
|
||||
|
||||
|
||||
def obj_to_bytes(src: object) -> bytes:
|
||||
return pickle.dumps(src)
|
||||
|
||||
|
||||
def bytes_to_obj(src: bytes) -> object:
|
||||
return pickle.loads(src)
|
||||
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import importlib.abc
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CustomImporter(importlib.abc.MetaPathFinder):
|
||||
def __init__(self, base_path):
|
||||
self.base_path = Path(base_path).resolve()
|
||||
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
# 将模块名转换为路径(例如:mypkg.mymodule -> mypkg/mymodule.py)
|
||||
parts = fullname.split(".")
|
||||
module_path = self.base_path.joinpath(*parts)
|
||||
|
||||
# 检查是否存在 .py 文件或目录
|
||||
if module_path.with_suffix(".py").exists():
|
||||
return importlib.util.spec_from_file_location(
|
||||
fullname,
|
||||
str(module_path.with_suffix(".py")),
|
||||
submodule_search_locations=[str(module_path.parent)]
|
||||
)
|
||||
elif module_path.is_dir() and (module_path / "__init__.py").exists():
|
||||
return importlib.util.spec_from_file_location(
|
||||
fullname,
|
||||
str(module_path / "__init__.py"),
|
||||
submodule_search_locations=[str(module_path)]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
229
runtime/python-executor/datamate/common/utils/lazy_loader.py
Normal file
229
runtime/python-executor/datamate/common/utils/lazy_loader.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import importlib
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from loguru import logger
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
|
||||
def is_valid_whl_filename(package_name):
|
||||
"""
|
||||
验证WHL文件名是否安全(聚焦防范命令注入攻击)
|
||||
|
||||
Args:
|
||||
package_name (str): 要验证的whl文件名
|
||||
|
||||
Returns:
|
||||
bool: True表示安全,False表示包含危险字符
|
||||
"""
|
||||
|
||||
# 禁止路径分隔符
|
||||
if os.path.sep in package_name or (os.path.altsep and os.path.altsep in package_name):
|
||||
return False
|
||||
|
||||
# 定义危险字符黑名单
|
||||
dangerous_pattern = re.compile(r"""
|
||||
[ # 匹配任意以下字符
|
||||
; & | ` # 命令分隔符
|
||||
$ () {} <> # 变量展开/命令替换
|
||||
\] \[ # 特殊符号
|
||||
! \\ # 历史扩展/转义符
|
||||
'"*?#\s # 引号/通配符/空格
|
||||
]
|
||||
""", re.VERBOSE)
|
||||
|
||||
if dangerous_pattern.search(package_name):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class PackageNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LazyLoader(ModuleType):
|
||||
def __init__(self,
|
||||
package_name,
|
||||
module_name=None,
|
||||
whl_path="/dataset/ops_whl",
|
||||
exact_version=None,
|
||||
force_reinstall=False
|
||||
):
|
||||
"""
|
||||
:param package_name: WHL包名称中的模块名称部分(一般是模块名称_替换为-)
|
||||
:param module_name: WHL包安装后,可用于import的模块名称, 当whl包名称和import名称不一致时,填写local_name.
|
||||
:param whl_path: WHL文件所在目录
|
||||
:param exact_version: 精确版本要求
|
||||
:param force_reinstall: 强制重新安装
|
||||
"""
|
||||
try:
|
||||
frame = sys._getframe(1)
|
||||
self._parent_globals = frame.f_globals
|
||||
except (AttributeError, ValueError) as e:
|
||||
logger.error(f"Failed to get stack frame: {e}")
|
||||
raise RuntimeError("Stack frame retrieval failed") from e
|
||||
|
||||
self._module_name = module_name if module_name else package_name
|
||||
self._package_name = package_name
|
||||
|
||||
self.whl_path = Path(whl_path).resolve()
|
||||
self.exact_version = exact_version
|
||||
|
||||
self.force_reinstall = force_reinstall
|
||||
self._cached_module = None
|
||||
|
||||
# 注册别名到父级命名空间
|
||||
self._parent_globals[self._module_name] = self
|
||||
super().__init__(self._module_name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if self._cached_module is None:
|
||||
self._cached_module = self._load_module()
|
||||
return getattr(self._cached_module, name)
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._load_module())
|
||||
|
||||
def _load_module(self):
|
||||
"""模块加载逻辑"""
|
||||
|
||||
if self._cached_module is not None:
|
||||
return self._cached_module
|
||||
|
||||
package_name: str = self._package_name.split('.')[0]
|
||||
if not is_valid_whl_filename(package_name):
|
||||
logger.error(f"Invalid package_name: {package_name}")
|
||||
raise RuntimeError("Invalide package_name, please check it again!")
|
||||
|
||||
module_name = self._module_name if self._module_name else package_name.replace("_", "-")
|
||||
|
||||
need_install = False
|
||||
|
||||
try:
|
||||
if not self.force_reinstall:
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
return module
|
||||
except ImportError:
|
||||
need_install = True
|
||||
|
||||
if self.force_reinstall:
|
||||
# 强制安装时的版本检查
|
||||
installed = self._check_package_exists(package_name)
|
||||
if installed and self.exact_version:
|
||||
installed_version = self._get_installed_version(package_name)
|
||||
if parse_version(installed_version) != parse_version(self.exact_version):
|
||||
logger.info(f"Version mismatch detected: {installed_version} vs {self.exact_version}")
|
||||
need_install = True
|
||||
else:
|
||||
need_install = True
|
||||
|
||||
if need_install:
|
||||
self._pip_install_package(package_name)
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
else:
|
||||
# 版本检查通过,无需再次安装
|
||||
module = importlib.import_module(module_name)
|
||||
self._cached_module = module
|
||||
self._register_alias(module)
|
||||
|
||||
return self._cached_module
|
||||
|
||||
def _register_alias(self, module):
|
||||
"""注册本地别名 """
|
||||
self._parent_globals[self._module_name] = module
|
||||
sys.modules[self._module_name] = module
|
||||
|
||||
def _check_package_exists(self, package_name):
|
||||
"""增强版包检查"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", package_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
return result.returncode == 0
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Package check failed: {e}")
|
||||
return False
|
||||
|
||||
def _get_installed_version(self, package_name):
|
||||
"""获取已安装版本 """
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", package_name],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if line.startswith('Version:'):
|
||||
return line.split()[-1]
|
||||
raise PackageNotFoundError()
|
||||
|
||||
def _pip_install_package(self, package_name: str):
|
||||
"""安装逻辑 """
|
||||
|
||||
if not self.whl_path.exists():
|
||||
raise FileNotFoundError(f"WHL directory not found: {self.whl_path}")
|
||||
|
||||
whl_files = list(self.whl_path.glob(f"{package_name}*.whl"))
|
||||
if not whl_files:
|
||||
raise RuntimeError(f"No WHL files found for {package_name}")
|
||||
|
||||
# 版本过滤
|
||||
if self.exact_version:
|
||||
pattern = re.compile(
|
||||
rf'^{re.escape(package_name)}-{re.escape(self.exact_version)}-\S*\.whl$',
|
||||
re.IGNORECASE
|
||||
)
|
||||
whl_files = [f for f in whl_files if pattern.match(f.name)]
|
||||
|
||||
# 选择最新版本
|
||||
whl_versions = []
|
||||
for f in whl_files:
|
||||
try:
|
||||
version = self._extract_version(f)
|
||||
whl_versions.append((f, version))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not whl_versions:
|
||||
raise FileNotFoundError("No valid WHL files")
|
||||
|
||||
whl_versions.sort(key=lambda x: x[1], reverse=True)
|
||||
target_whl = whl_versions[0][0]
|
||||
|
||||
# 执行安装
|
||||
try:
|
||||
subprocess.check_call([
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"--no-index",
|
||||
f"--find-links={self.whl_path}",
|
||||
str(target_whl)
|
||||
], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
logger.info(f"Successfully installed {target_whl}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Installation failed: {e}")
|
||||
raise RuntimeError(f"Installation failed: {e}") from e
|
||||
|
||||
def _extract_version(self, filename):
|
||||
"""版本解析 """
|
||||
|
||||
version_pattern = r"(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)"
|
||||
|
||||
match = re.search(
|
||||
rf"^{re.escape(self._package_name)}-({version_pattern})",
|
||||
filename.name,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {filename.name}")
|
||||
return parse_version(match.group(1))
|
||||
118
runtime/python-executor/datamate/common/utils/llm_request.py
Normal file
118
runtime/python-executor/datamate/common/utils/llm_request.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import requests
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from datamate.common.utils import decrypt
|
||||
|
||||
|
||||
class LlmReq:
|
||||
# 定义常量用于解释错误码
|
||||
ERRORCODE_INCOMPLETE_CONFIG = 83005
|
||||
ERRORCODE_INVALID_RESPONSE = 83006
|
||||
ERRORCODE_SERVICE_UNAVAILABLE = 83007
|
||||
|
||||
def __init__(self, url: str = None, header: Dict = None, body: Dict = None, access_type: int = None,
|
||||
is_https: bool = False, is_certificate: bool = False, certificate_path: Path = None):
|
||||
self.url = url
|
||||
self.header = header
|
||||
self.access_type = access_type
|
||||
self.is_https = is_https
|
||||
self.context = self._load_certificate(certificate_path, is_certificate) if is_https else None
|
||||
self.body = body
|
||||
if not self.body.get("messages", [])[0].get("content"):
|
||||
self.body["messages"][0]["content"] = "你好"
|
||||
|
||||
def __call__(self, input_str: str) -> str:
|
||||
outputs = ''
|
||||
try:
|
||||
self.body["messages"][0]["content"] = input_str
|
||||
outputs = self._call_service()
|
||||
except KeyError as e:
|
||||
logger.error(f"The body format is not completed, error detail: {e}")
|
||||
self.body["messages"][0]["content"] = "你好"
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _load_certificate(certificate_path: Path, is_certificate: bool) -> ssl.SSLContext:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
context.check_hostname = False
|
||||
if is_certificate:
|
||||
context.load_verify_locations(certificate_path)
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _pool_manager():
|
||||
cert_file = os.getenv("RAY_TLS_SERVER_CERT", "/certPersonal/global/identity/global.crt")
|
||||
key_file = os.getenv("RAY_TLS_SERVER_KEY", "/certPersonal/global/identity/global.key")
|
||||
ca_crt = os.getenv("RAY_TLS_CA_CERT", "/certPersonal/global/trust/ca.crt")
|
||||
pwd = os.getenv("RAY_TLS_SERVER_KEY_PASSWORD", "/certPersonal/global/identity/pwd.txt")
|
||||
key_password = os.getenv("GLOBAL_PWD", None)
|
||||
if not key_password:
|
||||
with open(pwd, "r") as f:
|
||||
key_password = f.read().strip()
|
||||
key_password = decrypt(key_password)
|
||||
|
||||
pool_manager = urllib3.PoolManager(cert_file=cert_file,
|
||||
key_file=key_file,
|
||||
key_password=key_password,
|
||||
cert_reqs='CERT_REQUIRED',
|
||||
ca_certs=ca_crt,
|
||||
assert_hostname='edatamate',
|
||||
ssl_version='TLSv1_2')
|
||||
return pool_manager
|
||||
|
||||
def _call_service(self):
|
||||
if not all([self.url, self.header, self.body.get("messages", [])[0].get("content")]):
|
||||
logger.error("LLM is not configured completely")
|
||||
raise RuntimeError(self.ERRORCODE_INCOMPLETE_CONFIG, "LLM is not configured completely") from None
|
||||
if not self.access_type:
|
||||
try:
|
||||
pool_manager = self._pool_manager()
|
||||
response = pool_manager.request(
|
||||
"POST",
|
||||
url=self.url,
|
||||
body=json.dumps(self.body).encode(),
|
||||
headers=self.header
|
||||
)
|
||||
logger.info(f"Response status code: {response.status}")
|
||||
response_json = json.loads(response.data.decode('utf-8'))
|
||||
outputs = response_json.get("choices", [])[0].get("message", {}).get("content")
|
||||
if not outputs:
|
||||
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
|
||||
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
|
||||
"Invalid response format for LLM, missing the 'prompt' key word") from None
|
||||
return outputs
|
||||
except Exception as e:
|
||||
logger.error(f"LLM service is not available, error detail: {e}")
|
||||
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
|
||||
if self.access_type:
|
||||
try:
|
||||
if self.is_https:
|
||||
req = Request(url=self.url, data=json.dumps(self.body).encode(), headers=self.header, method="POST")
|
||||
response_json = urlopen(req, context=self.context).read().decode("utf-8")
|
||||
response = json.loads(response_json)
|
||||
else:
|
||||
response = requests.post(url=self.url, data=json.dumps(self.body), headers=self.header,
|
||||
stream=False).json()
|
||||
outputs = response.get("choices", [])[0].get("message", {}).get("content")
|
||||
if not outputs:
|
||||
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
|
||||
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
|
||||
"Invalid response format for LLM, missing the 'prompt' key word") from None
|
||||
return outputs
|
||||
except Exception as e:
|
||||
logger.error(f"LLM service is not available, error detail: {e}")
|
||||
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
|
||||
return None # 确保在所有情况下都返回
|
||||
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from packaging.version import parse as parse_version, Version
|
||||
|
||||
|
||||
def install_whl(
|
||||
package_name: str,
|
||||
whl_path: str,
|
||||
exact_version: Optional[str] = None,
|
||||
filename_pattern: Optional[str] = None,
|
||||
force_reinstall: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param package_name: eg: ("zh_core_web_sm")
|
||||
:param whl_path: WHL file save path
|
||||
:param exact_version: version number
|
||||
:param filename_pattern: custom filename pattern for REGEX
|
||||
:param force_reinstall: which decide to overlap the original number or not (default: False)
|
||||
"""
|
||||
|
||||
whl_dir = Path(whl_path).resolve()
|
||||
whl_files = _get_whl_files(exact_version, filename_pattern, package_name, whl_dir)
|
||||
|
||||
# 语义化版本排序
|
||||
target_whl = _sort_whl_files(whl_files)
|
||||
|
||||
# 安装命令
|
||||
cmd = [
|
||||
sys.executable, "-m", "pip", "install",
|
||||
"--no-index",
|
||||
f"--find-links={whl_dir}",
|
||||
str(target_whl)
|
||||
]
|
||||
if force_reinstall:
|
||||
cmd.append("--force-reinstall")
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
cmd,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
logger.info(f"Successfully installed {target_whl.name}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = (
|
||||
f"Installation failed for {package_name}\n"
|
||||
f"Possible reasons:\n"
|
||||
f"1. Missing dependencies in {whl_dir}\n"
|
||||
f"2. Incompatible Python version\n"
|
||||
f"3. Platform mismatch (e.g., x86 vs ARM)"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
|
||||
def _sort_whl_files(whl_files):
|
||||
whl_versions = []
|
||||
logger.info(f"[load_offline_module]whl_files: {whl_files}")
|
||||
for f in whl_files:
|
||||
try:
|
||||
version = _extract_version(f)
|
||||
whl_versions.append((f, version))
|
||||
except ValueError as e:
|
||||
logger.warning(f"Skipping invalid file {f.name}: {e}")
|
||||
continue
|
||||
if not whl_versions:
|
||||
raise FileNotFoundError("No valid WHL files with parseable versions")
|
||||
whl_versions.sort(key=lambda x: x[1], reverse=True)
|
||||
target_whl = whl_versions[0][0]
|
||||
return target_whl
|
||||
|
||||
|
||||
def _get_whl_files(exact_version, filename_pattern, package_name, whl_dir):
|
||||
# 正则表达式
|
||||
if filename_pattern:
|
||||
pattern = filename_pattern
|
||||
else:
|
||||
if exact_version:
|
||||
version_part = re.escape(exact_version)
|
||||
pattern = rf"^{re.escape(package_name)}-{version_part}-\S*\.whl$"
|
||||
else:
|
||||
pattern = rf"^{re.escape(package_name)}\S*\.whl$"
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
whl_files = [f for f in whl_dir.glob("*.whl") if regex.match(f.name)]
|
||||
if not whl_files:
|
||||
available_files = "\n".join([f.name for f in whl_dir.glob("*.whl")])
|
||||
raise FileNotFoundError(
|
||||
f"No matching WHL found for {package_name} in {whl_dir}\n"
|
||||
f"Available files:\n{available_files}"
|
||||
)
|
||||
return whl_files
|
||||
|
||||
|
||||
def _extract_version(filename: Path) -> Version:
|
||||
"""从文件名提取语义化版本("""
|
||||
|
||||
match = re.search(
|
||||
r"-(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)-",
|
||||
filename.name
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {filename.name}")
|
||||
return parse_version(match.group(1))
|
||||
102
runtime/python-executor/datamate/common/utils/registry.py
Normal file
102
runtime/python-executor/datamate/common/utils/registry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Registry(object):
|
||||
"""注册器类,用于注册所有的算子."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
||||
self._name = name
|
||||
self._modules = {}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def modules(self):
|
||||
return self._modules
|
||||
|
||||
def list(self):
|
||||
"""日志打印注册器中所有的算子"""
|
||||
for m in self._modules.keys():
|
||||
logger.info(f'{self._name}\t{m}')
|
||||
|
||||
def get(self, module_key):
|
||||
return self._modules.get(module_key, None)
|
||||
|
||||
def register_module(self, module_name: str = None, module_cls: type = None, module_path: str = None, force=False):
|
||||
"""
|
||||
使用特定的模块名称注册模块
|
||||
|
||||
:param module_name: 模块名称
|
||||
:param module_cls: 模块类定义
|
||||
:param module_path: 模块所在的路径
|
||||
:param force: 是否强行覆盖同名模块,默认值为False.
|
||||
|
||||
Example:
|
||||
>>> registry = Registry()
|
||||
>>> @registry.register_module()
|
||||
>>> class TextFormatter:
|
||||
>>> pass
|
||||
|
||||
>>> class TextFormatter2:
|
||||
>>> pass
|
||||
>>> registry.register_module( module_name='text_formatter2', module_cls=TextFormatter2)
|
||||
"""
|
||||
if not (module_name is None or isinstance(module_name, str)):
|
||||
raise TypeError(f'module_name must be either of None, str,'
|
||||
f'got {type(module_name)}')
|
||||
if module_cls is not None:
|
||||
self._register_module(module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
force=force)
|
||||
return module_cls
|
||||
|
||||
elif module_cls is None and isinstance(module_path, str):
|
||||
self._register_module(module_name=module_name,
|
||||
module_path=module_path,
|
||||
force=force)
|
||||
return module_path
|
||||
|
||||
def _register(module_cls):
|
||||
"""
|
||||
注册其中module_cls为None是,返回装饰器函数
|
||||
"""
|
||||
self._register_module(module_name=module_name,
|
||||
module_cls=module_cls,
|
||||
force=force)
|
||||
return module_cls
|
||||
|
||||
return _register
|
||||
|
||||
def _register_module(self, module_name=None, module_cls=None, module_path=None, force=False):
|
||||
"""
|
||||
注册模块到注册器中.
|
||||
|
||||
:param module_name: 模块名称
|
||||
:param module_cls: 模块类定义
|
||||
:param force: 是否强行覆盖同名模块,默认值为False.
|
||||
"""
|
||||
|
||||
if module_name is None and module_cls is not None:
|
||||
module_name = module_cls.__name__
|
||||
|
||||
if module_name in self._modules:
|
||||
if module_cls is not None and module_cls == self._modules[module_name]:
|
||||
return
|
||||
|
||||
if module_path is not None and module_path == self._modules[module_name]:
|
||||
return
|
||||
|
||||
if not force:
|
||||
raise KeyError(
|
||||
f'{module_name} is already registered in {self._name}, content: {self.modules.keys()}')
|
||||
|
||||
if module_cls is not None:
|
||||
self._modules[module_name] = module_cls
|
||||
|
||||
elif module_path is not None:
|
||||
self._modules[module_name] = module_path
|
||||
171
runtime/python-executor/datamate/common/utils/text_splitter.py
Normal file
171
runtime/python-executor/datamate/common/utils/text_splitter.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/user/bin/python
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
from collections import deque
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
"""文本切片"""
|
||||
# 基于常用标点符号分句,保持句子完整
|
||||
COMMON_PUNCTUATIONS = [",", "。", "?", "!", ";", ",", "?", "!", ";"]
|
||||
PUNC_PATTERN = f"[{''.join(COMMON_PUNCTUATIONS)}]"
|
||||
|
||||
def __init__(self, max_characters: int, chunk_size: int, chunk_overlap: int):
|
||||
"""文本切片初始化
|
||||
Args:
|
||||
max_characters :文件最大字符,超过截断,-1不处理
|
||||
chunk_size: 块大小
|
||||
chunk_overlap: 块重叠度
|
||||
"""
|
||||
if chunk_size <= chunk_overlap:
|
||||
logger.error(f"param chunk_size should larger than chunk_overlap, "
|
||||
f"current chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")
|
||||
raise Exception(83000, str(ValueError)) from None
|
||||
self.max_characters = max_characters
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.separators = ["\n\n", "\n"]
|
||||
|
||||
@staticmethod
|
||||
def split_text_by_separator(text: str, separator: str):
|
||||
"""指定分隔符对文本进行切分,并且切分后的片段需要保留分隔符"""
|
||||
# 处理一个换行符与两个换行符之间的冲突
|
||||
if text.startswith("\n\n") and separator == "\n":
|
||||
chunks = re.split(f"({separator})", text.strip())
|
||||
chunks[0] = f"\n\n{chunks[0]}"
|
||||
else:
|
||||
chunks = re.split(f"({separator})", text)
|
||||
new_chunks = [chunks[idx] + chunks[idx + 1] for idx in range(1, len(chunks), 2)]
|
||||
new_chunks = [chunks[0]] + new_chunks
|
||||
return [chunk for chunk in new_chunks if chunk.strip() != ""]
|
||||
|
||||
@staticmethod
|
||||
def split_sentences(chunk: str):
|
||||
"""对切片按照标点符号切分成句子,并且保持标点符号不丢失"""
|
||||
sentences = re.split(TextSplitter.PUNC_PATTERN, chunk)
|
||||
delimiters = [s for s in chunk if s in TextSplitter.COMMON_PUNCTUATIONS]
|
||||
restore_chunks = []
|
||||
for chunk, delimiter in zip(sentences[:-1], delimiters):
|
||||
restore_chunks.append(chunk + delimiter)
|
||||
return restore_chunks + [sentences[-1]]
|
||||
|
||||
def split_text(self, input_data: str):
|
||||
if self.max_characters > 0:
|
||||
logger.info(f"The document characters should be within: {self.max_characters}")
|
||||
input_data = input_data[:self.max_characters]
|
||||
logger.info(f"characters of the document: {len(input_data)}")
|
||||
chunks = self.split_text_recursive(input_data, self.separators)
|
||||
final_chunks = self.merge_chunks(chunks)
|
||||
final_chunks = self.split_text_by_chunk_size(final_chunks)
|
||||
return [chunk.strip() for chunk in final_chunks if chunk]
|
||||
|
||||
def split_text_recursive(self, input_data: str, separators: List[str]):
|
||||
"""对文档按照分隔符优先级进行递归切分:
|
||||
1. 符合chunk_size要求的切片不再切分。
|
||||
2. 大于chunk_size要求的切片,继续进行递归切分。
|
||||
Args:
|
||||
input_data: 输入文本
|
||||
separators: 分隔符
|
||||
|
||||
Returns:
|
||||
List[str]: 切分后的文本片段
|
||||
|
||||
"""
|
||||
chunks = []
|
||||
cur_separator = ""
|
||||
next_separators = []
|
||||
for idx, sep in enumerate(separators):
|
||||
sep = re.escape(sep)
|
||||
if re.search(sep, input_data.strip()):
|
||||
cur_separator = sep
|
||||
next_separators = separators[idx + 1:]
|
||||
break
|
||||
|
||||
if not cur_separator:
|
||||
return [input_data]
|
||||
else:
|
||||
cur_chunks = TextSplitter.split_text_by_separator(input_data, cur_separator)
|
||||
|
||||
for chunk in cur_chunks:
|
||||
if len(chunk.strip()) <= self.chunk_size:
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
if not next_separators:
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
next_chunks = self.split_text_recursive(chunk, next_separators)
|
||||
chunks.extend(next_chunks)
|
||||
return chunks
|
||||
|
||||
def merge_chunks(self, chunks: List[str]):
|
||||
"""对切分后的文本片段进行合并,合并过程考虑overlap"""
|
||||
final_chunks = []
|
||||
idx = 0
|
||||
while idx < len(chunks):
|
||||
if len(chunks[idx]) >= self.chunk_size:
|
||||
final_chunks.append(chunks[idx])
|
||||
idx += 1
|
||||
continue
|
||||
merge_idxes = self.get_merge_idxes(idx, chunks)
|
||||
content = ""
|
||||
for inner_idx in merge_idxes:
|
||||
content += chunks[inner_idx]
|
||||
final_chunks.append(content)
|
||||
idx = merge_idxes[-1] + 1
|
||||
return final_chunks
|
||||
|
||||
def get_merge_idxes(self, cur_idx: int, chunks: List[str]):
|
||||
"""获取可以合并的分片index,前向尽可能满足overlap,后向尽可能满足chunk_size"""
|
||||
idxes = deque([cur_idx])
|
||||
overlap_idx = cur_idx - 1
|
||||
cur_len = len(chunks[cur_idx])
|
||||
cur_idx += 1
|
||||
# 获取overlap的index
|
||||
over_lap_len = 0
|
||||
while overlap_idx >= 0:
|
||||
over_lap_len += len(chunks[overlap_idx])
|
||||
if over_lap_len > self.chunk_overlap or (cur_len + over_lap_len) > self.chunk_size:
|
||||
over_lap_len -= len(chunks[overlap_idx])
|
||||
break
|
||||
idxes.appendleft(overlap_idx)
|
||||
overlap_idx -= 1
|
||||
cur_len += over_lap_len
|
||||
# 获取merge的index
|
||||
while cur_idx < len(chunks):
|
||||
cur_len += len(chunks[cur_idx])
|
||||
if cur_len > self.chunk_size:
|
||||
break
|
||||
idxes.append(cur_idx)
|
||||
cur_idx += 1
|
||||
return idxes
|
||||
|
||||
def split_chunks(self, chunks: List[str]):
|
||||
"""对超过`chunk_size`限制的切片进行截断,过程中需要考虑overlap参数"""
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if len(chunk) <= self.chunk_size:
|
||||
final_chunks.append(chunk)
|
||||
else:
|
||||
start = 0
|
||||
end = self.chunk_size
|
||||
while end < len(chunk):
|
||||
final_chunks.append(chunk[start: end])
|
||||
start += self.chunk_size - self.chunk_overlap
|
||||
end = start + self.chunk_size
|
||||
final_chunks.append(chunk[start:])
|
||||
return final_chunks
|
||||
|
||||
def split_text_by_chunk_size(self, chunks: List[str]):
|
||||
"""对切片后超长的文本块进行二次切分,使用截断,并考虑overlap"""
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if len(chunk) <= self.chunk_size:
|
||||
final_chunks.append(chunk)
|
||||
continue
|
||||
sentences = TextSplitter.split_sentences(chunk)
|
||||
sub_chunks = self.merge_chunks(sentences)
|
||||
final_chunks.extend(self.split_chunks(sub_chunks))
|
||||
return final_chunks
|
||||
0
runtime/python-executor/datamate/core/__init__.py
Normal file
0
runtime/python-executor/datamate/core/__init__.py
Normal file
381
runtime/python-executor/datamate/core/base_op.py
Normal file
381
runtime/python-executor/datamate/core/base_op.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from datamate.common.error_code import ERROR_CODE_TABLE, UNKNOWN_ERROR_CODE
|
||||
from datamate.common.utils.llm_request import LlmReq
|
||||
from datamate.common.utils.registry import Registry
|
||||
from datamate.common.utils import check_valid_path
|
||||
from datamate.core.constant import Fields
|
||||
from datamate.sql_manager.persistence_atction import TaskInfoPersistence
|
||||
|
||||
OPERATORS = Registry('Operators')
|
||||
|
||||
FAILED_STATUS = "FAILED"
|
||||
SUCCESS_STATUS = "COMPLETED"
|
||||
|
||||
|
||||
def get_exception_info(e):
|
||||
exc_type = type(e).__name__ # 异常类型(如 'ZeroDivisionError')
|
||||
exc_msg = str(e) # 异常原因(如 'division by zero')
|
||||
|
||||
# 提取详细的堆栈信息
|
||||
tb = traceback.extract_tb(e.__traceback__) # 解析 traceback 对象
|
||||
error_line = tb[-1].lineno # 错误发生的行号
|
||||
error_file = tb[-1].filename # 错误发生的文件名
|
||||
code_snippet = tb[-1].line # 错误行的代码
|
||||
|
||||
# 组合输出信息
|
||||
error_info = (
|
||||
f"错误类型: {exc_type}\n"
|
||||
f"错误原因: {exc_msg}\n"
|
||||
f"文件名: {error_file}\n"
|
||||
f"行号: {error_line}\n"
|
||||
f"代码行: {code_snippet}"
|
||||
)
|
||||
return error_info
|
||||
|
||||
|
||||
class BaseOp:
|
||||
"""
|
||||
所有算子类的父类
|
||||
"""
|
||||
|
||||
use_model = False
|
||||
custom_ops = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.accelerator = kwargs.get('accelerator', "cpu")
|
||||
self.is_last_op = kwargs.get('is_last_op', False)
|
||||
self._name = kwargs.get('op_name', None)
|
||||
self.infer_model = None
|
||||
self.text_key = kwargs.get('text_key', "text")
|
||||
self.data_key = kwargs.get('data_key', "data")
|
||||
self.image_key = kwargs.get('image_key', "image")
|
||||
self.video_key = kwargs.get('video_key', "video")
|
||||
self.audio_key = kwargs.get('audio_key', "audio")
|
||||
self.filename_key = kwargs.get('fileName_key', "fileName")
|
||||
self.filetype_key = kwargs.get('fileType_key', "fileType")
|
||||
self.fileid_key = kwargs.get('fileId_key', "fileId")
|
||||
self.filepath_key = kwargs.get('filePath_key', "filePath")
|
||||
self.filesize_key = kwargs.get('fileSize_key', "fileSize")
|
||||
self.export_path_key = kwargs.get('export_path_key', "export_path")
|
||||
self.ext_params_key = kwargs.get('ext_params_key', "ext_params")
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if self._name:
|
||||
return self._name
|
||||
else:
|
||||
return "UnknownOp"
|
||||
|
||||
@staticmethod
|
||||
def is_npu_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch_npu.npu.is_available()
|
||||
except ImportError as e:
|
||||
logger.warning("Import torch_npu failed.")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_kwargs(sample: Dict[str, Any], not_update_keys=("text", "data", "meta")) -> Dict:
|
||||
"""获取sample_data中文件相关的信息"""
|
||||
res = {}
|
||||
for k, v in sample.items():
|
||||
if k not in not_update_keys:
|
||||
res[k] = v
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _get_error_info(e: BaseException) -> Tuple[str, str]:
|
||||
|
||||
error_code = UNKNOWN_ERROR_CODE
|
||||
exc_info = get_exception_info(e)
|
||||
|
||||
for exc_type in type(e).__mro__:
|
||||
if exc_type in ERROR_CODE_TABLE.keys():
|
||||
error_code = ERROR_CODE_TABLE[exc_type]
|
||||
break
|
||||
|
||||
return error_code, exc_info
|
||||
|
||||
def use_npu(self):
|
||||
"""确认算子是否可以使用npu"""
|
||||
return self.accelerator == 'npu' and self.is_npu_available()
|
||||
|
||||
def get_model(self, *args, **kwargs):
|
||||
if self.infer_model is None and self.use_model:
|
||||
return self.init_model(*args, **kwargs)
|
||||
else:
|
||||
logger.info(f"Op named {self.name} get infer model Failed. please "
|
||||
f" check Attribute self.use_model: {self.use_model} or model has been initialized!")
|
||||
return self.infer_model
|
||||
|
||||
def init_model(self, *args, **kwargs):
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in BaseOp, plese re-define this method in Sub-classes")
|
||||
|
||||
def fill_sample_params(self, sample: Dict[str, Any], **kwargs):
|
||||
if not sample.get("text", None):
|
||||
sample[self.text_key] = ""
|
||||
|
||||
if not sample.get("data", None):
|
||||
sample[self.data_key] = b""
|
||||
|
||||
if not sample[self.data_key] and not sample[self.text_key]:
|
||||
sample.update(kwargs)
|
||||
|
||||
def create_failure_sample(self, sample: Dict[str, Any], op_name, excp: BaseException):
|
||||
sample["execute_result"] = False
|
||||
error_code, exc_info = self._get_error_info(excp)
|
||||
failed_reason = {"op_name": op_name, "error_code": error_code, "reason": exc_info}
|
||||
sample["failed_reason"] = failed_reason
|
||||
|
||||
|
||||
class Mapper(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Mapper, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return sample
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return sample
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
|
||||
|
||||
|
||||
class Slicer(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Slicer, self).__init__(*args, **kwargs)
|
||||
self.target_file_type = None
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
sample_list = []
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample_list = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
self.load_sample_to_sample(sample, sample_list)
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return [sample]
|
||||
|
||||
self.load_sample_to_sample(sample, sample_list)
|
||||
sample["execute_status"] = execute_status
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
|
||||
return [sample]
|
||||
|
||||
@staticmethod
|
||||
def load_sample_to_sample(sample: Dict, sample_list: List[Dict]):
|
||||
"""使用sample中的k-v更新sample"""
|
||||
for sample_i in sample_list:
|
||||
for k, v in sample_i.items():
|
||||
sample[k] = v
|
||||
if not sample.get("fileNum", None):
|
||||
sample["fileNum"] = 1
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> List[Dict]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
|
||||
|
||||
def save_patch_sample(self, sample: Dict[str, Any], patch_no, save_format="text"):
|
||||
if save_format == "text":
|
||||
target_file_type = 'txt'
|
||||
elif save_format == "image":
|
||||
target_file_type = 'png'
|
||||
else:
|
||||
target_file_type = None
|
||||
raise RuntimeError(f"target file type is {target_file_type}!")
|
||||
|
||||
if self.target_file_type:
|
||||
target_file_type = self.target_file_type
|
||||
save_path = self.get_save_path(sample, patch_no, target_file_type)
|
||||
self.save_file(sample, save_path)
|
||||
|
||||
def get_save_path(self, sample: Dict[str, Any], patch_no, target_type) -> str:
|
||||
export_path = os.path.abspath(sample[self.export_path_key])
|
||||
logger.info(f"export path: {export_path}.")
|
||||
base_file_name, _ = os.path.splitext(sample[self.filename_key])
|
||||
file_id = str(sample[self.fileid_key])
|
||||
new_file_name = file_id + '_' + str(patch_no) + '.' + target_type
|
||||
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
|
||||
if not check_valid_path(export_path):
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
res = os.path.join(export_path, new_file_name)
|
||||
return res
|
||||
|
||||
def save_file(self, sample, save_path):
|
||||
# 以二进制格式保存文件
|
||||
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(file_sample)
|
||||
|
||||
os.chmod(save_path, 0o640)
|
||||
try:
|
||||
parent_dir = os.path.dirname(save_path)
|
||||
os.chmod(parent_dir, 0o770)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to modify the permission on the parent_dir.")
|
||||
|
||||
logger.info(f"patch sample has been save to {save_path}.")
|
||||
|
||||
|
||||
class Filter(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Filter, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 如果filter算子过滤失败, 不保留文件, 并记录文件执行信息到数据库
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
sample["execute_status"] = execute_status
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return False
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 文件无内容会被过滤
|
||||
if sample[self.text_key] == "" and sample[self.data_key] == b"":
|
||||
task_info = TaskInfoPersistence()
|
||||
sample["fileSize"] = "0"
|
||||
task_info.persistence_task_info(sample)
|
||||
return False
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return True
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Filter Class, plese re-define this method in Sub-classes")
|
||||
|
||||
|
||||
class LLM(Mapper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LLM, self).__init__(*args, **kwargs)
|
||||
self.llm = self.get_llm(*args, **kwargs)
|
||||
self.prompt_template = None
|
||||
|
||||
self.target_file_type = None
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
|
||||
|
||||
@staticmethod
|
||||
def get_llm(*args, **kwargs):
|
||||
url = kwargs.get("LLMUrl", '')
|
||||
header = kwargs.get("LLMHeaders", {"Content-type": "application/json"})
|
||||
body = kwargs.get("LLMBody", {})
|
||||
access_type = kwargs.get("accessType", False)
|
||||
is_https = kwargs.get("isHttps", False)
|
||||
is_certificate = kwargs.get("isCertificate", False)
|
||||
certificate_path = kwargs.get("certificatePath", None)
|
||||
return LlmReq(url=url, header=header, body=body, access_type=access_type, is_https=is_https,
|
||||
is_certificate=is_certificate, certificate_path=certificate_path)
|
||||
|
||||
def build_llm_prompt(self, *args, **kwargs):
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
|
||||
|
||||
def save_sample(self, object_list: List, sample: Dict[str, Any]):
|
||||
if self.target_file_type:
|
||||
target_file_type = self.target_file_type
|
||||
else:
|
||||
target_file_type = "jsonl"
|
||||
save_path = self.get_save_path(sample, target_file_type)
|
||||
self.save_json_file(object_list, save_path)
|
||||
|
||||
def get_save_path(self, sample: Dict[str, Any], target_type) -> str:
|
||||
export_path = os.path.abspath(sample[self.export_path_key])
|
||||
logger.info(f"export path: {export_path}.")
|
||||
base_file_name, _ = os.path.splitext(sample[self.filename_key])
|
||||
file_id = str(sample[self.fileid_key])
|
||||
new_file_name = file_id + '.' + target_type
|
||||
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
|
||||
if not check_valid_path(export_path):
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
res = os.path.join(export_path, new_file_name)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def save_json_file(object_list: List, save_path):
|
||||
if len(object_list) == 0:
|
||||
logger.warning("Please check the param: object_list, which has length equal to 0.")
|
||||
return
|
||||
try:
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
for item in object_list:
|
||||
json_str = json.dumps(item, ensure_ascii=False)
|
||||
f.write(json_str + '\n')
|
||||
|
||||
os.chmod(save_path, 0o640)
|
||||
try:
|
||||
parent_dir = os.path.dirname(save_path)
|
||||
os.chmod(parent_dir, 0o770)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to modify the permission on the parent_dir.")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Save jsonl file Failed!, save_path: {save_path}.") from e
|
||||
|
||||
logger.info(f"LLM output has been save to {save_path}.")
|
||||
8
runtime/python-executor/datamate/core/constant.py
Normal file
8
runtime/python-executor/datamate/core/constant.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
class Fields(object):
|
||||
result = 'execute_result'
|
||||
instance_id = 'instance_id'
|
||||
export_path = 'export_path'
|
||||
|
||||
|
||||
213
runtime/python-executor/datamate/core/dataset.py
Normal file
213
runtime/python-executor/datamate/core/dataset.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pyarrow as pa
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
from ray import data as rd
|
||||
|
||||
from datamate.core.base_op import Filter, Mapper, Slicer
|
||||
from datamate.core.constant import Fields
|
||||
from datamate.core.base_op import OPERATORS, BaseOp
|
||||
|
||||
rd.DataContext.get_current().enable_progress_bars = False
|
||||
|
||||
|
||||
def is_valid_path(item, dataset_dir):
|
||||
full_path = os.path.abspath(os.path.join(dataset_dir, item))
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def new_get_num_npus(init_kwargs):
|
||||
if init_kwargs.get("accelerator", "cpu") != "npu":
|
||||
return 0.0
|
||||
return 0.1
|
||||
|
||||
|
||||
class Formatters(Enum):
|
||||
"""
|
||||
抽取算子和落盘算子枚举类
|
||||
"""
|
||||
FILE_EXPORTER = "FileExporter"
|
||||
IMG_FORMATTER = "ImgFormatter"
|
||||
OCR_FORMATTER = "OcrFormatter"
|
||||
PDF_CPU_FORMATTER = "PdfCpuFormatter"
|
||||
SLID_FORMATTER = "SlideFormatter"
|
||||
TEXT_FORMATTER = "TextFormatter"
|
||||
WORD_FORMATTER = "WordFormatter"
|
||||
UNION_FORMATTER = "UnionFormatter"
|
||||
ONNX_FORMATTER = "OnnxImg2TextFormatter"
|
||||
|
||||
@classmethod
|
||||
def is_member(cls, op_name):
|
||||
return op_name in cls._value2member_map_
|
||||
|
||||
|
||||
class BasicDataset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self,
|
||||
cfg_process,
|
||||
*,
|
||||
exporter=None,
|
||||
checkpointer=None
|
||||
) -> BasicDataset:
|
||||
pass
|
||||
|
||||
|
||||
def preprocess_dataset(dataset: rd.Dataset, cfg) -> rd.Dataset:
|
||||
columns = dataset.columns()
|
||||
new_column_names = [getattr(Fields, attr_name)
|
||||
for attr_name in vars(Fields)
|
||||
if attr_name not in columns and not attr_name.startswith('__')]
|
||||
|
||||
def process_batch_arrow(table: pa.Table, names_list=None) -> pa.Table:
|
||||
name2value_table = {
|
||||
Fields.instance_id: cfg.instance_id,
|
||||
Fields.export_path: cfg.export_path
|
||||
}
|
||||
|
||||
for column_name in names_list:
|
||||
if column_name in name2value_table.keys():
|
||||
new_column_data = [name2value_table[column_name] for _ in range(len(table))]
|
||||
else:
|
||||
new_column_data = [None for _ in range(len(table))]
|
||||
table = table.append_column(column_name, [new_column_data])
|
||||
return table
|
||||
|
||||
if new_column_names:
|
||||
dataset = dataset.map_batches(process_batch_arrow,
|
||||
fn_kwargs={"names_list": new_column_names},
|
||||
num_cpus=0.05,
|
||||
batch_format='pyarrow')
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class RayDataset(BasicDataset):
|
||||
|
||||
def __init__(self,
|
||||
dataset: rd.Dataset,
|
||||
cfg=None) -> None:
|
||||
self.onnx_ops_name = ["OnnxImg2TextFormatter", "OnnxImageContentFilter"]
|
||||
self.npu_ops_name = ["Img2TextFormatter", "ImageContentFilter"]
|
||||
self.data = preprocess_dataset(dataset, cfg)
|
||||
|
||||
def process(self,
|
||||
cfg_process,
|
||||
*,
|
||||
exporter=None,
|
||||
checkpointer=None,
|
||||
**kwargs) -> BasicDataset:
|
||||
|
||||
# 从注册器加载类
|
||||
operators_cls_list = []
|
||||
init_kwargs_list = []
|
||||
for index, process in enumerate(cfg_process):
|
||||
op_name, init_kwargs = list(process.items())[0]
|
||||
init_kwargs = {} if not init_kwargs else init_kwargs
|
||||
init_kwargs.update({'op_name': op_name})
|
||||
|
||||
# 加载Ops module
|
||||
temp_ops = self.load_ops_module(op_name)
|
||||
|
||||
if index == len(cfg_process) - 1:
|
||||
init_kwargs["is_last_op"] = True
|
||||
operators_cls_list.append(temp_ops)
|
||||
init_kwargs_list.append(init_kwargs)
|
||||
|
||||
for cls_id, operators_cls in enumerate(operators_cls_list):
|
||||
self._run_single_op(operators_cls, init_kwargs_list[cls_id], **kwargs)
|
||||
return self
|
||||
|
||||
def load_ops_module(self, op_name):
|
||||
'''
|
||||
加载算子模块
|
||||
:param op_name: 算子名称
|
||||
:return: 算子对象
|
||||
'''
|
||||
parent_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ops")
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
registry_content = OPERATORS.modules[op_name]
|
||||
if isinstance(registry_content, str):
|
||||
# registry_content是module的路径
|
||||
submodule = importlib.import_module(registry_content)
|
||||
res = getattr(submodule, op_name, None)
|
||||
if res is None:
|
||||
raise ImportError(f"Import Ops module {op_name} Failed.")
|
||||
else:
|
||||
logger.info(f"Import Ops module {op_name} Success.")
|
||||
elif isinstance(registry_content, type) and issubclass(registry_content, BaseOp):
|
||||
# registry_content是module本身
|
||||
res = registry_content
|
||||
else:
|
||||
res = None
|
||||
return res
|
||||
|
||||
def _run_single_op(self, operators_cls, init_kwargs, **kwargs):
|
||||
|
||||
num_npus = new_get_num_npus(init_kwargs)
|
||||
max_actor_nums = os.getenv("MAX_ACTOR_NUMS", "20")
|
||||
|
||||
# 分辨是否是onnx算子,如果是需要限制Actor并发数量
|
||||
if self._use_onnx_model(init_kwargs['op_name']):
|
||||
max_actor_nums = 4
|
||||
|
||||
resources = {}
|
||||
|
||||
if num_npus > 0:
|
||||
resources["node_npu"] = 0.1
|
||||
|
||||
if init_kwargs.get("arch", "arm").startswith("x86"):
|
||||
resources["arch"] = "x86"
|
||||
|
||||
kwargs.update({"ext_params": {}, "failed_reason": {}, "target_type": None})
|
||||
try:
|
||||
if issubclass(operators_cls, Mapper):
|
||||
self.data = self.data.map(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, 1 if operators_cls.use_model else int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, Slicer):
|
||||
self.data = self.data.flat_map(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, Filter):
|
||||
self.data = self.data.filter(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
else:
|
||||
logger.error(
|
||||
'Ray executor only support Filter, Mapper and Slicer OPs for now')
|
||||
raise NotImplementedError
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise Exception("Error! Ops Details:") from e
|
||||
|
||||
def _use_onnx_model(self, ops_name):
|
||||
if ops_name in self.onnx_ops_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _use_npu_model(self, ops_name):
|
||||
if ops_name in self.npu_ops_name:
|
||||
return True
|
||||
return False
|
||||
163
runtime/python-executor/datamate/operator_runtime.py
Normal file
163
runtime/python-executor/datamate/operator_runtime.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from jsonargparse import ArgumentParser
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from datamate.common.error_code import ErrorCode
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
from datamate.scheduler import func_scheduler
|
||||
from datamate.wrappers import WRAPPERS
|
||||
|
||||
# 日志配置
|
||||
LOG_DIR = "/var/log/data-mate/runtime"
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger.add(
|
||||
f"{LOG_DIR}/runtime.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
"""自定义API异常"""
|
||||
|
||||
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
|
||||
extra_data: Optional[Dict] = None):
|
||||
self.error_code = error_code
|
||||
self.detail = detail or error_code.value[1]
|
||||
self.code = error_code.value[0]
|
||||
self.extra_data = extra_data
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"code": self.code,
|
||||
"message": self.detail,
|
||||
"success": False
|
||||
}
|
||||
if self.extra_data:
|
||||
result["data"] = self.extra_data
|
||||
return result
|
||||
|
||||
|
||||
@app.exception_handler(APIException)
|
||||
async def api_exception_handler(request: Request, exc: APIException):
|
||||
return JSONResponse(
|
||||
status_code=200, # 业务错误返回 200,错误信息在响应体中
|
||||
content=exc.to_dict()
|
||||
)
|
||||
|
||||
|
||||
class QueryTaskRequest(BaseModel):
|
||||
task_ids: List[str]
|
||||
|
||||
|
||||
@app.post("/api/task/list")
|
||||
async def query_task_info(request: QueryTaskRequest):
|
||||
try:
|
||||
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
|
||||
except Exception as e:
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/submit")
|
||||
async def submit_task(task_id):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
logger.info("Start submitting job...")
|
||||
|
||||
dataset_path = get_from_cfg(task_id, "dataset_path")
|
||||
if not check_valid_path(dataset_path):
|
||||
logger.error(f"dataset_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
await WRAPPERS.get(executor_type).submit(task_id, config_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error happens during submitting task. Error Info following: {e}")
|
||||
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
|
||||
|
||||
logger.info(f"task id: {task_id} has been submitted.")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been submitted"},
|
||||
status_code=200
|
||||
)
|
||||
return success_json_info
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/stop")
|
||||
async def stop_task(task_id):
|
||||
logger.info("Start stopping ray job...")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been stopped"},
|
||||
status_code=200
|
||||
)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
if not WRAPPERS.get(executor_type).cancel(task_id):
|
||||
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
|
||||
except Exception as e:
|
||||
if isinstance(e, APIException):
|
||||
raise e
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
logger.info(f"{task_id} has been stopped.")
|
||||
return success_json_info
|
||||
|
||||
|
||||
def check_valid_path(file_path):
|
||||
full_path = os.path.abspath(file_path)
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def get_from_cfg(task_id, key):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
if not check_valid_path(config_path):
|
||||
logger.error(f"config_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
with open(config_path, "r", encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
cfg = yaml.safe_load(content)
|
||||
return cfg[key]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
|
||||
|
||||
parser.add_argument(
|
||||
'--ip',
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help='Service ip for this API, default to use 0.0.0.0.'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Service port for this API, default to use 8600.'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
p_args = parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=p_args.ip,
|
||||
port=p_args.port
|
||||
)
|
||||
22
runtime/python-executor/datamate/ops/__init__.py
Normal file
22
runtime/python-executor/datamate/ops/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# 获取当前目录
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
# 遍历子目录
|
||||
for module_name in os.listdir(current_dir):
|
||||
module_path = os.path.join(current_dir, module_name)
|
||||
# 检查是否是目录且包含 __init__.py
|
||||
if os.path.isdir(module_path) and '__init__.py' in os.listdir(module_path):
|
||||
# 动态导入模块
|
||||
try:
|
||||
importlib.import_module(f".{module_name}", package=__name__)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Ops {module_name}")
|
||||
6
runtime/python-executor/datamate/scheduler/__init__.py
Normal file
6
runtime/python-executor/datamate/scheduler/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .cmd_task_scheduler import CommandScheduler
|
||||
from .func_task_scheduler import CallableScheduler
|
||||
|
||||
|
||||
cmd_scheduler = CommandScheduler(max_concurrent=5)
|
||||
func_scheduler = CallableScheduler(max_concurrent=5)
|
||||
214
runtime/python-executor/datamate/scheduler/cmd_task_scheduler.py
Normal file
214
runtime/python-executor/datamate/scheduler/cmd_task_scheduler.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .scheduler import Task, TaskStatus, TaskResult, TaskScheduler
|
||||
|
||||
|
||||
class CommandTask(Task):
|
||||
"""命令任务包装类"""
|
||||
|
||||
def __init__(self, task_id: str, command: str, shell: bool = True,
|
||||
timeout: Optional[int] = None, *args, **kwargs):
|
||||
super().__init__(task_id, *args, **kwargs)
|
||||
self.command = command
|
||||
self.shell = shell
|
||||
self.timeout = timeout
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
self.return_code = None
|
||||
self._process = None
|
||||
|
||||
def start(self) -> 'CommandTask':
|
||||
"""启动任务"""
|
||||
if self.status == TaskStatus.PENDING:
|
||||
self.status = TaskStatus.RUNNING
|
||||
self.started_at = datetime.now()
|
||||
self._task = asyncio.create_task(self._execute())
|
||||
return self
|
||||
|
||||
async def _execute(self):
|
||||
"""执行命令"""
|
||||
try:
|
||||
self.status = TaskStatus.RUNNING
|
||||
self.started_at = datetime.now()
|
||||
|
||||
# 使用 asyncio.create_subprocess_shell 或 create_subprocess_exec
|
||||
if self.shell:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
**self.kwargs
|
||||
)
|
||||
else:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*self.command.split(),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
self._process = process
|
||||
|
||||
# 等待进程完成(带超时)
|
||||
try:
|
||||
if self.timeout:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
else:
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
self.stdout = stdout.decode() if stdout else ""
|
||||
self.stderr = stderr.decode() if stderr else ""
|
||||
self.return_code = process.returncode
|
||||
|
||||
if self._cancelled:
|
||||
self.status = TaskStatus.CANCELLED
|
||||
elif process.returncode == 0:
|
||||
self.status = TaskStatus.COMPLETED
|
||||
else:
|
||||
self.status = TaskStatus.FAILED
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
|
||||
self.status = TaskStatus.FAILED
|
||||
self.stderr = f"Command timed out after {self.timeout} seconds"
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 任务被取消
|
||||
if self._process:
|
||||
self._process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
|
||||
self.status = TaskStatus.CANCELLED
|
||||
self._cancelled = True
|
||||
|
||||
except Exception as e:
|
||||
self.status = TaskStatus.FAILED
|
||||
self.stderr = str(e)
|
||||
finally:
|
||||
self.completed_at = datetime.now()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""取消任务"""
|
||||
if self._process and self.status == TaskStatus.RUNNING:
|
||||
try:
|
||||
# 尝试优雅终止
|
||||
self._process.terminate()
|
||||
self._cancelled = True
|
||||
return True
|
||||
except Exception:
|
||||
# 如果无法终止,强制杀死
|
||||
try:
|
||||
self._process.kill()
|
||||
self._cancelled = True
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
def to_result(self) -> TaskResult:
|
||||
"""转换为结果对象"""
|
||||
self.result = {
|
||||
"command": self.command,
|
||||
"stdout": self.stdout,
|
||||
"stderr": self.stderr,
|
||||
"return_code": self.return_code,
|
||||
}
|
||||
return super().to_result()
|
||||
|
||||
|
||||
class CommandScheduler(TaskScheduler):
|
||||
"""命令调度器"""
|
||||
|
||||
def __init__(self, max_concurrent: int = 5):
|
||||
super().__init__(max_concurrent)
|
||||
|
||||
async def submit(self, task_id, command: str, shell: bool = True,
|
||||
timeout: Optional[int] = None, **kwargs) -> str:
|
||||
"""提交命令任务"""
|
||||
task = CommandTask(task_id, command, shell, timeout, **kwargs)
|
||||
self.tasks[task_id] = task
|
||||
|
||||
# 使用信号量限制并发
|
||||
async with self.semaphore:
|
||||
# 异步执行任务
|
||||
task.start()
|
||||
|
||||
logger.info(f"命令任务 {task_id} 已提交并开始执行")
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""获取任务状态"""
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
return task.to_result()
|
||||
return None
|
||||
|
||||
def get_all_tasks(self) -> List[TaskResult]:
|
||||
"""获取所有任务状态"""
|
||||
return [task.to_result() for task in self.tasks.values()]
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务"""
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
return True
|
||||
if task.status == TaskStatus.RUNNING:
|
||||
cancelled = task.cancel()
|
||||
if cancelled:
|
||||
logger.info(f"命令任务 {task_id} 已取消")
|
||||
return cancelled
|
||||
return False
|
||||
|
||||
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
|
||||
"""根据状态获取任务"""
|
||||
return [
|
||||
task.to_result()
|
||||
for task in self.tasks.values()
|
||||
if task.status == status
|
||||
]
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
|
||||
"""等待任务完成"""
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"任务 {task_id} 不存在")
|
||||
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
return task.to_result()
|
||||
|
||||
# 对于运行中的任务,我们已经通过 await task.execute() 等待了
|
||||
# 所以这里直接返回结果
|
||||
return task.to_result()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭调度器,取消所有运行中的任务"""
|
||||
logger.info("正在关闭命令调度器...")
|
||||
|
||||
running_tasks = [
|
||||
task for task in self.tasks.values()
|
||||
if task.status == TaskStatus.RUNNING
|
||||
]
|
||||
|
||||
for task in running_tasks:
|
||||
logger.info(f"取消运行中的命令任务: {task.task_id}")
|
||||
task.cancel()
|
||||
|
||||
logger.info("命令调度器已关闭")
|
||||
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .scheduler import TaskStatus, TaskResult, Task, TaskScheduler
|
||||
|
||||
|
||||
class CallableTask(Task):
|
||||
"""任务包装类"""
|
||||
|
||||
def __init__(self, task_id: str, func: Callable, *args, **kwargs):
|
||||
super().__init__(task_id, *args, **kwargs)
|
||||
self.func = func
|
||||
|
||||
def start(self) -> 'CallableTask':
|
||||
"""启动任务"""
|
||||
if self.status == TaskStatus.PENDING:
|
||||
self.status = TaskStatus.RUNNING
|
||||
self.started_at = datetime.now()
|
||||
self._task = asyncio.create_task(self._execute())
|
||||
return self
|
||||
|
||||
async def _execute(self):
|
||||
"""执行任务"""
|
||||
try:
|
||||
self.result = await self.func(*self.args, **self.kwargs)
|
||||
self.status = TaskStatus.COMPLETED
|
||||
except asyncio.CancelledError:
|
||||
self.status = TaskStatus.CANCELLED
|
||||
self._cancelled = True
|
||||
except Exception as e:
|
||||
self.status = TaskStatus.FAILED
|
||||
self.error = str(e)
|
||||
finally:
|
||||
self.completed_at = datetime.now()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""取消任务"""
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class CallableScheduler(TaskScheduler):
|
||||
"""异步任务调度器"""
|
||||
|
||||
def __init__(self, max_concurrent: int = 10):
|
||||
super().__init__(max_concurrent)
|
||||
|
||||
async def submit(self, task_id, func: Callable, *args, **kwargs) -> str:
|
||||
"""提交任务"""
|
||||
task = CallableTask(task_id, func, *args, **kwargs)
|
||||
self.tasks[task_id] = task
|
||||
|
||||
# 使用信号量限制并发
|
||||
async with self.semaphore:
|
||||
task.start()
|
||||
|
||||
logger.info(f"任务 {task_id} 已提交并开始执行")
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""获取任务状态"""
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
return task.to_result()
|
||||
return None
|
||||
|
||||
def get_all_tasks(self) -> List[TaskResult]:
|
||||
"""获取所有任务状态"""
|
||||
return [task.to_result() for task in self.tasks.values()]
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务"""
|
||||
task = self.tasks.get(task_id)
|
||||
if task and task.status == TaskStatus.RUNNING:
|
||||
cancelled = task.cancel()
|
||||
if cancelled:
|
||||
logger.info(f"任务 {task_id} 已取消")
|
||||
return cancelled
|
||||
return False
|
||||
|
||||
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
|
||||
"""根据状态获取任务"""
|
||||
return [
|
||||
task.to_result()
|
||||
for task in self.tasks.values()
|
||||
if task.status == status
|
||||
]
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
|
||||
"""等待任务完成"""
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"任务 {task_id} 不存在")
|
||||
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
return task.to_result()
|
||||
|
||||
# 等待任务完成
|
||||
if task.get():
|
||||
try:
|
||||
await asyncio.wait_for(task.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"任务 {task_id} 超时")
|
||||
|
||||
return task.to_result()
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭调度器,取消所有运行中的任务"""
|
||||
logger.info("正在关闭调度器...")
|
||||
|
||||
running_tasks = [
|
||||
task for task in self.tasks.values()
|
||||
if task.status == TaskStatus.RUNNING
|
||||
]
|
||||
|
||||
for task in running_tasks:
|
||||
logger.info(f"取消运行中的任务: {task.task_id}")
|
||||
task.cancel()
|
||||
|
||||
# 等待所有任务完成
|
||||
for task in running_tasks:
|
||||
if task.get() and not task.get().done():
|
||||
try:
|
||||
await asyncio.wait_for(task.get(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"任务 {task.task_id} 无法正常停止")
|
||||
|
||||
logger.info("调度器已关闭")
|
||||
160
runtime/python-executor/datamate/scheduler/scheduler.py
Normal file
160
runtime/python-executor/datamate/scheduler/scheduler.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# 任务状态枚举
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Dict, List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending" # 等待执行
|
||||
RUNNING = "running" # 正在运行
|
||||
COMPLETED = "completed" # 已完成
|
||||
FAILED = "failed" # 执行失败
|
||||
CANCELLED = "cancelled" # 已取消
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果数据类"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
result: Any = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
progress: float = 0.0
|
||||
|
||||
|
||||
class Task:
|
||||
def __init__(self, task_id: str, *args, **kwargs):
|
||||
self.task_id = task_id
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.status = TaskStatus.PENDING
|
||||
self.result = None
|
||||
self.error = None
|
||||
self.created_at = datetime.now()
|
||||
self.started_at = None
|
||||
self.completed_at = None
|
||||
self.progress = 0.0
|
||||
self._task = None # asyncio.Task 实例
|
||||
self._cancelled = False
|
||||
|
||||
def get(self):
|
||||
return self._task
|
||||
|
||||
def start(self) -> 'Task':
|
||||
"""启动任务"""
|
||||
pass
|
||||
|
||||
async def _execute(self):
|
||||
"""执行任务"""
|
||||
pass
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""取消任务"""
|
||||
pass
|
||||
|
||||
def to_result(self) -> TaskResult:
|
||||
"""转换为结果对象"""
|
||||
return TaskResult(
|
||||
task_id=self.task_id,
|
||||
status=self.status,
|
||||
result=self.result,
|
||||
error=self.error,
|
||||
created_at=self.created_at,
|
||||
started_at=self.started_at,
|
||||
completed_at=self.completed_at,
|
||||
progress=self.progress
|
||||
)
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""异步任务调度器"""
|
||||
|
||||
def __init__(self, max_concurrent: int = 10):
|
||||
self.max_concurrent = max_concurrent
|
||||
self.tasks: Dict[str, Task] = {}
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
# 注册信号处理器
|
||||
try:
|
||||
signal.signal(signal.SIGINT, self._signal_handler)
|
||||
signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
except (ValueError, AttributeError):
|
||||
# 在某些平台上可能不支持
|
||||
pass
|
||||
|
||||
def _signal_handler(self, signum, frame):
|
||||
"""信号处理器"""
|
||||
logger.info(f"收到信号 {signum},正在清理任务...")
|
||||
asyncio.create_task(self.shutdown())
|
||||
sys.exit(0)
|
||||
|
||||
async def submit(self, task_id, task, *args, **kwargs) -> str:
|
||||
"""提交任务"""
|
||||
pass
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
|
||||
"""获取任务状态"""
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
return task.to_result()
|
||||
return None
|
||||
|
||||
def get_all_tasks(self) -> List[TaskResult]:
|
||||
"""获取所有任务状态"""
|
||||
return [task.to_result() for task in self.tasks.values()]
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消任务"""
|
||||
task = self.tasks.get(task_id)
|
||||
if task and task.status == TaskStatus.RUNNING:
|
||||
cancelled = task.cancel()
|
||||
if cancelled:
|
||||
logger.info(f"任务 {task_id} 已取消")
|
||||
return cancelled
|
||||
return False
|
||||
|
||||
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
|
||||
"""根据状态获取任务"""
|
||||
return [
|
||||
task.to_result()
|
||||
for task in self.tasks.values()
|
||||
if task.status == status
|
||||
]
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
|
||||
"""等待任务完成"""
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭调度器,取消所有运行中的任务"""
|
||||
pass
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""获取统计信息"""
|
||||
stats = {
|
||||
TaskStatus.PENDING: 0,
|
||||
TaskStatus.RUNNING: 0,
|
||||
TaskStatus.COMPLETED: 0,
|
||||
TaskStatus.FAILED: 0,
|
||||
TaskStatus.CANCELLED: 0
|
||||
}
|
||||
|
||||
for task in self.tasks.values():
|
||||
stats[task.status] += 1
|
||||
|
||||
return {
|
||||
"pending": stats[TaskStatus.PENDING],
|
||||
"running": stats[TaskStatus.RUNNING],
|
||||
"completed": stats[TaskStatus.COMPLETED],
|
||||
"failed": stats[TaskStatus.FAILED],
|
||||
"cancelled": stats[TaskStatus.CANCELLED],
|
||||
"total": len(self.tasks)
|
||||
}
|
||||
2
runtime/python-executor/datamate/sql_manager/__init__.py
Normal file
2
runtime/python-executor/datamate/sql_manager/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
|
||||
|
||||
class TaskInfoPersistence:
|
||||
def __init__(self):
|
||||
self.sql_dict = self.load_sql_dict()
|
||||
|
||||
@staticmethod
|
||||
def load_sql_dict():
|
||||
"""获取sql语句"""
|
||||
sql_config_path = str(Path(__file__).parent / 'sql' / 'sql_config.json')
|
||||
with open(sql_config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def persistence_task_info(self, sample: Dict[str, Any]):
|
||||
instance_id = str(sample.get("instance_id"))
|
||||
src_file_name = str(sample.get("sourceFileName"))
|
||||
src_file_type = str(sample.get("sourceFileType"))
|
||||
src_file_id = str(sample.get("sourceFileId"))
|
||||
src_file_size = int(sample.get("sourceFileSize"))
|
||||
file_id = str(uuid.uuid4())
|
||||
file_size = str(sample.get("fileSize"))
|
||||
file_type = str(sample.get("fileType"))
|
||||
file_name = str(sample.get("fileName"))
|
||||
|
||||
status = str(sample.get("execute_status"))
|
||||
failed_reason = str(sample.get("failed_reason"))
|
||||
result_data = {
|
||||
"instance_id": instance_id,
|
||||
"src_file_id": src_file_id,
|
||||
"dest_file_id": file_id,
|
||||
"src_name": src_file_name,
|
||||
"dest_name": file_name,
|
||||
"src_type": src_file_type,
|
||||
"dest_type": file_type,
|
||||
"src_size": src_file_size,
|
||||
"dest_size": file_size,
|
||||
"status": status,
|
||||
"result": failed_reason
|
||||
}
|
||||
self.insert_result(result_data, str(self.sql_dict.get("insert_clean_result_sql")))
|
||||
|
||||
dataset_id = str(sample.get("dataset_id"))
|
||||
file_path = str(sample.get("filePath"))
|
||||
create_time = datetime.now()
|
||||
last_access_time = datetime.fromtimestamp(os.path.getmtime(file_path))
|
||||
file_data = {
|
||||
"id": file_id,
|
||||
"dataset_id": dataset_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_type": file_type,
|
||||
"file_size": file_size,
|
||||
"status": "COMPLETED",
|
||||
"upload_time": create_time,
|
||||
"last_access_time": last_access_time,
|
||||
"created_at": create_time,
|
||||
"updated_at": create_time
|
||||
}
|
||||
self.insert_result(file_data, str(self.sql_dict.get("insert_dataset_file_sql")))
|
||||
|
||||
@staticmethod
|
||||
def insert_result(data, sql):
|
||||
retries = 0
|
||||
max_retries = 20
|
||||
retry_delay = 1
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(text(sql), data)
|
||||
return
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e) or "locking protocol" in str(e):
|
||||
retries += 1
|
||||
time.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("database execute failed: {}", str(e))
|
||||
raise RuntimeError(82000, str(e)) from None
|
||||
raise Exception("Max retries exceeded")
|
||||
|
||||
def update_result(self, dataset_id, instance_id, status):
|
||||
dataset_data = {
|
||||
"dataset_id": dataset_id
|
||||
}
|
||||
query_dataset_sql = str(self.sql_dict.get("query_dataset_sql"))
|
||||
with SQLManager.create_connect() as conn:
|
||||
result = conn.execute(text(query_dataset_sql), dataset_data)
|
||||
if result:
|
||||
rows = result.fetchall()
|
||||
total_size = sum(int(row[0]) for row in rows)
|
||||
file_count = len(rows)
|
||||
else:
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
dataset_data.update({
|
||||
"task_id": instance_id,
|
||||
"total_size": total_size,
|
||||
"file_count": file_count
|
||||
})
|
||||
|
||||
update_dataset_sql = str(self.sql_dict.get("update_dataset_sql"))
|
||||
self.insert_result(dataset_data, update_dataset_sql)
|
||||
|
||||
task_data = {
|
||||
"task_id": instance_id,
|
||||
"status": status,
|
||||
"total_size": total_size,
|
||||
"finished_time": datetime.now()
|
||||
}
|
||||
update_task_sql = str(self.sql_dict.get("update_task_sql"))
|
||||
self.insert_result(task_data, update_task_sql)
|
||||
|
||||
def query_task_info(self, instance_ids: list[str]):
|
||||
result = {}
|
||||
current_result = None
|
||||
for instance_id in instance_ids:
|
||||
try:
|
||||
current_result = self.execute_sql_query(instance_id)
|
||||
except Exception as e:
|
||||
logger.warning("instance_id: {}, query job result error: {}", instance_id, str(e))
|
||||
if current_result:
|
||||
result[instance_id] = current_result
|
||||
return result
|
||||
|
||||
def execute_sql_query(self, instance_id):
|
||||
result = None
|
||||
create_tables_sql = str(self.sql_dict.get("create_tables_sql"))
|
||||
query_sql = str(self.sql_dict.get("query_sql"))
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(text(create_tables_sql))
|
||||
execute_result = conn.execute(text(query_sql), {"instance_id": instance_id})
|
||||
result = execute_result.fetchall()
|
||||
return result
|
||||
|
||||
# todo 删除接口待实现
|
||||
def delete_task_info(self, instance_id: str):
|
||||
create_tables_sql = self.sql_dict.get("create_tables_sql")
|
||||
delete_task_instance_sql = self.sql_dict.get("delete_task_instance_sql")
|
||||
try:
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(text(create_tables_sql))
|
||||
conn.execute(text(delete_task_instance_sql), {"instance_id": instance_id})
|
||||
except Exception as e:
|
||||
logger.warning(f"delete database for flow: {instance_id}", e)
|
||||
|
||||
def delete_task_operate_info(self, instance_id: str):
|
||||
create_duplicate_img_tables_sql = self.sql_dict.get("create_duplicate_img_tables_sql")
|
||||
create_similar_img_tables_sql = self.sql_dict.get("create_similar_img_tables_sql")
|
||||
create_similar_text_tables_sql = self.sql_dict.get("create_similar_text_tables_sql")
|
||||
delete_duplicate_img_tables_sql = self.sql_dict.get("delete_duplicate_img_tables_sql")
|
||||
delete_similar_img_tables_sql = self.sql_dict.get("delete_similar_img_tables_sql")
|
||||
delete_similar_text_tables_sql = self.sql_dict.get("delete_similar_text_tables_sql")
|
||||
try:
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(text(create_duplicate_img_tables_sql))
|
||||
conn.execute(text(delete_duplicate_img_tables_sql), {"instance_id": instance_id})
|
||||
conn.execute(text(create_similar_img_tables_sql))
|
||||
conn.execute(text(delete_similar_img_tables_sql), {"instance_id": instance_id})
|
||||
conn.execute(text(create_similar_text_tables_sql))
|
||||
conn.execute(text(delete_similar_text_tables_sql), {"instance_id": instance_id})
|
||||
except Exception as e:
|
||||
logger.warning(f"delete database for flow: {instance_id} error", e)
|
||||
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"query_sql": "SELECT * FROM t_task_instance_info WHERE instance_id IN (:instance_id)",
|
||||
"insert_sql": "INSERT INTO t_task_instance_info (instance_id, meta_file_name, meta_file_type, meta_file_id, meta_file_size, file_id, file_size, file_type, file_name, file_path, status, operator_id, error_code, incremental, child_id, slice_num) VALUES (:instance_id, :meta_file_name, :meta_file_type, :meta_file_id, :meta_file_size, :file_id, :file_size, :file_type, :file_name, :file_path, :status, :operator_id, :error_code, :incremental, :child_id, :slice_num)",
|
||||
"insert_dataset_file_sql": "INSERT INTO t_dm_dataset_files (id, dataset_id, file_name, file_path, file_type, file_size, status, upload_time, last_access_time, created_at, updated_at) VALUES (:id, :dataset_id, :file_name, :file_path, :file_type, :file_size, :status, :upload_time, :last_access_time, :created_at, :updated_at)",
|
||||
"insert_clean_result_sql": "INSERT INTO t_clean_result (instance_id, src_file_id, dest_file_id, src_name, dest_name, src_type, dest_type, src_size, dest_size, status, result) VALUES (:instance_id, :src_file_id, :dest_file_id, :src_name, :dest_name, :src_type, :dest_type, :src_size, :dest_size, :status, :result)",
|
||||
"query_dataset_sql": "SELECT file_size FROM t_dm_dataset_files WHERE dataset_id = :dataset_id",
|
||||
"update_dataset_sql": "UPDATE t_dm_datasets SET size_bytes = :total_size, file_count = :file_count WHERE id = :dataset_id;",
|
||||
"update_task_sql": "UPDATE t_clean_task SET status = :status, after_size = :total_size, finished_at = :finished_time WHERE id = :task_id",
|
||||
"create_tables_sql": "CREATE TABLE IF NOT EXISTS t_task_instance_info (instance_id VARCHAR(255), meta_file_name TEXT, meta_file_type VARCHAR(100), meta_file_id BIGINT, meta_file_size VARCHAR(100), file_id BIGINT, file_size VARCHAR(100), file_type VARCHAR(100), file_name TEXT, file_path TEXT, status INT, operator_id VARCHAR(255), error_code VARCHAR(100), incremental VARCHAR(50), child_id BIGINT, slice_num INT DEFAULT 0);",
|
||||
"delete_task_instance_sql": "DELETE FROM t_task_instance_info WHERE instance_id = :instance_id",
|
||||
"create_duplicate_img_tables_sql": "CREATE TABLE IF NOT EXISTS operator_duplicate_img_features (id INT AUTO_INCREMENT PRIMARY KEY,task_uuid VARCHAR(255),file_feature TEXT,file_name TEXT,timestamp DATETIME);",
|
||||
"delete_duplicate_img_tables_sql": "DELETE FROM operator_duplicate_img_features WHERE flow_id = :flow_id",
|
||||
"create_similar_img_tables_sql": "CREATE TABLE IF NOT EXISTS operator_similar_img_features (id INT AUTO_INCREMENT PRIMARY KEY,task_uuid VARCHAR(255),p_hash TEXT,des_matrix BLOB,matrix_shape TEXT,file_name TEXT,timestamp DATETIME);",
|
||||
"delete_similar_img_tables_sql": "DELETE FROM operator_similar_img_features WHERE flow_id = :flow_id",
|
||||
"create_similar_text_tables_sql": "CREATE TABLE IF NOT EXISTS operators_similar_text_features (id INT AUTO_INCREMENT PRIMARY KEY, task_uuid VARCHAR(255),file_feature TEXT,file_name TEXT,timestamp DATETIME);",
|
||||
"delete_similar_text_tables_sql": "DELETE FROM operators_similar_text_features WHERE flow_id = :flow_id"
|
||||
}
|
||||
52
runtime/python-executor/datamate/sql_manager/sql_manager.py
Normal file
52
runtime/python-executor/datamate/sql_manager/sql_manager.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# -- encoding: utf-8 --
|
||||
import os
|
||||
import time
|
||||
from random import uniform
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
|
||||
class SQLManager:
|
||||
|
||||
@staticmethod
|
||||
def create_connect(max_retries=5, base_delay=1):
|
||||
"""
|
||||
连接到 MySQL 数据库,使用 SQLAlchemy 和 PyMySQL。
|
||||
:param max_retries: 最大重试次数
|
||||
:param base_delay: 基础时延
|
||||
:return: 返回 SQLAlchemy 连接对象
|
||||
"""
|
||||
|
||||
connection_url = URL.create(
|
||||
drivername="mysql+pymysql",
|
||||
username=os.getenv("MYSQL_USER", "root"),
|
||||
password=os.getenv("MYSQL_PASSWORD", "Huawei@123"),
|
||||
host=os.getenv("MYSQL_HOST", "mysql"),
|
||||
port=os.getenv("MYSQL_PORT", 3306),
|
||||
database=os.getenv("MYSQL_DATABASE", "datamate"),
|
||||
query={"charset": "utf8mb4"},
|
||||
)
|
||||
|
||||
attempt = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
engine = create_engine(connection_url, pool_pre_ping=True, isolation_level="AUTOCOMMIT")
|
||||
return engine.connect()
|
||||
except Exception as e:
|
||||
logger.error(f"Attempt {attempt + 1} failed with error: {str(e)}")
|
||||
if attempt >= max_retries - 1:
|
||||
raise
|
||||
wait_time = min(30, base_delay * (2 ** attempt)) # 不超过30秒的最大延时
|
||||
jitter = uniform(-wait_time / 4, wait_time / 4) # 增加随机抖动因子
|
||||
time.sleep(wait_time + jitter)
|
||||
attempt += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with SQLManager.create_connect() as connection:
|
||||
inspector = inspect(connection)
|
||||
print(inspector.get_table_names())
|
||||
|
||||
6
runtime/python-executor/datamate/wrappers/__init__.py
Normal file
6
runtime/python-executor/datamate/wrappers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from . import data_juicer_wrapper, datamate_wrapper
|
||||
|
||||
WRAPPERS = {
|
||||
"data_juicer": data_juicer_wrapper,
|
||||
"datamate": datamate_wrapper
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
|
||||
|
||||
async def submit(task_id, config_path):
|
||||
await cmd_scheduler.submit(task_id, f"dj-process --config {config_path}")
|
||||
131
runtime/python-executor/datamate/wrappers/datamate_executor.py
Normal file
131
runtime/python-executor/datamate/wrappers/datamate_executor.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import ray
|
||||
import yaml
|
||||
from jsonargparse import dict_to_namespace, ArgumentParser
|
||||
from loguru import logger
|
||||
|
||||
from datamate.common.utils import check_valid_path
|
||||
from datamate.core.dataset import RayDataset
|
||||
from datamate.sql_manager.persistence_atction import TaskInfoPersistence
|
||||
|
||||
import datamate.ops
|
||||
|
||||
|
||||
class RayExecutor:
|
||||
"""
|
||||
基于Ray的执行器.
|
||||
|
||||
1. 当前仅支持Mapper,Filter类型的算子。
|
||||
2. 当前仅加载json文件类型的数据集。
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=None, meta=None):
|
||||
if isinstance(cfg, Dict):
|
||||
self.cfg = dict_to_namespace(cfg)
|
||||
else:
|
||||
logger.error(f"Please set param: cfg as type Dict, but given cfg as type {type(cfg).__name__}")
|
||||
raise TypeError(f"To params cfg, Dict type is required, but type {type(cfg).__name__} is given!")
|
||||
|
||||
self.cfg.process = cfg['process']
|
||||
self.meta = meta
|
||||
|
||||
# init ray
|
||||
logger.info('Initing Ray ...')
|
||||
ray.init()
|
||||
|
||||
def load_meta(self, line):
|
||||
meta = json.loads(line)
|
||||
if meta.get("fileId"):
|
||||
meta["sourceFileId"] = meta.get("fileId")
|
||||
if meta.get("fileName"):
|
||||
meta["sourceFileName"] = meta.get("fileName")
|
||||
if meta.get("fileType"):
|
||||
meta["sourceFileType"] = meta.get("fileType")
|
||||
if meta.get("fileSize"):
|
||||
meta["sourceFileSize"] = meta.get("fileSize")
|
||||
if not meta.get("totalPageNum"):
|
||||
meta["totalPageNum"] = 0
|
||||
if not meta.get("extraFilePath"):
|
||||
meta["extraFilePath"] = None
|
||||
if not meta.get("extraFileType"):
|
||||
meta["extraFileType"] = None
|
||||
meta["dataset_id"] = self.cfg.dataset_id
|
||||
return meta
|
||||
|
||||
def run(self):
|
||||
# 1. 加载数据集
|
||||
logger.info('Loading dataset with Ray...')
|
||||
|
||||
if self.meta:
|
||||
file_content = base64.b64decode(self.meta)
|
||||
lines = file_content.splitlines()
|
||||
dataset = ray.data.from_items([json.loads(line) for line in lines])
|
||||
else:
|
||||
dataset = self.load_dataset()
|
||||
dataset = RayDataset(dataset, self.cfg)
|
||||
|
||||
# 3. 处理数据
|
||||
logger.info('Processing data...')
|
||||
tstart = time.time()
|
||||
dataset.process(self.cfg.process, **getattr(self.cfg, 'kwargs', {}))
|
||||
tend = time.time()
|
||||
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')
|
||||
|
||||
dataset.data.materialize()
|
||||
|
||||
def load_dataset(self):
|
||||
retry = 0
|
||||
dataset = None
|
||||
jsonl_file_path = self.cfg.dataset_path
|
||||
while True:
|
||||
if check_valid_path(jsonl_file_path):
|
||||
with open(jsonl_file_path, "r", encoding='utf-8') as meta:
|
||||
lines = meta.readlines()
|
||||
dataset = ray.data.from_items([self.load_meta(line) for line in lines])
|
||||
break
|
||||
if retry < 5:
|
||||
retry += 1
|
||||
time.sleep(retry)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"can not load dataset from dataset_path")
|
||||
raise RuntimeError(f"Load dataset Failed!, dataset_path: {self.cfg.dataset_path}.")
|
||||
|
||||
return dataset
|
||||
|
||||
def update_db(self, status):
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.update_result(self.cfg.dataset_id, self.cfg.instance_id, status)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
|
||||
|
||||
parser.add_argument("--config_path", type=str, required=False, default="../configs/demo.yaml")
|
||||
parser.add_argument("--flow_config", type=str, required=False, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = args.config_path
|
||||
flow_config = args.flow_config
|
||||
|
||||
if flow_config:
|
||||
m_cfg = yaml.safe_load(base64.b64decode(flow_config))
|
||||
else:
|
||||
with open(config_path, "r", encoding='utf-8') as cfg:
|
||||
m_cfg = yaml.safe_load(cfg)
|
||||
|
||||
executor = RayExecutor(m_cfg)
|
||||
try:
|
||||
executor.run()
|
||||
except Exception as e:
|
||||
executor.update_db("FAILED")
|
||||
raise e
|
||||
executor.update_db("COMPLETED")
|
||||
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
|
||||
|
||||
async def submit(task_id, config_path):
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
await cmd_scheduler.submit(task_id, f"python {os.path.join(current_dir, 'datamate_executor.py')} "
|
||||
f"--config_path={config_path}")
|
||||
|
||||
|
||||
def cancel(task_id):
|
||||
return cmd_scheduler.cancel_task(task_id)
|
||||
Reference in New Issue
Block a user