You've already forked DataMate
算子将抽取与落盘固定到流程中 (#134)
* feature: 将抽取动作移到每一个算子中 * feature: 落盘算子改为默认执行 * feature: 优化前端展示 * feature: 使用pyproject管理依赖
This commit is contained in:
@@ -2,10 +2,15 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
from datamate.common.error_code import ERROR_CODE_TABLE, UNKNOWN_ERROR_CODE
|
||||
from datamate.common.utils.llm_request import LlmReq
|
||||
@@ -52,6 +57,7 @@ class BaseOp:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.accelerator = kwargs.get('accelerator', "cpu")
|
||||
self.is_last_op = kwargs.get('is_last_op', False)
|
||||
self.is_first_op = kwargs.get('is_first_op', False)
|
||||
self._name = kwargs.get('op_name', None)
|
||||
self.infer_model = None
|
||||
self.text_key = kwargs.get('text_key', "text")
|
||||
@@ -122,10 +128,10 @@ class BaseOp:
|
||||
raise NotImplementedError("This is in BaseOp, plese re-define this method in Sub-classes")
|
||||
|
||||
def fill_sample_params(self, sample: Dict[str, Any], **kwargs):
|
||||
if not sample.get("text", None):
|
||||
if not sample.get(self.text_key, None):
|
||||
sample[self.text_key] = ""
|
||||
|
||||
if not sample.get("data", None):
|
||||
if not sample.get(self.data_key, None):
|
||||
sample[self.data_key] = b""
|
||||
|
||||
if not sample[self.data_key] and not sample[self.text_key]:
|
||||
@@ -137,6 +143,27 @@ class BaseOp:
|
||||
failed_reason = {"op_name": op_name, "error_code": error_code, "reason": exc_info}
|
||||
sample["failed_reason"] = failed_reason
|
||||
|
||||
def read_file(self, sample):
|
||||
filepath = sample[self.filepath_key]
|
||||
filetype = sample[self.filetype_key]
|
||||
if filetype in ["ppt", "pptx", "docx", "doc", "xlsx"]:
|
||||
elements = partition(filename=filepath)
|
||||
sample[self.text_key] = "\n\n".join([str(el) for el in elements])
|
||||
elif filetype in ["txt", "md", "markdown", "xml", "html", "csv", "json", "jsonl"]:
|
||||
with open(filepath, 'rb') as f:
|
||||
content = f.read()
|
||||
sample[self.text_key] = content.decode("utf-8-sig").replace("\r\n", "\n")
|
||||
elif filetype in ['jpg', 'jpeg', 'png', 'bmp']:
|
||||
image_np = cv2.imdecode(np.fromfile(filepath, dtype=np.uint8), -1)
|
||||
if image_np.size:
|
||||
data = cv2.imencode(f".{filetype}", image_np)[1]
|
||||
image_bytes = data.tobytes()
|
||||
sample[self.data_key] = image_bytes
|
||||
|
||||
def read_file_first(self, sample):
|
||||
if self.is_first_op:
|
||||
self.read_file(sample)
|
||||
|
||||
|
||||
class Mapper(BaseOp):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -158,15 +185,16 @@ class Mapper(BaseOp):
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
sample[self.filesize_key] = "0"
|
||||
sample[self.filetype_key] = ""
|
||||
TaskInfoPersistence().update_task_result(sample)
|
||||
raise e
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
if FileExporter().execute(sample):
|
||||
TaskInfoPersistence().persistence_task_info(sample)
|
||||
return sample
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -197,8 +225,9 @@ class Slicer(BaseOp):
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
sample["execute_status"] = execute_status
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
sample[self.filesize_key] = "0"
|
||||
sample[self.filetype_key] = ""
|
||||
TaskInfoPersistence().update_task_result(sample)
|
||||
return [sample]
|
||||
|
||||
self.load_sample_to_sample(sample, sample_list)
|
||||
@@ -206,8 +235,8 @@ class Slicer(BaseOp):
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
if FileExporter().execute(sample):
|
||||
TaskInfoPersistence().persistence_task_info(sample)
|
||||
|
||||
return [sample]
|
||||
|
||||
@@ -286,22 +315,24 @@ class Filter(BaseOp):
|
||||
sample["execute_status"] = execute_status
|
||||
logger.error(f"Ops named {self.name} map failed, Error Info: \n"
|
||||
f"{str(get_exception_info(e))}")
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
sample[self.filesize_key] = "0"
|
||||
sample[self.filetype_key] = ""
|
||||
TaskInfoPersistence().update_task_result(sample)
|
||||
raise e
|
||||
|
||||
sample["execute_status"] = execute_status
|
||||
# 文件无内容会被过滤
|
||||
if sample[self.text_key] == "" and sample[self.data_key] == b"":
|
||||
task_info = TaskInfoPersistence()
|
||||
sample["fileSize"] = "0"
|
||||
task_info.persistence_task_info(sample)
|
||||
sample[self.filesize_key] = "0"
|
||||
sample[self.filetype_key] = ""
|
||||
task_info.update_task_result(sample)
|
||||
return False
|
||||
|
||||
# 加载文件成功执行信息到数据库
|
||||
if self.is_last_op:
|
||||
task_info = TaskInfoPersistence()
|
||||
task_info.persistence_task_info(sample)
|
||||
if FileExporter().execute(sample):
|
||||
TaskInfoPersistence().persistence_task_info(sample)
|
||||
return True
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -379,3 +410,131 @@ class LLM(Mapper):
|
||||
raise RuntimeError(f"Save jsonl file Failed!, save_path: {save_path}.") from e
|
||||
|
||||
logger.info(f"LLM output has been save to {save_path}.")
|
||||
|
||||
|
||||
class FileExporter(BaseOp):
|
||||
"""把输入的json文件流抽取为txt"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FileExporter, self).__init__(*args, **kwargs)
|
||||
self.last_ops = True
|
||||
self.text_support_ext = kwargs.get("text_support_ext", ['txt', 'html', 'md', 'markdown',
|
||||
'xlsx', 'xls', 'csv', 'pptx', 'ppt',
|
||||
'xml', 'json', 'doc', 'docx', 'pdf'])
|
||||
self.data_support_ext = kwargs.get("data_support_ext", ['jpg', 'jpeg', 'png', 'bmp'])
|
||||
self.medical_support_ext = kwargs.get("medical_support_ext", ['svs', 'tif', 'tiff'])
|
||||
|
||||
def execute(self, sample: Dict[str, Any]):
|
||||
file_name = sample[self.filename_key]
|
||||
file_type = sample[self.filetype_key]
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
if file_type in self.text_support_ext:
|
||||
sample, save_path = self.get_textfile_handler(sample)
|
||||
elif file_type in self.data_support_ext:
|
||||
sample, save_path = self.get_datafile_handler(sample)
|
||||
elif file_type in self.medical_support_ext:
|
||||
sample, save_path = self.get_medicalfile_handler(sample)
|
||||
else:
|
||||
raise TypeError(f"{file_type} is unsupported! please check support_ext in FileExporter Ops")
|
||||
|
||||
if sample[self.text_key] == '' and sample[self.data_key] == b'':
|
||||
sample[self.filesize_key] = "0"
|
||||
return False
|
||||
|
||||
if save_path:
|
||||
self.save_file(sample, save_path)
|
||||
sample[self.text_key] = ''
|
||||
sample[self.data_key] = b''
|
||||
sample[Fields.result] = True
|
||||
|
||||
file_type = save_path.split('.')[-1]
|
||||
sample[self.filetype_key] = file_type
|
||||
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
new_file_name = base_name + '.' + file_type
|
||||
sample[self.filename_key] = new_file_name
|
||||
|
||||
base_name, _ = os.path.splitext(save_path)
|
||||
sample[self.filepath_key] = base_name
|
||||
file_size = os.path.getsize(base_name)
|
||||
sample[self.filesize_key] = f"{file_size}"
|
||||
|
||||
logger.info(f"origin file named {file_name} has been save to {save_path}")
|
||||
logger.info(f"fileName: {sample[self.filename_key]}, "
|
||||
f"method: FileExporter costs {time.time() - start:.6f} s")
|
||||
except UnicodeDecodeError as err:
|
||||
logger.error(f"fileName: {sample[self.filename_key]}, "
|
||||
f"method: FileExporter causes decode error: {err}")
|
||||
raise
|
||||
return True
|
||||
|
||||
def get_save_path(self, sample: Dict[str, Any], target_type):
|
||||
export_path = os.path.abspath(sample[self.export_path_key])
|
||||
file_name = sample[self.filename_key]
|
||||
new_file_name = os.path.splitext(file_name)[0] + '.' + target_type
|
||||
|
||||
if not check_valid_path(export_path):
|
||||
os.makedirs(export_path, exist_ok=True)
|
||||
return os.path.join(export_path, new_file_name)
|
||||
|
||||
def get_textfile_handler(self, sample: Dict[str, Any]):
|
||||
target_type = sample.get("target_type", None)
|
||||
|
||||
# target_type存在则保存为扫描件, docx格式
|
||||
if target_type:
|
||||
sample = self._get_from_data(sample)
|
||||
save_path = self.get_save_path(sample, target_type)
|
||||
# 不存在则保存为txt文件,正常文本清洗
|
||||
else:
|
||||
sample = self._get_from_text(sample)
|
||||
save_path = self.get_save_path(sample, 'txt')
|
||||
return sample, save_path
|
||||
|
||||
def get_datafile_handler(self, sample: Dict[str, Any]):
|
||||
target_type = sample.get("target_type", None)
|
||||
|
||||
# target_type存在, 图转文保存为target_type,markdown格式
|
||||
if target_type:
|
||||
sample = self._get_from_text(sample)
|
||||
save_path = self.get_save_path(sample, target_type)
|
||||
# 不存在则保存为原本图片文件格式,正常图片清洗
|
||||
else:
|
||||
sample = self._get_from_data(sample)
|
||||
save_path = self.get_save_path(sample, sample[self.filetype_key])
|
||||
return sample, save_path
|
||||
|
||||
def get_medicalfile_handler(self, sample: Dict[str, Any]):
|
||||
target_type = 'png'
|
||||
|
||||
sample = self._get_from_data(sample)
|
||||
save_path = self.get_save_path(sample, target_type)
|
||||
|
||||
return sample, save_path
|
||||
|
||||
def save_file(self, sample, save_path):
|
||||
file_name, _ = os.path.splitext(save_path)
|
||||
# 以二进制格式保存文件
|
||||
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
|
||||
with open(file_name, 'wb') as f:
|
||||
f.write(file_sample)
|
||||
# 获取父目录路径
|
||||
|
||||
parent_dir = os.path.dirname(file_name)
|
||||
os.chmod(parent_dir, 0o770)
|
||||
os.chmod(file_name, 0o640)
|
||||
|
||||
def _get_from_data(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample[self.data_key] = bytes(sample[self.data_key])
|
||||
sample[self.text_key] = ''
|
||||
return sample
|
||||
|
||||
def _get_from_text(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample[self.data_key] = b''
|
||||
sample[self.text_key] = str(sample[self.text_key])
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def _get_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@@ -119,6 +119,8 @@ class RayDataset(BasicDataset):
|
||||
|
||||
# 加载Ops module
|
||||
temp_ops = self.load_ops_module(op_name)
|
||||
if index == 0:
|
||||
init_kwargs["is_first_op"] = True
|
||||
|
||||
if index == len(cfg_process) - 1:
|
||||
init_kwargs["is_last_op"] = True
|
||||
@@ -182,7 +184,8 @@ class RayDataset(BasicDataset):
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, 1 if operators_cls.use_model else int(max_actor_nums)))
|
||||
compute=rd.ActorPoolStrategy(min_size=1,
|
||||
max_size=int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, (Slicer, RELATIVE_Slicer)):
|
||||
self.data = self.data.flat_map(operators_cls,
|
||||
@@ -190,7 +193,8 @@ class RayDataset(BasicDataset):
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
compute=rd.ActorPoolStrategy(min_size=1,
|
||||
max_size=int(max_actor_nums)))
|
||||
|
||||
elif issubclass(operators_cls, (Filter, RELATIVE_Filter)):
|
||||
self.data = self.data.filter(operators_cls,
|
||||
@@ -198,7 +202,8 @@ class RayDataset(BasicDataset):
|
||||
fn_kwargs=kwargs,
|
||||
resources=resources,
|
||||
num_cpus=0.05,
|
||||
concurrency=(1, int(max_actor_nums)))
|
||||
compute=rd.ActorPoolStrategy(min_size=1,
|
||||
max_size=int(max_actor_nums)))
|
||||
else:
|
||||
logger.error(
|
||||
'Ray executor only support Filter, Mapper and Slicer OPs for now')
|
||||
|
||||
Reference in New Issue
Block a user