init datamate

This commit is contained in:
Dallas98
2025-10-21 23:00:48 +08:00
commit 1c97afed7d
692 changed files with 135442 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
# 自定义算子开发指南
## 算子规范
### 算子元数据格式
每个自定义算子都需要包含一个 `metadata.yml` 文件:
```yaml
name: '落盘算子'
name_en: 'save file operator'
description: '将文件内容保存为文件。'
description_en: 'Save the file data as a file.'
language: 'Python'
vendor: 'Huawei'
raw_id: 'FileExporter'
version: '1.0.0'
types:
- 'collect'
modal: 'others'
effect:
before: ''
after: ''
inputs: 'all'
outputs: 'all'
```
### 算子实现
创建 `process.py` 文件:
```python
# -*- coding: utf-8 -*-
"""
Description: Json文本抽取
Create: 2024/06/06 15:43
"""
import time
from loguru import logger
from typing import Dict, Any
from datamate.core.base_op import Mapper
class TextFormatter(Mapper):
"""把输入的json文件流抽取为txt"""
def __init__(self, *args, **kwargs):
super(TextFormatter, self).__init__(*args, **kwargs)
@staticmethod
def _extract_json(byte_io):
"""将默认使用utf-8编码的Json文件流解码,抽取为txt"""
# 用utf-8-sig的格式进行抽取,可以避免uft-8 BOM编码格式的文件在抽取后产生隐藏字符作为前缀。
return byte_io.decode("utf-8-sig").replace("\r\n", "\n")
def byte_read(self, sample: Dict[str, Any]):
filepath = sample[self.filepath_key]
with open(filepath, "rb") as file:
byte_data = file.read()
sample[self.data_key] = byte_data
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
start = time.time()
try:
self.byte_read(sample)
sample[self.text_key] = self._extract_json(sample[self.data_key])
sample[self.data_key] = b"" # 将sample[self.data_key]置空
logger.info(
f"fileName: {sample[self.filename_key]}, method: TextFormatter costs {(time.time() - start):6f} s")
except UnicodeDecodeError as err:
logger.exception(f"fileName: {sample[self.filename_key]}, method: TextFormatter causes decode error: {err}")
raise
return sample
```
创建 `__init__.py` 文件:
```python
# -*- coding: utf-8 -*-
from datamate.core.base_op import OPERATORS
OPERATORS.register_module(module_name='TextFormatter',
module_path="ops.formatter.text_formatter.process")
```

View File

@@ -0,0 +1 @@
__version__ = "0.0.1"

View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
"""
The class hierarchy for built-in exceptions is:
see https://docs.python.org/3/library/exceptions.html
BaseException
├── SystemExit
├── KeyboardInterrupt
├── GeneratorExit
└── Exception
├── ArithmeticError
│ ├── FloatingPointError
│ ├── OverflowError
│ └── ZeroDivisionError
├── AssertionError
├── AttributeError
├── BufferError
├── EOFError
├── ImportError
│ └── ModuleNotFoundError
├── LookupError
│ ├── IndexError
│ └── KeyError
├── MemoryError
├── NameError
│ └── UnboundLocalError
├── OSError
│ ├── BlockingIOError
│ ├── ChildProcessError
│ ├── ConnectionError
│ │ ├── BrokenPipeError
│ │ ├── ConnectionAbortedError
│ │ ├── ConnectionRefusedError
│ │ └── ConnectionResetError
│ ├── FileExistsError
│ ├── FileNotFoundError
│ ├── InterruptedError
│ ├── IsADirectoryError
│ ├── NotADirectoryError
│ ├── PermissionError
│ ├── ProcessLookupError
│ └── TimeoutError
├── ReferenceError
├── RuntimeError
│ ├── NotImplementedError
│ └── RecursionError
├── SyntaxError
│ └── IndentationError
│ └── TabError
├── SystemError
├── TypeError
├── ValueError
│ └── UnicodeError
│ ├── DecodeError
│ ├── EncodeError
│ └── UnicodeTranslateError
└── Warning
├── DeprecationWarning
├── PendingDeprecationWarning
├── RuntimeWarning
├── SyntaxWarning
├── UserWarning
├── FutureWarning
├── ImportWarning
└── UnicodeWarning
"""
from enum import Enum
ERROR_CODE_TABLE = {
ImportError: "ops.0001",
ModuleNotFoundError: "ops.0002",
NameError: "ops.0003",
KeyError: "ops.0004",
IndexError: "ops.0005",
ValueError: "ops.0006",
TypeError: "ops.0007",
SyntaxError: "ops.0008",
AttributeError: "ops.0009",
ArithmeticError: "ops.0010",
MemoryError: "ops.0011",
OSError: "ops.0012",
FileNotFoundError: "ops.0013",
NotADirectoryError: "ops.0014",
PermissionError: "ops.0015",
TimeoutError: "ops.0016",
}
UNKNOWN_ERROR_CODE = "ops.9999"
class ErrorCode(Enum):
# 通用错误
SUCCESS = (0, "Success")
UNKNOWN_ERROR = (1, "Unknown error")
FILE_NOT_FOUND_ERROR = (1000, "File not found!")
SUBMIT_TASK_ERROR = (1001, "Task submitted Failed!")
CANCEL_TASK_ERROR = (1002, "Task canceled Failed!")

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
from datetime import datetime
import cv2
import numpy as np
import os
import pytz
from loguru import logger
def check_valid_path(file_path):
full_path = os.path.abspath(file_path)
return os.path.exists(full_path)
def get_realpath_with_prefix_check(path, prefix):
realpath = os.path.realpath(path)
if realpath.startswith(prefix):
return realpath
else:
raise ValueError(f"The path {realpath} does not start with the prefix '{prefix}'.")
def bytes_to_numpy(image_bytes):
"""bytes转数组"""
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2
def numpy_to_bytes(image_np, file_type):
"""数组转bytes"""
if not image_np.size:
return b""
data = cv2.imencode(file_type, image_np)[1]
image_bytes = data.tobytes()
return image_bytes
def get_now_time(timezone, time_format, file_name, method):
timestamp = ""
try:
china_tz = pytz.timezone(timezone) # 设置时区
china_time = datetime.now(tz=china_tz) # 获取当前时间并转换为对应的时区
timestamp = china_time.strftime(time_format) # 格式化输出时间
except ValueError as e:
logger.error("fileName: %s, method: %s, formatting time failed: %s", file_name, method, e, exc_info=True)
return timestamp
def decrypt(enc_pass):
import kmc.kmc as K
os.environ['KMC_DATA_USER'] = 'modelenginepublic'
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = enc_pass
dec_pass = K.API().decrypt(0)
os.environ['KMC_PYTHON_ENCRYPT_DATA'] = ""
return dec_pass

View File

@@ -0,0 +1,115 @@
# -- encoding: utf-8 --
from collections import deque
class TrieNode:
def __init__(self, value):
self.value = value
self.child = dict()
self.fail = None
self.word = None
class AhoCorasic:
"""AC自动机算法进行目标字符串搜索"""
def __init__(self, words):
self._root = add_fail_pointer(build_trie(words))
def search(self, text: str, special_symbols: set):
"""
匹配敏感词。
Args:
text: 文本
special_symbols: 特殊字符(需跳过)
Returns:
匹配成功的字符串列表
"""
seq_list = []
node = self._root
valid_len = 0 # 当前遍历的有效长度
for i, s in enumerate(text):
if s in special_symbols: # 跳过特殊字符
if valid_len != 0:
valid_len += 1
continue
matched = True
while s not in node.child: # 当node.child没有字符s
if node == self._root: # 当node为root(无node.fail),有效长度归0且跳出
valid_len = 0
matched = False
break
elif node.fail == self._root: # node.fail为root场景,有效长度归0,但可继续
valid_len = 0
node = node.fail # 移动到失败指针节点
if not matched:
continue
node = node.child.get(s)
valid_len += 1
if node.word: # node是单词尾字母
sensitive_word = text[i - valid_len + 1:i + 1]
seq_list.append(sensitive_word)
seq_list = list(set(seq_list))
return seq_list
def build_trie(words: list):
"""
构建前缀树。
Args:
words: 敏感词列表。
Returns:
前缀树根节点。
"""
root = TrieNode('root')
for word in words:
node = root
for s in word:
if s not in node.child:
node.child[s] = TrieNode(s)
node = node.child[s]
if not node.word:
node.word = {word}
else:
node.word.add(word)
return root
def add_fail_pointer(root: TrieNode):
"""
为前缀树添加失败指针。
步骤:
1. 从root开始逐层将node和node.parent以二元组存放队列。root没有fail指针,root.child的失败指针即为root。
2. 对于root和root.child以外的node,查询node.parent.fail.child。
3. 如果存在node.parent.fail.child.value == node.value,则构建node.fail = node.parent.fail.child.value。
Args:
root: 前缀树根节点。
returns:
添加失败指针后的前缀树根节点。
"""
queue = deque()
queue.appendleft((root, None))
while len(queue) > 0:
node_parent = queue.pop()
curr, parent = node_parent[0], node_parent[1]
for sub in curr.child.values():
queue.appendleft((sub, curr))
if parent is None:
continue
elif parent is root:
curr.fail = root
else:
parent_fail = parent.fail
while parent_fail and curr.value not in parent_fail.child:
parent_fail = parent_fail.fail
if parent_fail:
curr.fail = parent_fail.child[curr.value]
else:
curr.fail = root
return root

View File

