You've already forked DataMate
init datamate
This commit is contained in:
0
runtime/python-executor/datamate/core/__init__.py
Normal file
0
runtime/python-executor/datamate/core/__init__.py
Normal file
381
runtime/python-executor/datamate/core/base_op.py
Normal file
381
runtime/python-executor/datamate/core/base_op.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from datamate.common.error_code import ERROR_CODE_TABLE, UNKNOWN_ERROR_CODE
|
||||
from datamate.common.utils.llm_request import LlmReq
|
||||
from datamate.common.utils.registry import Registry
|
||||
from datamate.common.utils import check_valid_path
|
||||
from datamate.core.constant import Fields
|
||||
from datamate.sql_manager.persistence_atction import TaskInfoPersistence
|
||||
|
||||
OPERATORS = Registry('Operators')
|
||||
|
||||
FAILED_STATUS = "FAILED"
|
||||
SUCCESS_STATUS = "COMPLETED"
|
||||
|
||||
|
||||
def get_exception_info(e):
|
||||
exc_type = type(e).__name__ # 异常类型(如 'ZeroDivisionError')
|
||||
exc_msg = str(e) # 异常原因(如 'division by zero')
|
||||
|
||||
# 提取详细的堆栈信息
|
||||
tb = traceback.extract_tb(e.__traceback__) # 解析 traceback 对象
|
||||
error_line = tb[-1].lineno # 错误发生的行号
|
||||
error_file = tb[-1].filename # 错误发生的文件名
|
||||
code_snippet = tb[-1].line # 错误行的代码
|
||||
|
||||
# 组合输出信息
|
||||
error_info = (
|
||||
f"错误类型: {exc_type}\n"
|
||||
f"错误原因: {exc_msg}\n"
|
||||
f"文件名: {error_file}\n"
|
||||
f"行号: {error_line}\n"
|
||||
f"代码行: {code_snippet}"
|
||||
)
|
||||
return error_info
|
||||
|
||||
|
||||
class BaseOp:
|
||||
"""
|
||||
所有算子类的父类
|
||||
"""
|
||||
|
||||
use_model = False
|
||||
custom_ops = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.accelerator = kwargs.get('accelerator', "cpu")
|
||||
self.is_last_op = kwargs.get('is_last_op', False)
|
||||
self._name = kwargs.get('op_name', None)
|
||||
self.infer_model = None
|
||||
self.text_key = kwargs.get('text_key', "text")
|
||||
self.data_key = kwargs.get('data_key', "data")
|
||||
self.image_key = kwargs.get('image_key', "image")
|
||||
self.video_key = kwargs.get('video_key', "video")
|
||||
self.audio_key = kwargs.get('audio_key', "audio")
|
||||
self.filename_key = kwargs.get('fileName_key', "fileName")
|
||||
self.filetype_key = kwargs.get('fileType_key', "fileType")
|
||||
self.fileid_key = kwargs.get('fileId_key', "fileId")
|
||||
self.filepath_key = kwargs.get('filePath_key', "filePath")
|
||||
self.filesize_key = kwargs.get('fileSize_key', "fileSize")
|
||||
self.export_path_key = kwargs.get('export_path_key', "export_path")
|
||||
self.ext_params_key = kwargs.get('ext_params_key', "ext_params")
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if self._name:
|
||||
return self._name
|
||||
else:
|
||||
return "UnknownOp"
|
||||
|
||||
@staticmethod
|
||||
def is_npu_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch_npu.npu.is_available()
|
||||
except ImportError as e:
|
||||
logger.warning("Import torch_npu failed.")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_kwargs(sample: Dict[str, Any], not_update_keys=("text", "data", "meta")) -> Dict:
|
||||
"""获取sample_data中文件相关的信息"""
|
||||
res = {}
|
||||
for k, v in sample.items():
|
||||
if k not in not_update_keys:
|
||||
res[k] = v
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _get_error_info(e: BaseException) -> Tuple[str, str]:
|
||||
|
||||
error_code = UNKNOWN_ERROR_CODE
|
||||
exc_info = get_exception_info(e)
|
||||
|
||||
for exc_type in type(e).__mro__:
|
||||
if exc_type in ERROR_CODE_TABLE.keys():
|
||||
error_code = ERROR_CODE_TABLE[exc_type]
|
||||
break
|
||||
|
||||
return error_code, exc_info
|
||||
|
||||
def use_npu(self):
|
||||
"""确认算子是否可以使用npu"""
|
||||
return self.accelerator == 'npu' and self.is_npu_available()
|
||||
|
||||
def get_model(self, *args, **kwargs):
|
||||
if self.infer_model is None and self.use_model:
|
||||
return self.init_model(*args, **kwargs)
|
||||
else:
|
||||
logger.info(f"Op named {self.name} get infer model Failed. please "
|
||||
f" check Attribute self.use_model: {self.use_model} or model has been initialized!")
|
||||
return self.infer_model
|
||||
|
||||
def init_model(self, *args, **kwargs):
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in BaseOp, plese re-define this method in Sub-classes")
|
||||
|
||||
def fill_sample_params(self, sample: Dict[str, Any], **kwargs):
|
||||
if not sample.get("text", None):
|
||||
sample[self.text_key] = ""
|
||||
|
||||
if not sample.get("data", None):
|
||||
sample[self.data_key] = b""
|
||||
|
||||
if not sample[self.data_key] and not sample[self.text_key]:
|
||||
sample.update(kwargs)
|
||||
|
||||
def create_failure_sample(self, sample: Dict[str, Any], op_name, excp: BaseException):
|
||||
sample["execute_result"] = False
|
||||
error_code, exc_info = self._get_error_info(excp)
|
||||
failed_reason = {"op_name": op_name, "error_code": error_code, "reason": exc_info}
|
||||
sample["failed_reason"] = failed_reason
|
||||
|
||||
|
||||
class Mapper(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Mapper, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return sample
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return sample
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
|
||||
|
||||
|
||||
class Slicer(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Slicer, self).__init__(*args, **kwargs)
|
||||
self.target_file_type = None
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
sample_list = []
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample_list = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 算子执行失败,记录文件执行信息到数据库,并更该文件执行结果状态
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
self.load_sample_to_sample(sample, sample_list)
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return [sample]
|
||||
|
||||
self.load_sample_to_sample(sample, sample_list)
|
||||
sample["execute_status"] = execute_status
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
|
||||
return [sample]
|
||||
|
||||
@staticmethod
|
||||
def load_sample_to_sample(sample: Dict, sample_list: List[Dict]):
|
||||
"""使用sample中的k-v更新sample"""
|
||||
for sample_i in sample_list:
|
||||
for k, v in sample_i.items():
|
||||
sample[k] = v
|
||||
if not sample.get("fileNum", None):
|
||||
sample["fileNum"] = 1
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> List[Dict]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Mapper Class, plese re-define this method in Sub-classes")
|
||||
|
||||
def save_patch_sample(self, sample: Dict[str, Any], patch_no, save_format="text"):
|
||||
if save_format == "text":
|
||||
target_file_type = 'txt'
|
||||
elif save_format == "image":
|
||||
target_file_type = 'png'
|
||||
else:
|
||||
target_file_type = None
|
||||
raise RuntimeError(f"target file type is {target_file_type}!")
|
||||
|
||||
if self.target_file_type:
|
||||
target_file_type = self.target_file_type
|
||||
save_path = self.get_save_path(sample, patch_no, target_file_type)
|
||||
self.save_file(sample, save_path)
|
||||
|
||||
def get_save_path(self, sample: Dict[str, Any], patch_no, target_type) -> str:
|
||||
export_path = os.path.abspath(sample[self.export_path_key])
|
||||
logger.info(f"export path: {export_path}.")
|
||||
base_file_name, _ = os.path.splitext(sample[self.filename_key])
|
||||
file_id = str(sample[self.fileid_key])
|
||||
new_file_name = file_id + '_' + str(patch_no) + '.' + target_type
|
||||
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
|
||||
if not check_valid_path(export_path):
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
res = os.path.join(export_path, new_file_name)
|
||||
return res
|
||||
|
||||
def save_file(self, sample, save_path):
|
||||
# 以二进制格式保存文件
|
||||
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(file_sample)
|
||||
|
||||
os.chmod(save_path, 0o640)
|
||||
try:
|
||||
parent_dir = os.path.dirname(save_path)
|
||||
os.chmod(parent_dir, 0o770)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to modify the permission on the parent_dir.")
|
||||
|
||||
logger.info(f"patch sample has been save to {save_path}.")
|
||||
|
||||
|
||||
class Filter(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Filter, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, sample: Dict[str, Any], **kwargs):
|
||||
# 该算子前已有算子执行该文件失败
|
||||
if sample.get(Fields.result) is False:
|
||||
return sample
|
||||
|
||||
self.fill_sample_params(sample, **kwargs)
|
||||
execute_status = FAILED_STATUS
|
||||
try:
|
||||
sample = self.execute(sample)
|
||||
execute_status = SUCCESS_STATUS
|
||||
except Exception as e:
|
||||
# 如果filter算子过滤失败, 不保留文件, 并记录文件执行信息到数据库
|
||||
self.create_failure_sample(sample, self.name, e)
|
||||
sample["execute_status"] = execute_status
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return False
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 文件无内容会被过滤
|
||||
if sample[self.text_key] == "" and sample[self.data_key] == b"":
|
||||
task_info = TaskInfoPersistence()
|
||||
sample["fileSize"] = "0"
|
||||
task_info.persistence_task_info(sample)
|
||||
return False
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
return True
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in Filter Class, plese re-define this method in Sub-classes")
|
||||
|
||||
|
||||
class LLM(Mapper):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LLM, self).__init__(*args, **kwargs)
|
||||
self.llm = self.get_llm(*args, **kwargs)
|
||||
self.prompt_template = None
|
||||
|
||||
self.target_file_type = None
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
|
||||
|
||||
@staticmethod
|
||||
def get_llm(*args, **kwargs):
|
||||
url = kwargs.get("LLMUrl", '')
|
||||
header = kwargs.get("LLMHeaders", {"Content-type": "application/json"})
|
||||
body = kwargs.get("LLMBody", {})
|
||||
access_type = kwargs.get("accessType", False)
|
||||
is_https = kwargs.get("isHttps", False)
|
||||
is_certificate = kwargs.get("isCertificate", False)
|
||||
certificate_path = kwargs.get("certificatePath", None)
|
||||
return LlmReq(url=url, header=header, body=body, access_type=access_type, is_https=is_https,
|
||||
is_certificate=is_certificate, certificate_path=certificate_path)
|
||||
|
||||
def build_llm_prompt(self, *args, **kwargs):
|
||||
"""执行函数(子类实现)"""
|
||||
raise NotImplementedError("This is in LLM Class, plese re-define this method in Sub-classes")
|
||||
|
||||
def save_sample(self, object_list: List, sample: Dict[str, Any]):
|
||||
if self.target_file_type:
|
||||
target_file_type = self.target_file_type
|
||||
else:
|
||||
target_file_type = "jsonl"
|
||||
save_path = self.get_save_path(sample, target_file_type)
|
||||
self.save_json_file(object_list, save_path)
|
||||
|
||||
def get_save_path(self, sample: Dict[str, Any], target_type) -> str:
|
||||
export_path = os.path.abspath(sample[self.export_path_key])
|
||||
logger.info(f"export path: {export_path}.")
|
||||
base_file_name, _ = os.path.splitext(sample[self.filename_key])
|
||||
file_id = str(sample[self.fileid_key])
|
||||
new_file_name = file_id + '.' + target_type
|
||||
logger.info(f"base_file_name: {base_file_name}, new file name: {new_file_name}.")
|
||||
if not check_valid_path(export_path):
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
res = os.path.join(export_path, new_file_name)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def save_json_file(object_list: List, save_path):
|
||||
if len(object_list) == 0:
|
||||
logger.warning("Please check the param: object_list, which has length equal to 0.")
|
||||
return
|
||||
try:
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
for item in object_list:
|
||||
json_str = json.dumps(item, ensure_ascii=False)
|
||||
f.write(json_str + '\n')
|
||||
|
||||
os.chmod(save_path, 0o640)
|
||||
try:
|
||||
parent_dir = os.path.dirname(save_path)
|
||||
os.chmod(parent_dir, 0o770)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to modify the permission on the parent_dir.")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Save jsonl file Failed!, save_path: {save_path}.") from e
|
||||
|
||||
logger.info(f"LLM output has been save to {save_path}.")
|
||||
8
runtime/python-executor/datamate/core/constant.py
Normal file
8
runtime/python-executor/datamate/core/constant.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
class Fields(object):
|
||||
result = 'execute_result'
|
||||
instance_id = 'instance_id'
|
||||
export_path = 'export_path'
|
||||
|
||||
|
||||
213
runtime/python-executor/datamate/core/dataset.py
Normal file
213
runtime/python-executor/datamate/core/dataset.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pyarrow as pa
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
from ray import data as rd
|
||||
|
||||
from datamate.core.base_op import Filter, Mapper, Slicer
|
||||
from datamate.core.constant import Fields
|
||||
from datamate.core.base_op import OPERATORS, BaseOp
|
||||
|
||||
rd.DataContext.get_current().enable_progress_bars = False
|
||||
|
||||
|
||||
def is_valid_path(item, dataset_dir):
|
||||
full_path = os.path.abspath(os.path.join(dataset_dir, item))
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def new_get_num_npus(init_kwargs):
|
||||
if init_kwargs.get("accelerator", "cpu") != "npu":
|
||||
return 0.0
|
||||
return 0.1
|
||||
|
||||
|
||||
class Formatters(Enum):
|
||||
"""
|
||||
抽取算子和落盘算子枚举类
|
||||
"""
|
||||
FILE_EXPORTER = "FileExporter"
|
||||
IMG_FORMATTER = "ImgFormatter"
|
||||
OCR_FORMATTER = "OcrFormatter"
|
||||
PDF_CPU_FORMATTER = "PdfCpuFormatter"
|
||||
SLID_FORMATTER = "SlideFormatter"
|
||||
TEXT_FORMATTER = "TextFormatter"
|
||||
WORD_FORMATTER = "WordFormatter"
|
||||
UNION_FORMATTER = "UnionFormatter"
|
||||
ONNX_FORMATTER = "OnnxImg2TextFormatter"
|
||||
|
||||
@classmethod
|
||||
def is_member(cls, op_name):
|
||||
return op_name in cls._value2member_map_
|
||||
|
||||
|
||||
class BasicDataset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def process(
|
||||
self,
|
||||
cfg_process,
|
||||
*,
|
||||
exporter=None,
|
||||
checkpointer=None
|
||||
) -> BasicDataset:
|
||||
pass
|
||||
|
||||
|
||||
def preprocess_dataset(dataset: rd.Dataset, cfg) -> rd.Dataset:
|
||||
columns = dataset.columns()
|
||||
new_column_names = [getattr(Fields, attr_name)
|
||||
for attr_name in vars(Fields)
|
||||
if attr_name not in columns and not attr_name.startswith('__')]
|
||||
|
||||
def process_batch_arrow(table: pa.Table, names_list=None) -> pa.Table:
|
||||
name2value_table = {
|
||||
Fields.instance_id: cfg.instance_id,
|
||||
Fields.export_path: cfg.export_path
|
||||
}
|
||||
|
||||
for column_name in names_list:
|
||||
if column_name in name2value_table.keys():
|
||||
new_column_data = [name2value_table[column_name] for _ in range(len(table))]
|
||||
else:
|
||||
new_column_data = [None for _ in range(len(table))]
|
||||
table = table.append_column(column_name, [new_column_data])
|
||||
return table
|
||||
|
||||
if new_column_names:
|
||||
dataset = dataset.map_batches(process_batch_arrow,
|
||||
fn_kwargs={"names_list": new_column_names},
|
||||
num_cpus=0.05,
|
||||
batch_format='pyarrow')
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class RayDataset(BasicDataset):
|
||||
|
||||
def __init__(self,
|
||||
dataset: rd.Dataset,
|
||||
cfg=None) -> None:
|
||||
self.onnx_ops_name = ["OnnxImg2TextFormatter", "OnnxImageContentFilter"]
|
||||
self.npu_ops_name = ["Img2TextFormatter", "ImageContentFilter"]
|
||||
self.data = preprocess_dataset(dataset, cfg)
|
||||
|
||||
def process(self,
|
||||
cfg_process,
|
||||
*,
|
||||
exporter=None,
|
||||
checkpointer=None,
|
||||
**kwargs) -> BasicDataset:
|
||||
|
||||
# 从注册器加载类
|
||||
operators_cls_list = []
|
||||
init_kwargs_list = []
|
||||
for index, process in enumerate(cfg_process):
|
||||
op_name, init_kwargs = list(process.items())[0]
|
||||
init_kwargs = {} if not init_kwargs else init_kwargs
|
||||
init_kwargs.update({'op_name': op_name})
|
||||
|
||||
# 加载Ops module
|
||||
temp_ops = self.load_ops_module(op_name)
|
||||
|
||||
if index == len(cfg_process) - 1:
|
||||
init_kwargs["is_last_op"] = True
|
||||
operators_cls_list.append(temp_ops)
|
||||
init_kwargs_list.append(init_kwargs)
|
||||
|
||||
for cls_id, operators_cls in enumerate(operators_cls_list):
|
||||
self._run_single_op(operators_cls, init_kwargs_list[cls_id], **kwargs)
|
||||
return self
|
||||
|
||||
def load_ops_module(self, op_name):
|
||||
'''
|
||||
加载算子模块
|
||||
:param op_name: 算子名称
|
||||
:return: 算子对象
|
||||
'''
|
||||
parent_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "ops")
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
registry_content = OPERATORS.modules[op_name]
|
||||
if isinstance(registry_content, str):
|
||||
# registry_content是module的路径
|
||||
submodule = importlib.import_module(registry_content)
|
||||
res = getattr(submodule, op_name, None)
|
||||
if res is None:
|
||||
raise ImportError(f"Import Ops module {op_name} Failed.")
|
||||
else:
|
||||
logger.info(f"Import Ops module {op_name} Success.")
|
||||
elif isinstance(registry_content, type) and issubclass(registry_content, BaseOp):
|
||||
# registry_content是module本身
|
||||
res = registry_content
|
||||
else:
|
||||
res = None
|
||||
return res
|
||||
|
||||
def _run_single_op(self, operators_cls, init_kwargs, **kwargs):
|
||||
|
||||
num_npus = new_get_num_npus(init_kwargs)
|
||||
max_actor_nums = os.getenv("MAX_ACTOR_NUMS", "20")
|
||||
|
||||
# 分辨是否是onnx算子,如果是需要限制Actor并发数量
|
||||
if self._use_onnx_model(init_kwargs['op_name']):
|
||||
max_actor_nums = 4
|
||||
|
||||
resources = {}
|
||||
|
||||
if num_npus > 0:
|
||||
resources["node_npu"] = 0.1
|
||||
|
||||
if init_kwargs.get("arch", "arm").startswith("x86"):
|
||||
resources["arch"] = "x86"
|
||||
|
||||
kwargs.update({"ext_params": {}, "failed_reason": {}, "target_type": None})
|
||||
try:
|
||||
if issubclass(operators_cls, Mapper):
|
||||
self.data = self.data.map(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, 1 if operators_cls.use_model else int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, Slicer):
|
||||
self.data = self.data.flat_map(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, Filter):
|
||||
self.data = self.data.filter(operators_cls,
|
||||
fn_constructor_kwargs=init_kwargs,
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
else:
|
||||
logger.error(
|
||||
'Ray executor only support Filter, Mapper and Slicer OPs for now')
|
||||
raise NotImplementedError
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise Exception("Error! Ops Details:") from e
|
||||
|
||||
def _use_onnx_model(self, ops_name):
|
||||
if ops_name in self.onnx_ops_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _use_npu_model(self, ops_name):
|
||||
if ops_name in self.npu_ops_name:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user