@@ -0,0 +1,66 @@
# -- encoding: utf-8 --
import pickle
import base64
from io import BytesIO
import cv2
import numpy as np
from PIL import Image
def bytes_to_numpy(image_bytes):
"""bytes转数组"""
image_np = np.frombuffer(image_bytes, dtype=np.uint8)
image_np2 = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
return image_np2
def numpy_to_bytes(image_np, file_type):
"""
数组转bytes
Params:
file_type: as required by OpenCV, extension must have a leading period.
"""
if not image_np.size:
return b""
data = cv2.imencode(file_type, image_np)[1]
image_bytes = data.tobytes()
return image_bytes
def pil_to_bytes(src: Image.Image) -> bytes:
"""将 PIL.Image 转换为字节流"""
# 确保图像是 RGB 模式
src = src.convert("RGB")
with BytesIO() as bytes_io:
src.save(bytes_io, format='PNG')
im_bytes = bytes_io.getvalue()
return im_bytes
def bytes_to_pil(src: bytes) -> Image.Image:
"""将字节流转换为 PIL.Image"""
with BytesIO() as bytes_io:
with Image.open(bytes_io) as pil_img: # 使用with/as语句确保资源被正确释放
pil_img.load() # 确保图像数据被加载
return pil_img.copy() # 返回图像的副本以避免资源被关闭后无法使用
def pil_to_base64(src: Image.Image):
"""PIl.Image转base64"""
with BytesIO() as img_buffer:
src.save(img_buffer, format='png')
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data)
return base64_str
def obj_to_bytes(src: object) -> bytes:
return pickle.dumps(src)
def bytes_to_obj(src: bytes) -> object:
return pickle.loads(src)

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import importlib.abc
import importlib.util
from pathlib import Path
class CustomImporter(importlib.abc.MetaPathFinder):
def __init__(self, base_path):
self.base_path = Path(base_path).resolve()
def find_spec(self, fullname, path, target=None):
# 将模块名转换为路径(例如:mypkg.mymodule -> mypkg/mymodule.py)
parts = fullname.split(".")
module_path = self.base_path.joinpath(*parts)
# 检查是否存在 .py 文件或目录
if module_path.with_suffix(".py").exists():
return importlib.util.spec_from_file_location(
fullname,
str(module_path.with_suffix(".py")),
submodule_search_locations=[str(module_path.parent)]
)
elif module_path.is_dir() and (module_path / "__init__.py").exists():
return importlib.util.spec_from_file_location(
fullname,
str(module_path / "__init__.py"),
submodule_search_locations=[str(module_path)]
)
else:
return None

View File

@@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
import os
import re
import sys
import importlib
import subprocess
from pathlib import Path
from types import ModuleType
from loguru import logger
from packaging.version import parse as parse_version
def is_valid_whl_filename(package_name):
"""
验证WHL文件名是否安全(聚焦防范命令注入攻击)
Args:
package_name (str): 要验证的whl文件名
Returns:
bool: True表示安全,False表示包含危险字符
"""
# 禁止路径分隔符
if os.path.sep in package_name or (os.path.altsep and os.path.altsep in package_name):
return False
# 定义危险字符黑名单
dangerous_pattern = re.compile(r"""
[ # 匹配任意以下字符
; & | ` # 命令分隔符
$ () {} <> # 变量展开/命令替换
\] \[ # 特殊符号
! \\ # 历史扩展/转义符
'"*?#\s # 引号/通配符/空格
]
""", re.VERBOSE)
if dangerous_pattern.search(package_name):
return False
return True
class PackageNotFoundError(Exception):
pass
class LazyLoader(ModuleType):
def __init__(self,
package_name,
module_name=None,
whl_path="/dataset/ops_whl",
exact_version=None,
force_reinstall=False
):
"""
:param package_name: WHL包名称中的模块名称部分(一般是模块名称_替换为-)
:param module_name: WHL包安装后,可用于import的模块名称, 当whl包名称和import名称不一致时,填写local_name.
:param whl_path: WHL文件所在目录
:param exact_version: 精确版本要求
:param force_reinstall: 强制重新安装
"""
try:
frame = sys._getframe(1)
self._parent_globals = frame.f_globals
except (AttributeError, ValueError) as e:
logger.error(f"Failed to get stack frame: {e}")
raise RuntimeError("Stack frame retrieval failed") from e
self._module_name = module_name if module_name else package_name
self._package_name = package_name
self.whl_path = Path(whl_path).resolve()
self.exact_version = exact_version
self.force_reinstall = force_reinstall
self._cached_module = None
# 注册别名到父级命名空间
self._parent_globals[self._module_name] = self
super().__init__(self._module_name)
def __getattr__(self, name):
if self._cached_module is None:
self._cached_module = self._load_module()
return getattr(self._cached_module, name)
def __dir__(self):
return dir(self._load_module())
def _load_module(self):
"""模块加载逻辑"""
if self._cached_module is not None:
return self._cached_module
package_name: str = self._package_name.split('.')[0]
if not is_valid_whl_filename(package_name):
logger.error(f"Invalid package_name: {package_name}")
raise RuntimeError("Invalide package_name, please check it again!")
module_name = self._module_name if self._module_name else package_name.replace("_", "-")
need_install = False
try:
if not self.force_reinstall:
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
return module
except ImportError:
need_install = True
if self.force_reinstall:
# 强制安装时的版本检查
installed = self._check_package_exists(package_name)
if installed and self.exact_version:
installed_version = self._get_installed_version(package_name)
if parse_version(installed_version) != parse_version(self.exact_version):
logger.info(f"Version mismatch detected: {installed_version} vs {self.exact_version}")
need_install = True
else:
need_install = True
if need_install:
self._pip_install_package(package_name)
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
else:
# 版本检查通过,无需再次安装
module = importlib.import_module(module_name)
self._cached_module = module
self._register_alias(module)
return self._cached_module
def _register_alias(self, module):
"""注册本地别名 """
self._parent_globals[self._module_name] = module
sys.modules[self._module_name] = module
def _check_package_exists(self, package_name):
"""增强版包检查"""
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "show", package_name],
capture_output=True,
text=True
)
return result.returncode == 0
except subprocess.SubprocessError as e:
logger.error(f"Package check failed: {e}")
return False
def _get_installed_version(self, package_name):
"""获取已安装版本 """
result = subprocess.run(
[sys.executable, "-m", "pip", "show", package_name],
capture_output=True,
text=True
)
for line in result.stdout.split('\n'):
if line.startswith('Version:'):
return line.split()[-1]
raise PackageNotFoundError()
def _pip_install_package(self, package_name: str):
"""安装逻辑 """
if not self.whl_path.exists():
raise FileNotFoundError(f"WHL directory not found: {self.whl_path}")
whl_files = list(self.whl_path.glob(f"{package_name}*.whl"))
if not whl_files:
raise RuntimeError(f"No WHL files found for {package_name}")
# 版本过滤
if self.exact_version:
pattern = re.compile(
rf'^{re.escape(package_name)}-{re.escape(self.exact_version)}-\S*\.whl$',
re.IGNORECASE
)
whl_files = [f for f in whl_files if pattern.match(f.name)]
# 选择最新版本
whl_versions = []
for f in whl_files:
try:
version = self._extract_version(f)
whl_versions.append((f, version))
except ValueError:
continue
if not whl_versions:
raise FileNotFoundError("No valid WHL files")
whl_versions.sort(key=lambda x: x[1], reverse=True)
target_whl = whl_versions[0][0]
# 执行安装
try:
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"--no-index",
f"--find-links={self.whl_path}",
str(target_whl)
], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
logger.info(f"Successfully installed {target_whl}")
except subprocess.CalledProcessError as e:
logger.error(f"Installation failed: {e}")
raise RuntimeError(f"Installation failed: {e}") from e
def _extract_version(self, filename):
"""版本解析 """
version_pattern = r"(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)"
match = re.search(
rf"^{re.escape(self._package_name)}-({version_pattern})",
filename.name,
re.IGNORECASE
)
if not match:
raise ValueError(f"Invalid version format: {filename.name}")
return parse_version(match.group(1))

View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
import json
import os
import ssl
from pathlib import Path
from typing import Dict
from urllib.request import Request, urlopen
import requests
import urllib3
from loguru import logger
from datamate.common.utils import decrypt
class LlmReq:
# 定义常量用于解释错误码
ERRORCODE_INCOMPLETE_CONFIG = 83005
ERRORCODE_INVALID_RESPONSE = 83006
ERRORCODE_SERVICE_UNAVAILABLE = 83007
def __init__(self, url: str = None, header: Dict = None, body: Dict = None, access_type: int = None,
is_https: bool = False, is_certificate: bool = False, certificate_path: Path = None):
self.url = url
self.header = header
self.access_type = access_type
self.is_https = is_https
self.context = self._load_certificate(certificate_path, is_certificate) if is_https else None
self.body = body
if not self.body.get("messages", [])[0].get("content"):
self.body["messages"][0]["content"] = "你好"
def __call__(self, input_str: str) -> str:
outputs = ''
try:
self.body["messages"][0]["content"] = input_str
outputs = self._call_service()
except KeyError as e:
logger.error(f"The body format is not completed, error detail: {e}")
self.body["messages"][0]["content"] = "你好"
return outputs
@staticmethod
def _load_certificate(certificate_path: Path, is_certificate: bool) -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
if is_certificate:
context.load_verify_locations(certificate_path)
context.verify_mode = ssl.CERT_REQUIRED
else:
context.verify_mode = ssl.CERT_NONE
return context
@staticmethod
def _pool_manager():
cert_file = os.getenv("RAY_TLS_SERVER_CERT", "/certPersonal/global/identity/global.crt")
key_file = os.getenv("RAY_TLS_SERVER_KEY", "/certPersonal/global/identity/global.key")
ca_crt = os.getenv("RAY_TLS_CA_CERT", "/certPersonal/global/trust/ca.crt")
pwd = os.getenv("RAY_TLS_SERVER_KEY_PASSWORD", "/certPersonal/global/identity/pwd.txt")
key_password = os.getenv("GLOBAL_PWD", None)
if not key_password:
with open(pwd, "r") as f:
key_password = f.read().strip()
key_password = decrypt(key_password)
pool_manager = urllib3.PoolManager(cert_file=cert_file,
key_file=key_file,
key_password=key_password,
cert_reqs='CERT_REQUIRED',
ca_certs=ca_crt,
assert_hostname='edatamate',
ssl_version='TLSv1_2')
return pool_manager
def _call_service(self):
if not all([self.url, self.header, self.body.get("messages", [])[0].get("content")]):
logger.error("LLM is not configured completely")
raise RuntimeError(self.ERRORCODE_INCOMPLETE_CONFIG, "LLM is not configured completely") from None
if not self.access_type:
try:
pool_manager = self._pool_manager()
response = pool_manager.request(
"POST",
url=self.url,
body=json.dumps(self.body).encode(),
headers=self.header
)
logger.info(f"Response status code: {response.status}")
response_json = json.loads(response.data.decode('utf-8'))
outputs = response_json.get("choices", [])[0].get("message", {}).get("content")
if not outputs:
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
"Invalid response format for LLM, missing the 'prompt' key word") from None
return outputs
except Exception as e:
logger.error(f"LLM service is not available, error detail: {e}")
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
if self.access_type:
try:
if self.is_https:
req = Request(url=self.url, data=json.dumps(self.body).encode(), headers=self.header, method="POST")
response_json = urlopen(req, context=self.context).read().decode("utf-8")
response = json.loads(response_json)
else:
response = requests.post(url=self.url, data=json.dumps(self.body), headers=self.header,
stream=False).json()
outputs = response.get("choices", [])[0].get("message", {}).get("content")
if not outputs:
logger.error("Invalid response format for LLM, missing the 'prompt' key word")
raise RuntimeError(self.ERRORCODE_INVALID_RESPONSE,
"Invalid response format for LLM, missing the 'prompt' key word") from None
return outputs
except Exception as e:
logger.error(f"LLM service is not available, error detail: {e}")
raise RuntimeError(self.ERRORCODE_SERVICE_UNAVAILABLE, "LLM service is not available") from None
return None # 确保在所有情况下都返回

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
import sys
import re
import subprocess
from pathlib import Path
from typing import Optional
from loguru import logger
from packaging.version import parse as parse_version, Version
def install_whl(
package_name: str,
whl_path: str,
exact_version: Optional[str] = None,
filename_pattern: Optional[str] = None,
force_reinstall: bool = False
) -> None:
"""
:param package_name: eg: ("zh_core_web_sm")
:param whl_path: WHL file save path
:param exact_version: version number
:param filename_pattern: custom filename pattern for REGEX
:param force_reinstall: which decide to overlap the original number or not (default: False)
"""
whl_dir = Path(whl_path).resolve()
whl_files = _get_whl_files(exact_version, filename_pattern, package_name, whl_dir)
# 语义化版本排序
target_whl = _sort_whl_files(whl_files)
# 安装命令
cmd = [
sys.executable, "-m", "pip", "install",
"--no-index",
f"--find-links={whl_dir}",
str(target_whl)
]
if force_reinstall:
cmd.append("--force-reinstall")
try:
subprocess.check_call(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT
)
logger.info(f"Successfully installed {target_whl.name}")
except subprocess.CalledProcessError as e:
error_msg = (
f"Installation failed for {package_name}\n"
f"Possible reasons:\n"
f"1. Missing dependencies in {whl_dir}\n"
f"2. Incompatible Python version\n"
f"3. Platform mismatch (e.g., x86 vs ARM)"
)
raise RuntimeError(error_msg) from e
def _sort_whl_files(whl_files):
whl_versions = []
logger.info(f"[load_offline_module]whl_files: {whl_files}")
for f in whl_files:
try:
version = _extract_version(f)
whl_versions.append((f, version))
except ValueError as e:
logger.warning(f"Skipping invalid file {f.name}: {e}")
continue
if not whl_versions:
raise FileNotFoundError("No valid WHL files with parseable versions")
whl_versions.sort(key=lambda x: x[1], reverse=True)
target_whl = whl_versions[0][0]
return target_whl
def _get_whl_files(exact_version, filename_pattern, package_name, whl_dir):
# 正则表达式
if filename_pattern:
pattern = filename_pattern
else:
if exact_version:
version_part = re.escape(exact_version)
pattern = rf"^{re.escape(package_name)}-{version_part}-\S*\.whl$"
else:
pattern = rf"^{re.escape(package_name)}\S*\.whl$"
regex = re.compile(pattern, re.IGNORECASE)
whl_files = [f for f in whl_dir.glob("*.whl") if regex.match(f.name)]
if not whl_files:
available_files = "\n".join([f.name for f in whl_dir.glob("*.whl")])
raise FileNotFoundError(
f"No matching WHL found for {package_name} in {whl_dir}\n"
f"Available files:\n{available_files}"
)
return whl_files
def _extract_version(filename: Path) -> Version:
"""从文件名提取语义化版本("""
match = re.search(
r"-(\d+([.]\d+)+([ab]|rc\d+)*([.]post\d+)*([.]dev\d+)*)-",
filename.name
)
if not match:
raise ValueError(f"Invalid version format: {filename.name}")
return parse_version(match.group(1))

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
from loguru import logger
class Registry(object):
"""注册器类,用于注册所有的算子."""
def __init__(self, name: str):
self._name = name
self._modules = {}
@property
def name(self):
return self._name
@property
def modules(self):
return self._modules
def list(self):
"""日志打印注册器中所有的算子"""
for m in self._modules.keys():
logger.info(f'{self._name}\t{m}')
def get(self, module_key):
return self._modules.get(module_key, None)
def register_module(self, module_name: str = None, module_cls: type = None, module_path: str = None, force=False):
"""
使用特定的模块名称注册模块
:param module_name: 模块名称
:param module_cls: 模块类定义
:param module_path: 模块所在的路径
:param force: 是否强行覆盖同名模块,默认值为False.
Example:
>>> registry = Registry()
>>> @registry.register_module()
>>> class TextFormatter:
>>> pass
>>> class TextFormatter2:
>>> pass
>>> registry.register_module( module_name='text_formatter2', module_cls=TextFormatter2)
"""
if not (module_name is None or isinstance(module_name, str)):
raise TypeError(f'module_name must be either of None, str,'
f'got {type(module_name)}')
if module_cls is not None:
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
elif module_cls is None and isinstance(module_path, str):
self._register_module(module_name=module_name,
module_path=module_path,
force=force)
return module_path
def _register(module_cls):
"""
注册其中module_cls为None是,返回装饰器函数
"""
self._register_module(module_name=module_name,
module_cls=module_cls,
force=force)
return module_cls
return _register
def _register_module(self, module_name=None, module_cls=None, module_path=None, force=False):
"""
注册模块到注册器中.
:param module_name: 模块名称
:param module_cls: 模块类定义
:param force: 是否强行覆盖同名模块,默认值为False.
"""
if module_name is None and module_cls is not None:
module_name = module_cls.__name__
if module_name in self._modules:
if module_cls is not None and module_cls == self._modules[module_name]:
return
if module_path is not None and module_path == self._modules[module_name]:
return
if not force:
raise KeyError(
f'{module_name} is already registered in {self._name}, content: {self.modules.keys()}')
if module_cls is not None:
self._modules[module_name] = module_cls
elif module_path is not None:
self._modules[module_name] = module_path

View File

@@ -0,0 +1,171 @@
#!/user/bin/python
import re
from typing import List
from collections import deque
from loguru import logger
class TextSplitter:
"""文本切片"""
# 基于常用标点符号分句,保持句子完整
COMMON_PUNCTUATIONS = ["", "", "", "", "", ",", "?", "!", ";"]
PUNC_PATTERN = f"[{''.join(COMMON_PUNCTUATIONS)}]"
def __init__(self, max_characters: int, chunk_size: int, chunk_overlap: int):
"""文本切片初始化
Args:
max_characters :文件最大字符,超过截断,-1不处理
chunk_size: 块大小
chunk_overlap: 块重叠度
"""
if chunk_size <= chunk_overlap:
logger.error(f"param chunk_size should larger than chunk_overlap, "
f"current chunk_size: {chunk_size}, chunk_overlap: {chunk_overlap}")
raise Exception(83000, str(ValueError)) from None
self.max_characters = max_characters
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = ["\n\n", "\n"]
@staticmethod
def split_text_by_separator(text: str, separator: str):
"""指定分隔符对文本进行切分,并且切分后的片段需要保留分隔符"""
# 处理一个换行符与两个换行符之间的冲突
if text.startswith("\n\n") and separator == "\n":
chunks = re.split(f"({separator})", text.strip())
chunks[0] = f"\n\n{chunks[0]}"
else:
chunks = re.split(f"({separator})", text)
new_chunks = [chunks[idx] + chunks[idx + 1] for idx in range(1, len(chunks), 2)]
new_chunks = [chunks[0]] + new_chunks
return [chunk for chunk in new_chunks if chunk.strip() != ""]
@staticmethod
def split_sentences(chunk: str):
"""对切片按照标点符号切分成句子,并且保持标点符号不丢失"""
sentences = re.split(TextSplitter.PUNC_PATTERN, chunk)
delimiters = [s for s in chunk if s in TextSplitter.COMMON_PUNCTUATIONS]
restore_chunks = []
for chunk, delimiter in zip(sentences[:-1], delimiters):
restore_chunks.append(chunk + delimiter)
return restore_chunks + [sentences[-1]]
def split_text(self, input_data: str):
if self.max_characters > 0:
logger.info(f"The document characters should be within: {self.max_characters}")
input_data = input_data[:self.max_characters]
logger.info(f"characters of the document: {len(input_data)}")
chunks = self.split_text_recursive(input_data, self.separators)
final_chunks = self.merge_chunks(chunks)
final_chunks = self.split_text_by_chunk_size(final_chunks)
return [chunk.strip() for chunk in final_chunks if chunk]
def split_text_recursive(self, input_data: str, separators: List[str]):
"""对文档按照分隔符优先级进行递归切分:
1. 符合chunk_size要求的切片不再切分。
2. 大于chunk_size要求的切片,继续进行递归切分。
Args:
input_data: 输入文本
separators: 分隔符
Returns:
List[str]: 切分后的文本片段
"""
chunks = []
cur_separator = ""
next_separators = []
for idx, sep in enumerate(separators):
sep = re.escape(sep)
if re.search(sep, input_data.strip()):
cur_separator = sep
next_separators = separators[idx + 1:]
break
if not cur_separator:
return [input_data]
else:
cur_chunks = TextSplitter.split_text_by_separator(input_data, cur_separator)
for chunk in cur_chunks:
if len(chunk.strip()) <= self.chunk_size:
chunks.append(chunk)
else:
if not next_separators:
chunks.append(chunk)
else:
next_chunks = self.split_text_recursive(chunk, next_separators)
chunks.extend(next_chunks)
return chunks
def merge_chunks(self, chunks: List[str]):
"""对切分后的文本片段进行合并,合并过程考虑overlap"""
final_chunks = []
idx = 0
while idx < len(chunks):
if len(chunks[idx]) >= self.chunk_size:
final_chunks.append(chunks[idx])
idx += 1
continue
merge_idxes = self.get_merge_idxes(idx, chunks)
content = ""
for inner_idx in merge_idxes:
content += chunks[inner_idx]
final_chunks.append(content)
idx = merge_idxes[-1] + 1
return final_chunks
def get_merge_idxes(self, cur_idx: int, chunks: List[str]):
"""获取可以合并的分片index,前向尽可能满足overlap,后向尽可能满足chunk_size"""
idxes = deque([cur_idx])
overlap_idx = cur_idx - 1
cur_len = len(chunks[cur_idx])
cur_idx += 1
# 获取overlap的index
over_lap_len = 0
while overlap_idx >= 0:
over_lap_len += len(chunks[overlap_idx])
if over_lap_len > self.chunk_overlap or (cur_len + over_lap_len) > self.chunk_size:
over_lap_len -= len(chunks[overlap_idx])
break
idxes.appendleft(overlap_idx)
overlap_idx -= 1
cur_len += over_lap_len
# 获取merge的index
while cur_idx < len(chunks):
cur_len += len(chunks[cur_idx])
if cur_len > self.chunk_size:
break
idxes.append(cur_idx)
cur_idx += 1
return idxes
def split_chunks(self, chunks: List[str]):
"""对超过`chunk_size`限制的切片进行截断,过程中需要考虑overlap参数"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
else:
start = 0
end = self.chunk_size
while end < len(chunk):
final_chunks.append(chunk[start: end])
start += self.chunk_size - self.chunk_overlap
end = start + self.chunk_size
final_chunks.append(chunk[start:])
return final_chunks
def split_text_by_chunk_size(self, chunks: List[str]):
"""对切片后超长的文本块进行二次切分,使用截断,并考虑overlap"""
final_chunks = []
for chunk in chunks:
if len(chunk) <= self.chunk_size:
final_chunks.append(chunk)
continue
sentences = TextSplitter.split_sentences(chunk)
sub_chunks = self.merge_chunks(sentences)
final_chunks.extend(self.split_chunks(sub_chunks))
return final_chunks

View File

@@ -0,0 +1,381 @@
# -*- coding: utf-8 -*-
import json
import os
import traceback
from typing import List, Dict, Any, Tuple
from loguru import logger
from datamate.common.error_code import ERROR_CODE_TABLE, UNKNOWN_ERROR_CODE
from datamate.common.utils.llm_request import LlmReq
from datamate.common.utils.registry import Registry
from datamate.common.utils import check_valid_path
from datamate.core.constant import Fields
from datamate.sql_manager.persistence_atction import TaskInfoPersistence
OPERATORS = Registry('Operators')
FAILED_STATUS = "FAILED"
SUCCESS_STATUS = "COMPLETED"
def get_exception_info(e):
exc_type = type(e).__name__ # 异常类型(如 'ZeroDivisionError')
exc_msg = str(e) # 异常原因(如 'division by zero')
# 提取详细的堆栈信息
tb = traceback.extract_tb(e.__traceback__) # 解析 traceback 对象
error_line = tb[-1].lineno # 错误发生的行号
error_file = tb[-1].filename # 错误发生的文件名
code_snippet = tb[-1].line # 错误行的代码
# 组合输出信息
error_info = (
f"错误类型: {exc_type}\n"
f"错误原因: {exc_msg}\n"
f"文件名: {error_file}\n"
f"行号: {error_line}\n"
f"代码行: {code_snippet}"
)
return error_info
class BaseOp:
"""
所有算子类的父类
"""
use_model = False
custom_ops = False
def __init__(self, *args, **kwargs):
self.accelerator = kwargs.get('accelerator', "cpu")
self.is_last_op = kwargs.get('is_last_op', False)
self._name = kwargs.get('op_name', None)
self.infer_model = None
self.text_key = kwargs.get('text_key', "text")
self.data_key = kwargs.get('data_key', "data")
self.image_key = kwargs.get('image_key', "image")
self.video_key = kwargs.get('video_key', "video")
self.audio_key = kwargs.get('audio_key', "audio")
self.filename_key = kwargs.get('fileName_key', "fileName")
self.filetype_key = kwargs.get('fileType_key', "fileType")
self.fileid_key = kwargs.get('fileId_key', "fileId")
self.filepath_key = kwargs.get('filePath_key', "filePath")
self.filesize_key = kwargs.get('fileSize_key', "fileSize")
self.export_path_key = kwargs.get('export_path_key', "export_path")
self.ext_params_key = kwargs.get('ext_params_key', "ext_params")
@property
def name(self):
if self._name:
return self._name
else:
return "UnknownOp"
@staticmethod
def is_npu_available():
try:
import torch_npu
return torch_npu.npu.is_available()
except ImportError as e:
logger.warning("Import torch_npu failed.")
return False
@staticmethod
def update_kwargs(sample: Dict[str, Any], not_update_keys=("text", "data", "meta")) -> Dict:
"""获取sample_data中文件相关的信息"""
res = {}
for k, v in sample.items():
if k not in not_update_keys:
res[k] = v
return res
@staticmethod
def _get_error_info(e: BaseException) -> Tuple[str, str]:
error_code = UNKNOWN_ERROR_CODE
exc_info = get_exception_info(e)
for exc_type in type(e).__mro__:
if exc_type in ERROR_CODE_TABLE.keys():
error_code = ERROR_CODE_TABLE[exc_type]
break
return error_code, exc_info
def use_npu(self):
"""确认算子是否可以使用npu"""
return self.accelerator == 'npu' and self.is_npu_available()
def get_model(self, *args, **kwargs):
if self.infer_model is None and self.use_model:
return self.init_model(*args, **kwargs)
else:
logger.info(f"Op named {self.name} get infer model Failed. please "
f" check Attribute self.use_model: {self.use_model} or model has been initialized!")
return self.infer_model
def init_model(self, *args, **kwargs):
"""执行函数(子类实现)"""
raise NotImplementedError("This is in BaseOp, plese re-define this method in Sub-classes")
def fill_sample_params(self, sample: Dict[str, Any], **kwargs):
if not sample.get("text", None):
sample[self.text_key] = ""
if not sample.get("data", None):
sample[self.data_key] = b""
if not sample[self.data_key] and not sample[self.text_key]:
sample.update(kwargs)
def create_failure_sample(self, sample: Dict[str, Any], op_name, excp: BaseException):
sample["execute_result"] = False
error_code, exc_info = self._get_error_info(excp)
failed_reason = {"op_name": op_name, "error_code": error_code, "reason": exc_info}
sample["failed_reason"] = failed_reason
class Mapper(BaseOp):
def __init__(self, *args, **kwargs):
super(Mapper, self).__init__(*args, **kwargs)
def __call__(self, sample: Dict[str, Any], **kwargs):
# 该算子前已有算子执行该文件失败
if sample.get(Fields.result) is False:
return sample
self.fill_sample_params(sample, **kwargs)
execute_status = FAILED_STATUS
try:
sample = self.execute(sample)
execute_status = SUCCESS_STATUS
except Exception as e:
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
self.create_failure_sample(sample, self.name, e)
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
f"{str(get_exception_info(e))}")
sample["execute_status"] = execute_status
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return sample
sample["execute_status"] = execute_status
# 加载文件成功执行信息到数据库
if self.is_last_op:
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return sample
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""执行函数(子类实现)"""
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
class Slicer(BaseOp):
def __init__(self, *args, **kwargs):
super(Slicer, self).__init__(*args, **kwargs)
self.target_file_type = None
def __call__(self, sample: Dict[str, Any], **kwargs):
# 该算子前已有算子执行该文件失败
if sample.get(Fields.result) is False:
return sample
self.fill_sample_params(sample, **kwargs)
sample_list = []
execute_status = FAILED_STATUS
try:
sample_list = self.execute(sample)
execute_status = SUCCESS_STATUS
except Exception as e:
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
self.create_failure_sample(sample, self.name, e)
self.load_sample_to_sample(sample, sample_list)
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
f"{str(get_exception_info(e))}")
sample["execute_status"] = execute_status
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return [sample]
self.load_sample_to_sample(sample, sample_list)
sample["execute_status"] = execute_status
# 加载文件成功执行信息到数据库
if self.is_last_op:
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return [sample]
@staticmethod
def load_sample_to_sample(sample: Dict, sample_list: List[Dict]):
"""使用sample中的k-v更新sample"""
for sample_i in sample_list:
for k, v in sample_i.items():
sample[k] = v
if not sample.get("fileNum", None):
sample["fileNum"] = 1
def execute(self, sample: Dict[str, Any]) -> List[Dict]:
"""执行函数(子类实现)"""
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
def save_patch_sample(self, sample: Dict[str, Any], patch_no, save_format="text"):
if save_format == "text":
target_file_type = 'txt'
elif save_format == "image":
target_file_type = 'png'
else:
target_file_type = None
raise RuntimeError(f"target file type is {target_file_type}!")
if self.target_file_type:
target_file_type = self.target_file_type
save_path = self.get_save_path(sample, patch_no, target_file_type)
self.save_file(sample, save_path)
def get_save_path(self, sample: Dict[str, Any], patch_no, target_type) -> str:
export_path = os.path.abspath(sample[self.export_path_key])
logger.info(f"export path: {export_path}.")
base_file_name, _ = os.path.splitext(sample[self.filename_key])
file_id = str(sample[self.fileid_key])
new_file_name = file_id + '_' + str(patch_no) + '.' + target_type
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
if not check_valid_path(export_path):
os.makedirs(export_path, exist_ok=True)
res = os.path.join(export_path, new_file_name)
return res
def save_file(self, sample, save_path):
# 以二进制格式保存文件
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
with open(save_path, 'wb') as f:
f.write(file_sample)
os.chmod(save_path, 0o640)
try:
parent_dir = os.path.dirname(save_path)
os.chmod(parent_dir, 0o770)
except Exception as e:
logger.warning("Failed to modify the permission on the parent_dir.")
logger.info(f"patch sample has been save to {save_path}.")
class Filter(BaseOp):
def __init__(self, *args, **kwargs):
super(Filter, self).__init__(*args, **kwargs)
def __call__(self, sample: Dict[str, Any], **kwargs):
# 该算子前已有算子执行该文件失败
if sample.get(Fields.result) is False:
return sample
self.fill_sample_params(sample, **kwargs)
execute_status = FAILED_STATUS
try:
sample = self.execute(sample)
execute_status = SUCCESS_STATUS
except Exception as e:
# 如果filter算子过滤失败, 不保留文件, 并记录文件执行信息到数据库
self.create_failure_sample(sample, self.name, e)
sample["execute_status"] = execute_status
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
f"{str(get_exception_info(e))}")
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return False
sample["execute_status"] = execute_status
# 文件无内容会被过滤
if sample[self.text_key] == "" and sample[self.data_key] == b"":
task_info = TaskInfoPersistence()
sample["fileSize"] = "0"
task_info.persistence_task_info(sample)
return False
# 加载文件成功执行信息到数据库
if self.is_last_op:
task_info = TaskInfoPersistence()
task_info.persistence_task_info(sample)
return True
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""执行函数(子类实现)"""
raise NotImplementedError("This is in Filter Class, plese re-define this method in Sub-classes")
class LLM(Mapper):
def __init__(self, *args, **kwargs):
super(LLM, self).__init__(*args, **kwargs)
self.llm = self.get_llm(*args, **kwargs)
self.prompt_template = None
self.target_file_type = None
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""执行函数(子类实现)"""
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
@staticmethod
def get_llm(*args, **kwargs):
url = kwargs.get("LLMUrl", '')
header = kwargs.get("LLMHeaders", {"Content-type": "application/json"})
body = kwargs.get("LLMBody", {})
access_type = kwargs.get("accessType", False)
is_https = kwargs.get("isHttps", False)
is_certificate = kwargs.get("isCertificate", False)
certificate_path = kwargs.get("certificatePath", None)
return LlmReq(url=url, header=header, body=body, access_type=access_type, is_https=is_https,
is_certificate=is_certificate, certificate_path=certificate_path)
def build_llm_prompt(self, *args, **kwargs):
"""执行函数(子类实现)"""
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
def save_sample(self, object_list: List, sample: Dict[str, Any]):
if self.target_file_type:
target_file_type = self.target_file_type
else:
target_file_type = "jsonl"
save_path = self.get_save_path(sample, target_file_type)
self.save_json_file(object_list, save_path)
def get_save_path(self, sample: Dict[str, Any], target_type) -> str:
export_path = os.path.abspath(sample[self.export_path_key])
logger.info(f"export path: {export_path}.")
base_file_name, _ = os.path.splitext(sample[self.filename_key])
file_id = str(sample[self.fileid_key])
new_file_name = file_id + '.' + target_type
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
if not check_valid_path(export_path):
os.makedirs(export_path, exist_ok=True)
res = os.path.join(export_path, new_file_name)
return res
@staticmethod
def save_json_file(object_list: List, save_path):
if len(object_list) == 0:
logger.warning("Please check the param: object_list, which has length equal to 0.")
return
try:
with open(save_path, 'w', encoding='utf-8') as f:
for item in object_list:
json_str = json.dumps(item, ensure_ascii=False)
f.write(json_str + '\n')
os.chmod(save_path, 0o640)
try:
parent_dir = os.path.dirname(save_path)
os.chmod(parent_dir, 0o770)
except Exception as e:
logger.warning("Failed to modify the permission on the parent_dir.")
except Exception as e:
raise RuntimeError(f"Save jsonl file Failed!, save_path: {save_path}.") from e
logger.info(f"LLM output has been save to {save_path}.")

View File

@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
class Fields(object):
result = 'execute_result'
instance_id = 'instance_id'
export_path = 'export_path'

View File

@@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import importlib
import sys
from abc import ABC, abstractmethod
import pyarrow as pa
from enum import Enum
from loguru import logger
from ray import data as rd
from datamate.core.base_op import Filter, Mapper, Slicer
from datamate.core.constant import Fields
from datamate.core.base_op import OPERATORS, BaseOp
rd.DataContext.get_current().enable_progress_bars = False
def is_valid_path(item, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, item))
return os.path.exists(full_path)
def new_get_num_npus(init_kwargs):
if init_kwargs.get("accelerator", "cpu") != "npu":
return 0.0
return 0.1
class Formatters(Enum):
"""
抽取算子和落盘算子枚举类
"""
FILE_EXPORTER = "FileExporter"
IMG_FORMATTER = "ImgFormatter"
OCR_FORMATTER = "OcrFormatter"
PDF_CPU_FORMATTER = "PdfCpuFormatter"
SLID_FORMATTER = "SlideFormatter"
TEXT_FORMATTER = "TextFormatter"
WORD_FORMATTER = "WordFormatter"
UNION_FORMATTER = "UnionFormatter"
ONNX_FORMATTER = "OnnxImg2TextFormatter"
@classmethod
def is_member(cls, op_name):
return op_name in cls._value2member_map_
class BasicDataset(ABC):
@abstractmethod
def process(
self,
cfg_process,
*,
exporter=None,
checkpointer=None
) -> BasicDataset:
pass
def preprocess_dataset(dataset: rd.Dataset, cfg) -> rd.Dataset:
columns = dataset.columns()
new_column_names = [getattr(Fields, attr_name)
for attr_name in vars(Fields)
if attr_name not in columns and not attr_name.startswith('__')]
def process_batch_arrow(table: pa.Table, names_list=None) -> pa.Table:
name2value_table = {
Fields.instance_id: cfg.instance_id,
Fields.export_path: cfg.export_path
}
for column_name in names_list:
if column_name in name2value_table.keys():
new_column_data = [name2value_table[column_name] for _ in range(len(table))]
else:
new_column_data = [None for _ in range(len(table))]
table = table.append_column(column_name, [new_column_data])
return table
if new_column_names:
dataset = dataset.map_batches(process_batch_arrow,
fn_kwargs={"names_list": new_column_names},
num_cpus=0.05,
batch_format='pyarrow')
return dataset
class RayDataset(BasicDataset):
def __init__(self,
dataset: rd.Dataset,
cfg=None) -> None:
self.onnx_ops_name = ["OnnxImg2TextFormatter", "OnnxImageContentFilter"]
self.npu_ops_name = ["Img2TextFormatter", "ImageContentFilter"]
self.data = preprocess_dataset(dataset, cfg)
def process(self,
cfg_process,
*,
exporter=None,
checkpointer=None,
**kwargs) -> BasicDataset:
# 从注册器加载类
operators_cls_list = []
init_kwargs_list = []
for index, process in enumerate(cfg_process):
op_name, init_kwargs = list(process.items())[0]
init_kwargs = {} if not init_kwargs else init_kwargs
init_kwargs.update({'op_name': op_name})
# 加载Ops module
temp_ops = self.load_ops_module(op_name)
if index == len(cfg_process) - 1:
init_kwargs["is_last_op"] = True
operators_cls_list.append(temp_ops)
init_kwargs_list.append(init_kwargs)
for cls_id, operators_cls in enumerate(operators_cls_list):
self._run_single_op(operators_cls, init_kwargs_list[cls_id], **kwargs)
return self
def load_ops_module(self, op_name):
'''
加载算子模块
:param op_name: 算子名称
:return: 算子对象
'''
parent_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ops")
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
registry_content = OPERATORS.modules[op_name]
if isinstance(registry_content, str):
# registry_content是module的路径
submodule = importlib.import_module(registry_content)
res = getattr(submodule, op_name, None)
if res is None:
raise ImportError(f"Import Ops module {op_name} Failed.")
else:
logger.info(f"Import Ops module {op_name} Success.")
elif isinstance(registry_content, type) and issubclass(registry_content, BaseOp):
# registry_content是module本身
res = registry_content
else:
res = None
return res
def _run_single_op(self, operators_cls, init_kwargs, **kwargs):
num_npus = new_get_num_npus(init_kwargs)
max_actor_nums = os.getenv("MAX_ACTOR_NUMS", "20")
# 分辨是否是onnx算子,如果是需要限制Actor并发数量
if self._use_onnx_model(init_kwargs['op_name']):
max_actor_nums = 4
resources = {}
if num_npus > 0:
resources["node_npu"] = 0.1
if init_kwargs.get("arch", "arm").startswith("x86"):
resources["arch"] = "x86"
kwargs.update({"ext_params": {}, "failed_reason": {}, "target_type": None})
try:
if issubclass(operators_cls, Mapper):
self.data = self.data.map(operators_cls,
fn_constructor_kwargs=init_kwargs,
fn_kwargs=kwargs,
resources=resources,
num_cpus=0.05,
concurrency=(1, 1 if operators_cls.use_model else int(max_actor_nums)))
elif issubclass(operators_cls, Slicer):
self.data = self.data.flat_map(operators_cls,
fn_constructor_kwargs=init_kwargs,
fn_kwargs=kwargs,
resources=resources,
num_cpus=0.05,
concurrency=(1, int(max_actor_nums)))
elif issubclass(operators_cls, Filter):
self.data = self.data.filter(operators_cls,
fn_constructor_kwargs=init_kwargs,
fn_kwargs=kwargs,
resources=resources,
num_cpus=0.05,
concurrency=(1, int(max_actor_nums)))
else:
logger.error(
'Ray executor only support Filter, Mapper and Slicer OPs for now')
raise NotImplementedError
except Exception as e:
logger.error(e)
raise Exception("Error! Ops Details:") from e
def _use_onnx_model(self, ops_name):
if ops_name in self.onnx_ops_name:
return True
return False
def _use_npu_model(self, ops_name):
if ops_name in self.npu_ops_name:
return True
return False

View File

@@ -0,0 +1,163 @@
import os
from typing import Optional, Dict, Any, List
import uvicorn
import yaml
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from jsonargparse import ArgumentParser
from loguru import logger
from pydantic import BaseModel
from datamate.common.error_code import ErrorCode
from datamate.scheduler import cmd_scheduler
from datamate.scheduler import func_scheduler
from datamate.wrappers import WRAPPERS
# 日志配置
LOG_DIR = "/var/log/data-mate/runtime"
os.makedirs(LOG_DIR, exist_ok=True)
logger.add(
f"{LOG_DIR}/runtime.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
level="DEBUG",
enqueue=True
)
app = FastAPI()
class APIException(Exception):
"""自定义API异常"""
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
extra_data: Optional[Dict] = None):
self.error_code = error_code
self.detail = detail or error_code.value[1]
self.code = error_code.value[0]
self.extra_data = extra_data
super().__init__(self.detail)
def to_dict(self) -> Dict[str, Any]:
result = {
"code": self.code,
"message": self.detail,
"success": False
}
if self.extra_data:
result["data"] = self.extra_data
return result
@app.exception_handler(APIException)
async def api_exception_handler(request: Request, exc: APIException):
return JSONResponse(
status_code=200, # 业务错误返回 200,错误信息在响应体中
content=exc.to_dict()
)
class QueryTaskRequest(BaseModel):
task_ids: List[str]
@app.post("/api/task/list")
async def query_task_info(request: QueryTaskRequest):
try:
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
except Exception as e:
raise APIException(ErrorCode.UNKNOWN_ERROR)
@app.post("/api/task/{task_id}/submit")
async def submit_task(task_id):
config_path = f"/flow/{task_id}/process.yaml"
logger.info("Start submitting job...")
dataset_path = get_from_cfg(task_id, "dataset_path")
if not check_valid_path(dataset_path):
logger.error(f"dataset_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
try:
executor_type = get_from_cfg(task_id, "executor_type")
await WRAPPERS.get(executor_type).submit(task_id, config_path)
except Exception as e:
logger.error(f"Error happens during submitting task. Error Info following: {e}")
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
logger.info(f"task id: {task_id} has been submitted.")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been submitted"},
status_code=200
)
return success_json_info
@app.post("/api/task/{task_id}/stop")
async def stop_task(task_id):
logger.info("Start stopping ray job...")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been stopped"},
status_code=200
)
try:
executor_type = get_from_cfg(task_id, "executor_type")
if not WRAPPERS.get(executor_type).cancel(task_id):
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
except Exception as e:
if isinstance(e, APIException):
raise e
raise APIException(ErrorCode.UNKNOWN_ERROR)
logger.info(f"{task_id} has been stopped.")
return success_json_info
def check_valid_path(file_path):
full_path = os.path.abspath(file_path)
return os.path.exists(full_path)
def get_from_cfg(task_id, key):
config_path = f"/flow/{task_id}/process.yaml"
if not check_valid_path(config_path):
logger.error(f"config_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
with open(config_path, "r", encoding='utf-8') as f:
content = f.read()
cfg = yaml.safe_load(content)
return cfg[key]
def parse_args():
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
parser.add_argument(
'--ip',
type=str,
default="0.0.0.0",
help='Service ip for this API, default to use 0.0.0.0.'
)
parser.add_argument(
'--port',
type=int,
default=8080,
help='Service port for this API, default to use 8600.'
)
return parser.parse_args()
if __name__ == '__main__':
p_args = parse_args()
uvicorn.run(
app,
host=p_args.ip,
port=p_args.port
)

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
import importlib
import os
import sys
from pathlib import Path
from loguru import logger
# 获取当前目录
current_dir = os.path.dirname(__file__)
# 遍历子目录
for module_name in os.listdir(current_dir):
module_path = os.path.join(current_dir, module_name)
# 检查是否是目录且包含 __init__.py
if os.path.isdir(module_path) and '__init__.py' in os.listdir(module_path):
# 动态导入模块
try:
importlib.import_module(f".{module_name}", package=__name__)
except Exception as e:
logger.error(f"Failed to load Ops {module_name}")

View File

@@ -0,0 +1,6 @@
from .cmd_task_scheduler import CommandScheduler
from .func_task_scheduler import CallableScheduler
cmd_scheduler = CommandScheduler(max_concurrent=5)
func_scheduler = CallableScheduler(max_concurrent=5)

View File

@@ -0,0 +1,214 @@
import asyncio
from datetime import datetime
from typing import Optional, List
from loguru import logger
from .scheduler import Task, TaskStatus, TaskResult, TaskScheduler
class CommandTask(Task):
"""命令任务包装类"""
def __init__(self, task_id: str, command: str, shell: bool = True,
timeout: Optional[int] = None, *args, **kwargs):
super().__init__(task_id, *args, **kwargs)
self.command = command
self.shell = shell
self.timeout = timeout
self.stdout = None
self.stderr = None
self.return_code = None
self._process = None
def start(self) -> 'CommandTask':
"""启动任务"""
if self.status == TaskStatus.PENDING:
self.status = TaskStatus.RUNNING
self.started_at = datetime.now()
self._task = asyncio.create_task(self._execute())
return self
async def _execute(self):
"""执行命令"""
try:
self.status = TaskStatus.RUNNING
self.started_at = datetime.now()
# 使用 asyncio.create_subprocess_shell 或 create_subprocess_exec
if self.shell:
process = await asyncio.create_subprocess_shell(
self.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
**self.kwargs
)
else:
process = await asyncio.create_subprocess_exec(
*self.command.split(),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
**self.kwargs
)
self._process = process
# 等待进程完成(带超时)
try:
if self.timeout:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=self.timeout
)
else:
stdout, stderr = await process.communicate()
self.stdout = stdout.decode() if stdout else ""
self.stderr = stderr.decode() if stderr else ""
self.return_code = process.returncode
if self._cancelled:
self.status = TaskStatus.CANCELLED
elif process.returncode == 0:
self.status = TaskStatus.COMPLETED
else:
self.status = TaskStatus.FAILED
except asyncio.TimeoutError:
# 超时处理
self._process.terminate()
try:
await asyncio.wait_for(self._process.wait(), timeout=5.0)
except asyncio.TimeoutError:
self._process.kill()
await self._process.wait()
self.status = TaskStatus.FAILED
self.stderr = f"Command timed out after {self.timeout} seconds"
except asyncio.CancelledError:
# 任务被取消
if self._process:
self._process.terminate()
try:
await asyncio.wait_for(self._process.wait(), timeout=5.0)
except asyncio.TimeoutError:
self._process.kill()
await self._process.wait()
self.status = TaskStatus.CANCELLED
self._cancelled = True
except Exception as e:
self.status = TaskStatus.FAILED
self.stderr = str(e)
finally:
self.completed_at = datetime.now()
def cancel(self) -> bool:
"""取消任务"""
if self._process and self.status == TaskStatus.RUNNING:
try:
# 尝试优雅终止
self._process.terminate()
self._cancelled = True
return True
except Exception:
# 如果无法终止,强制杀死
try:
self._process.kill()
self._cancelled = True
return True
except Exception:
return False
return False
def to_result(self) -> TaskResult:
"""转换为结果对象"""
self.result = {
"command": self.command,
"stdout": self.stdout,
"stderr": self.stderr,
"return_code": self.return_code,
}
return super().to_result()
class CommandScheduler(TaskScheduler):
"""命令调度器"""
def __init__(self, max_concurrent: int = 5):
super().__init__(max_concurrent)
async def submit(self, task_id, command: str, shell: bool = True,
timeout: Optional[int] = None, **kwargs) -> str:
"""提交命令任务"""
task = CommandTask(task_id, command, shell, timeout, **kwargs)
self.tasks[task_id] = task
# 使用信号量限制并发
async with self.semaphore:
# 异步执行任务
task.start()
logger.info(f"命令任务 {task_id} 已提交并开始执行")
return task_id
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
"""获取任务状态"""
task = self.tasks.get(task_id)
if task:
return task.to_result()
return None
def get_all_tasks(self) -> List[TaskResult]:
"""获取所有任务状态"""
return [task.to_result() for task in self.tasks.values()]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
task = self.tasks.get(task_id)
if not task:
return True
if task.status == TaskStatus.RUNNING:
cancelled = task.cancel()
if cancelled:
logger.info(f"命令任务 {task_id} 已取消")
return cancelled
return False
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
"""根据状态获取任务"""
return [
task.to_result()
for task in self.tasks.values()
if task.status == status
]
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
"""等待任务完成"""
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"任务 {task_id} 不存在")
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return task.to_result()
# 对于运行中的任务,我们已经通过 await task.execute() 等待了
# 所以这里直接返回结果
return task.to_result()
async def shutdown(self):
"""关闭调度器,取消所有运行中的任务"""
logger.info("正在关闭命令调度器...")
running_tasks = [
task for task in self.tasks.values()
if task.status == TaskStatus.RUNNING
]
for task in running_tasks:
logger.info(f"取消运行中的命令任务: {task.task_id}")
task.cancel()
logger.info("命令调度器已关闭")

View File

@@ -0,0 +1,133 @@
import asyncio
from datetime import datetime
from typing import Callable, Optional, List
from loguru import logger
from .scheduler import TaskStatus, TaskResult, Task, TaskScheduler
class CallableTask(Task):
"""任务包装类"""
def __init__(self, task_id: str, func: Callable, *args, **kwargs):
super().__init__(task_id, *args, **kwargs)
self.func = func
def start(self) -> 'CallableTask':
"""启动任务"""
if self.status == TaskStatus.PENDING:
self.status = TaskStatus.RUNNING
self.started_at = datetime.now()
self._task = asyncio.create_task(self._execute())
return self
async def _execute(self):
"""执行任务"""
try:
self.result = await self.func(*self.args, **self.kwargs)
self.status = TaskStatus.COMPLETED
except asyncio.CancelledError:
self.status = TaskStatus.CANCELLED
self._cancelled = True
except Exception as e:
self.status = TaskStatus.FAILED
self.error = str(e)
finally:
self.completed_at = datetime.now()
def cancel(self) -> bool:
"""取消任务"""
if self._task and not self._task.done():
self._task.cancel()
return True
return False
class CallableScheduler(TaskScheduler):
"""异步任务调度器"""
def __init__(self, max_concurrent: int = 10):
super().__init__(max_concurrent)
async def submit(self, task_id, func: Callable, *args, **kwargs) -> str:
"""提交任务"""
task = CallableTask(task_id, func, *args, **kwargs)
self.tasks[task_id] = task
# 使用信号量限制并发
async with self.semaphore:
task.start()
logger.info(f"任务 {task_id} 已提交并开始执行")
return task_id
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
"""获取任务状态"""
task = self.tasks.get(task_id)
if task:
return task.to_result()
return None
def get_all_tasks(self) -> List[TaskResult]:
"""获取所有任务状态"""
return [task.to_result() for task in self.tasks.values()]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
task = self.tasks.get(task_id)
if task and task.status == TaskStatus.RUNNING:
cancelled = task.cancel()
if cancelled:
logger.info(f"任务 {task_id} 已取消")
return cancelled
return False
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
"""根据状态获取任务"""
return [
task.to_result()
for task in self.tasks.values()
if task.status == status
]
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
"""等待任务完成"""
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"任务 {task_id} 不存在")
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return task.to_result()
# 等待任务完成
if task.get():
try:
await asyncio.wait_for(task.get(), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"任务 {task_id} 超时")
return task.to_result()
async def shutdown(self):
"""关闭调度器,取消所有运行中的任务"""
logger.info("正在关闭调度器...")
running_tasks = [
task for task in self.tasks.values()
if task.status == TaskStatus.RUNNING
]
for task in running_tasks:
logger.info(f"取消运行中的任务: {task.task_id}")
task.cancel()
# 等待所有任务完成
for task in running_tasks:
if task.get() and not task.get().done():
try:
await asyncio.wait_for(task.get(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning(f"任务 {task.task_id} 无法正常停止")
logger.info("调度器已关闭")

View File

@@ -0,0 +1,160 @@
# 任务状态枚举
import asyncio
import signal
import sys
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Optional, Dict, List
from loguru import logger
class TaskStatus(Enum):
PENDING = "pending" # 等待执行
RUNNING = "running" # 正在运行
COMPLETED = "completed" # 已完成
FAILED = "failed" # 执行失败
CANCELLED = "cancelled" # 已取消
@dataclass
class TaskResult:
"""任务结果数据类"""
task_id: str
status: TaskStatus
result: Any = None
error: Optional[str] = None
created_at: datetime = None
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
progress: float = 0.0
class Task:
def __init__(self, task_id: str, *args, **kwargs):
self.task_id = task_id
self.args = args
self.kwargs = kwargs
self.status = TaskStatus.PENDING
self.result = None
self.error = None
self.created_at = datetime.now()
self.started_at = None
self.completed_at = None
self.progress = 0.0
self._task = None # asyncio.Task 实例
self._cancelled = False
def get(self):
return self._task
def start(self) -> 'Task':
"""启动任务"""
pass
async def _execute(self):
"""执行任务"""
pass
def cancel(self) -> bool:
"""取消任务"""
pass
def to_result(self) -> TaskResult:
"""转换为结果对象"""
return TaskResult(
task_id=self.task_id,
status=self.status,
result=self.result,
error=self.error,
created_at=self.created_at,
started_at=self.started_at,
completed_at=self.completed_at,
progress=self.progress
)
class TaskScheduler:
"""异步任务调度器"""
def __init__(self, max_concurrent: int = 10):
self.max_concurrent = max_concurrent
self.tasks: Dict[str, Task] = {}
self.semaphore = asyncio.Semaphore(max_concurrent)
# 注册信号处理器
try:
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
except (ValueError, AttributeError):
# 在某些平台上可能不支持
pass
def _signal_handler(self, signum, frame):
"""信号处理器"""
logger.info(f"收到信号 {signum},正在清理任务...")
asyncio.create_task(self.shutdown())
sys.exit(0)
async def submit(self, task_id, task, *args, **kwargs) -> str:
"""提交任务"""
pass
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
"""获取任务状态"""
task = self.tasks.get(task_id)
if task:
return task.to_result()
return None
def get_all_tasks(self) -> List[TaskResult]:
"""获取所有任务状态"""
return [task.to_result() for task in self.tasks.values()]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
task = self.tasks.get(task_id)
if task and task.status == TaskStatus.RUNNING:
cancelled = task.cancel()
if cancelled:
logger.info(f"任务 {task_id} 已取消")
return cancelled
return False
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
"""根据状态获取任务"""
return [
task.to_result()
for task in self.tasks.values()
if task.status == status
]
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
"""等待任务完成"""
pass
async def shutdown(self):
"""关闭调度器,取消所有运行中的任务"""
pass
def get_statistics(self) -> Dict[str, int]:
"""获取统计信息"""
stats = {
TaskStatus.PENDING: 0,
TaskStatus.RUNNING: 0,
TaskStatus.COMPLETED: 0,
TaskStatus.FAILED: 0,
TaskStatus.CANCELLED: 0
}
for task in self.tasks.values():
stats[task.status] += 1
return {
"pending": stats[TaskStatus.PENDING],
"running": stats[TaskStatus.RUNNING],
"completed": stats[TaskStatus.COMPLETED],
"failed": stats[TaskStatus.FAILED],
"cancelled": stats[TaskStatus.CANCELLED],
"total": len(self.tasks)
}

View File

@@ -0,0 +1,2 @@

View File

@@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
import json
import time
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Dict, Any
from loguru import logger
from sqlalchemy import text
from datamate.sql_manager.sql_manager import SQLManager
class TaskInfoPersistence:
def __init__(self):
self.sql_dict = self.load_sql_dict()
@staticmethod
def load_sql_dict():
"""获取sql语句"""
sql_config_path = str(Path(__file__).parent / 'sql' / 'sql_config.json')
with open(sql_config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def persistence_task_info(self, sample: Dict[str, Any]):
instance_id = str(sample.get("instance_id"))
src_file_name = str(sample.get("sourceFileName"))
src_file_type = str(sample.get("sourceFileType"))
src_file_id = str(sample.get("sourceFileId"))
src_file_size = int(sample.get("sourceFileSize"))
file_id = str(uuid.uuid4())
file_size = str(sample.get("fileSize"))
file_type = str(sample.get("fileType"))
file_name = str(sample.get("fileName"))
status = str(sample.get("execute_status"))
failed_reason = str(sample.get("failed_reason"))
result_data = {
"instance_id": instance_id,
"src_file_id": src_file_id,
"dest_file_id": file_id,
"src_name": src_file_name,
"dest_name": file_name,
"src_type": src_file_type,
"dest_type": file_type,
"src_size": src_file_size,
"dest_size": file_size,
"status": status,
"result": failed_reason
}
self.insert_result(result_data, str(self.sql_dict.get("insert_clean_result_sql")))
dataset_id = str(sample.get("dataset_id"))
file_path = str(sample.get("filePath"))
create_time = datetime.now()
last_access_time = datetime.fromtimestamp(os.path.getmtime(file_path))
file_data = {
"id": file_id,
"dataset_id": dataset_id,
"file_name": file_name,
"file_path": file_path,
"file_type": file_type,
"file_size": file_size,
"status": "COMPLETED",
"upload_time": create_time,
"last_access_time": last_access_time,
"created_at": create_time,
"updated_at": create_time
}
self.insert_result(file_data, str(self.sql_dict.get("insert_dataset_file_sql")))
@staticmethod
def insert_result(data, sql):
retries = 0
max_retries = 20
retry_delay = 1
while retries <= max_retries:
try:
with SQLManager.create_connect() as conn:
conn.execute(text(sql), data)
return
except Exception as e:
if "database is locked" in str(e) or "locking protocol" in str(e):
retries += 1
time.sleep(retry_delay)
else:
logger.error("database execute failed: {}", str(e))
raise RuntimeError(82000, str(e)) from None
raise Exception("Max retries exceeded")
def update_result(self, dataset_id, instance_id, status):
dataset_data = {
"dataset_id": dataset_id
}
query_dataset_sql = str(self.sql_dict.get("query_dataset_sql"))
with SQLManager.create_connect() as conn:
result = conn.execute(text(query_dataset_sql), dataset_data)
if result:
rows = result.fetchall()
total_size = sum(int(row[0]) for row in rows)
file_count = len(rows)
else:
total_size = 0
file_count = 0
dataset_data.update({
"task_id": instance_id,
"total_size": total_size,
"file_count": file_count
})
update_dataset_sql = str(self.sql_dict.get("update_dataset_sql"))
self.insert_result(dataset_data, update_dataset_sql)
task_data = {
"task_id": instance_id,
"status": status,
"total_size": total_size,
"finished_time": datetime.now()
}
update_task_sql = str(self.sql_dict.get("update_task_sql"))
self.insert_result(task_data, update_task_sql)
def query_task_info(self, instance_ids: list[str]):
result = {}
current_result = None
for instance_id in instance_ids:
try:
current_result = self.execute_sql_query(instance_id)
except Exception as e:
logger.warning("instance_id: {}, query job result error: {}", instance_id, str(e))
if current_result:
result[instance_id] = current_result
return result
def execute_sql_query(self, instance_id):
result = None
create_tables_sql = str(self.sql_dict.get("create_tables_sql"))
query_sql = str(self.sql_dict.get("query_sql"))
with SQLManager.create_connect() as conn:
conn.execute(text(create_tables_sql))
execute_result = conn.execute(text(query_sql), {"instance_id": instance_id})
result = execute_result.fetchall()
return result
# todo 删除接口待实现
def delete_task_info(self, instance_id: str):
create_tables_sql = self.sql_dict.get("create_tables_sql")
delete_task_instance_sql = self.sql_dict.get("delete_task_instance_sql")
try:
with SQLManager.create_connect() as conn:
conn.execute(text(create_tables_sql))
conn.execute(text(delete_task_instance_sql), {"instance_id": instance_id})
except Exception as e:
logger.warning(f"delete database for flow: {instance_id}", e)
def delete_task_operate_info(self, instance_id: str):
create_duplicate_img_tables_sql = self.sql_dict.get("create_duplicate_img_tables_sql")
create_similar_img_tables_sql = self.sql_dict.get("create_similar_img_tables_sql")
create_similar_text_tables_sql = self.sql_dict.get("create_similar_text_tables_sql")
delete_duplicate_img_tables_sql = self.sql_dict.get("delete_duplicate_img_tables_sql")
delete_similar_img_tables_sql = self.sql_dict.get("delete_similar_img_tables_sql")
delete_similar_text_tables_sql = self.sql_dict.get("delete_similar_text_tables_sql")
try:
with SQLManager.create_connect() as conn:
conn.execute(text(create_duplicate_img_tables_sql))
conn.execute(text(delete_duplicate_img_tables_sql), {"instance_id": instance_id})
conn.execute(text(create_similar_img_tables_sql))
conn.execute(text(delete_similar_img_tables_sql), {"instance_id": instance_id})
conn.execute(text(create_similar_text_tables_sql))
conn.execute(text(delete_similar_text_tables_sql), {"instance_id": instance_id})
except Exception as e:
logger.warning(f"delete database for flow: {instance_id} error", e)

View File

@@ -0,0 +1,17 @@
{
"query_sql": "SELECT * FROM t_task_instance_info WHERE instance_id IN (:instance_id)",
"insert_sql": "INSERT INTO t_task_instance_info (instance_id, meta_file_name, meta_file_type, meta_file_id, meta_file_size, file_id, file_size, file_type, file_name, file_path, status, operator_id, error_code, incremental, child_id, slice_num) VALUES (:instance_id, :meta_file_name, :meta_file_type, :meta_file_id, :meta_file_size, :file_id, :file_size, :file_type, :file_name, :file_path, :status, :operator_id, :error_code, :incremental, :child_id, :slice_num)",
"insert_dataset_file_sql": "INSERT INTO t_dm_dataset_files (id, dataset_id, file_name, file_path, file_type, file_size, status, upload_time, last_access_time, created_at, updated_at) VALUES (:id, :dataset_id, :file_name, :file_path, :file_type, :file_size, :status, :upload_time, :last_access_time, :created_at, :updated_at)",
"insert_clean_result_sql": "INSERT INTO t_clean_result (instance_id, src_file_id, dest_file_id, src_name, dest_name, src_type, dest_type, src_size, dest_size, status, result) VALUES (:instance_id, :src_file_id, :dest_file_id, :src_name, :dest_name, :src_type, :dest_type, :src_size, :dest_size, :status, :result)",
"query_dataset_sql": "SELECT file_size FROM t_dm_dataset_files WHERE dataset_id = :dataset_id",
"update_dataset_sql": "UPDATE t_dm_datasets SET size_bytes = :total_size, file_count = :file_count WHERE id = :dataset_id;",
"update_task_sql": "UPDATE t_clean_task SET status = :status, after_size = :total_size, finished_at = :finished_time WHERE id = :task_id",
"create_tables_sql": "CREATE TABLE IF NOT EXISTS t_task_instance_info (instance_id VARCHAR(255), meta_file_name TEXT, meta_file_type VARCHAR(100), meta_file_id BIGINT, meta_file_size VARCHAR(100), file_id BIGINT, file_size VARCHAR(100), file_type VARCHAR(100), file_name TEXT, file_path TEXT, status INT, operator_id VARCHAR(255), error_code VARCHAR(100), incremental VARCHAR(50), child_id BIGINT, slice_num INT DEFAULT 0);",
"delete_task_instance_sql": "DELETE FROM t_task_instance_info WHERE instance_id = :instance_id",
"create_duplicate_img_tables_sql": "CREATE TABLE IF NOT EXISTS operator_duplicate_img_features (id INT AUTO_INCREMENT PRIMARY KEY,task_uuid VARCHAR(255),file_feature TEXT,file_name TEXT,timestamp DATETIME);",
"delete_duplicate_img_tables_sql": "DELETE FROM operator_duplicate_img_features WHERE flow_id = :flow_id",
"create_similar_img_tables_sql": "CREATE TABLE IF NOT EXISTS operator_similar_img_features (id INT AUTO_INCREMENT PRIMARY KEY,task_uuid VARCHAR(255),p_hash TEXT,des_matrix BLOB,matrix_shape TEXT,file_name TEXT,timestamp DATETIME);",
"delete_similar_img_tables_sql": "DELETE FROM operator_similar_img_features WHERE flow_id = :flow_id",
"create_similar_text_tables_sql": "CREATE TABLE IF NOT EXISTS operators_similar_text_features (id INT AUTO_INCREMENT PRIMARY KEY, task_uuid VARCHAR(255),file_feature TEXT,file_name TEXT,timestamp DATETIME);",
"delete_similar_text_tables_sql": "DELETE FROM operators_similar_text_features WHERE flow_id = :flow_id"
}

View File

@@ -0,0 +1,52 @@
# -- encoding: utf-8 --
import os
import time
from random import uniform
from loguru import logger
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import URL
class SQLManager:
@staticmethod
def create_connect(max_retries=5, base_delay=1):
"""
连接到 MySQL 数据库,使用 SQLAlchemy 和 PyMySQL。
:param max_retries: 最大重试次数
:param base_delay: 基础时延
:return: 返回 SQLAlchemy 连接对象
"""
connection_url = URL.create(
drivername="mysql+pymysql",
username=os.getenv("MYSQL_USER", "root"),
password=os.getenv("MYSQL_PASSWORD", "Huawei@123"),
host=os.getenv("MYSQL_HOST", "mysql"),
port=os.getenv("MYSQL_PORT", 3306),
database=os.getenv("MYSQL_DATABASE", "datamate"),
query={"charset": "utf8mb4"},
)
attempt = 0
while True:
try:
engine = create_engine(connection_url, pool_pre_ping=True, isolation_level="AUTOCOMMIT")
return engine.connect()
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed with error: {str(e)}")
if attempt >= max_retries - 1:
raise
wait_time = min(30, base_delay * (2 ** attempt)) # 不超过30秒的最大延时
jitter = uniform(-wait_time / 4, wait_time / 4) # 增加随机抖动因子
time.sleep(wait_time + jitter)
attempt += 1
if __name__ == "__main__":
with SQLManager.create_connect() as connection:
inspector = inspect(connection)
print(inspector.get_table_names())

View File

@@ -0,0 +1,6 @@
from . import data_juicer_wrapper, datamate_wrapper
WRAPPERS = {
"data_juicer": data_juicer_wrapper,
"datamate": datamate_wrapper
}

View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
from datamate.scheduler import cmd_scheduler
async def submit(task_id, config_path):
await cmd_scheduler.submit(task_id, f"dj-process --config {config_path}")

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
import base64
import json
import time
from typing import Dict
import ray
import yaml
from jsonargparse import dict_to_namespace, ArgumentParser
from loguru import logger
from datamate.common.utils import check_valid_path
from datamate.core.dataset import RayDataset
from datamate.sql_manager.persistence_atction import TaskInfoPersistence
import datamate.ops
class RayExecutor:
"""
基于Ray的执行器.
1. 当前仅支持Mapper,Filter类型的算子。
2. 当前仅加载json文件类型的数据集。
"""
def __init__(self, cfg=None, meta=None):
if isinstance(cfg, Dict):
self.cfg = dict_to_namespace(cfg)
else:
logger.error(f"Please set param: cfg as type Dict, but given cfg as type {type(cfg).__name__}")
raise TypeError(f"To params cfg, Dict type is required, but type {type(cfg).__name__} is given!")
self.cfg.process = cfg['process']
self.meta = meta
# init ray
logger.info('Initing Ray ...')
ray.init()
def load_meta(self, line):
meta = json.loads(line)
if meta.get("fileId"):
meta["sourceFileId"] = meta.get("fileId")
if meta.get("fileName"):
meta["sourceFileName"] = meta.get("fileName")
if meta.get("fileType"):
meta["sourceFileType"] = meta.get("fileType")
if meta.get("fileSize"):
meta["sourceFileSize"] = meta.get("fileSize")
if not meta.get("totalPageNum"):
meta["totalPageNum"] = 0
if not meta.get("extraFilePath"):
meta["extraFilePath"] = None
if not meta.get("extraFileType"):
meta["extraFileType"] = None
meta["dataset_id"] = self.cfg.dataset_id
return meta
def run(self):
# 1. 加载数据集
logger.info('Loading dataset with Ray...')
if self.meta:
file_content = base64.b64decode(self.meta)
lines = file_content.splitlines()
dataset = ray.data.from_items([json.loads(line) for line in lines])
else:
dataset = self.load_dataset()
dataset = RayDataset(dataset, self.cfg)
# 3. 处理数据
logger.info('Processing data...')
tstart = time.time()
dataset.process(self.cfg.process, **getattr(self.cfg, 'kwargs', {}))
tend = time.time()
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')
dataset.data.materialize()
def load_dataset(self):
retry = 0
dataset = None
jsonl_file_path = self.cfg.dataset_path
while True:
if check_valid_path(jsonl_file_path):
with open(jsonl_file_path, "r", encoding='utf-8') as meta:
lines = meta.readlines()
dataset = ray.data.from_items([self.load_meta(line) for line in lines])
break
if retry < 5:
retry += 1
time.sleep(retry)
continue
else:
logger.error(f"can not load dataset from dataset_path")
raise RuntimeError(f"Load dataset Failed!, dataset_path: {self.cfg.dataset_path}.")
return dataset
def update_db(self, status):
task_info = TaskInfoPersistence()
task_info.update_result(self.cfg.dataset_id, self.cfg.instance_id, status)
if __name__ == '__main__':
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
parser.add_argument("--config_path", type=str, required=False, default="../configs/demo.yaml")
parser.add_argument("--flow_config", type=str, required=False, default=None)
args = parser.parse_args()
config_path = args.config_path
flow_config = args.flow_config
if flow_config:
m_cfg = yaml.safe_load(base64.b64decode(flow_config))
else:
with open(config_path, "r", encoding='utf-8') as cfg:
m_cfg = yaml.safe_load(cfg)
executor = RayExecutor(m_cfg)
try:
executor.run()
except Exception as e:
executor.update_db("FAILED")
raise e
executor.update_db("COMPLETED")

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
import os
from datamate.scheduler import cmd_scheduler
async def submit(task_id, config_path):
current_dir = os.path.dirname(__file__)
await cmd_scheduler.submit(task_id, f"python {os.path.join(current_dir, 'datamate_executor.py')} "
f"--config_path={config_path}")
def cancel(task_id):
return cmd_scheduler.cancel_task(task_id)

View File

@@ -0,0 +1,76 @@
[project]
name = "datamate"
dynamic = ["version"]
description = "Data Processing for and with Foundation Models."
authors = [
{ name = "Huawei datamate team" }
]
readme = "README.md"
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
urls = { repository = "https://github.com/ModelEngine-Group/datamate" }
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Operating System :: OS Independent"
]
# Core dependencies
dependencies = [
"uvicorn",
"fastapi",
"loguru",
"jsonargparse",
"ray[default, data]==2.46.0",
"opencv-python"
]
[project.optional-dependencies]
dj = [
"py-data-juicer~=1.4.0"
]
op = [
"python-docx==1.1.0"
]
# All dependencies
all = [
"datamate[dj]",
"datamate[op]"
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.version]
path = "datamate/__init__.py"
[tool.hatch.build.targets.wheel]
packages = ["datamate"]
include = ["pyproject.toml"]
[tool.hatch.build]
include = ["pyproject.toml"]
[tool.hatch.build.targets.wheel.shared-data]
"pyproject.toml" = "pyproject.toml"
[tool.flake8]
per-file-ignores = [
"*/__init__.py: F401"
]
max-line-length = 120
extend-ignore = [
"E203", # whitespace before ':' (black handles this)
"E501", # line too long (black handles this)
"BLK100", # black would make changes (black handles this)
]
[tool.black]
line-length = 120
target-version = ['py310']
[tool.isort]
profile = "black"