You've already forked DataMate
fix: prevent deletion of predefined operators and improve error handling (#192)
* fix: prevent deletion of predefined operators and improve error handling * fix: prevent deletion of predefined operators and improve error handling
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
import cv2
|
||||
@@ -445,7 +446,7 @@ class FileExporter(BaseOp):
|
||||
return False
|
||||
|
||||
if save_path:
|
||||
self.save_file(sample, save_path)
|
||||
save_path = self.save_file(sample, save_path)
|
||||
sample[self.text_key] = ''
|
||||
sample[self.data_key] = b''
|
||||
sample[Fields.result] = True
|
||||
@@ -453,6 +454,7 @@ class FileExporter(BaseOp):
|
||||
file_type = save_path.split('.')[-1]
|
||||
sample[self.filetype_key] = file_type
|
||||
|
||||
file_name = os.path.basename(save_path)
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
new_file_name = base_name + '.' + file_type
|
||||
sample[self.filename_key] = new_file_name
|
||||
@@ -516,13 +518,28 @@ class FileExporter(BaseOp):
|
||||
def save_file(self, sample, save_path):
|
||||
# 以二进制格式保存文件
|
||||
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(file_sample)
|
||||
# 获取父目录路径
|
||||
path_obj = Path(save_path).resolve()
|
||||
parent_dir = path_obj.parent
|
||||
stem = path_obj.stem # 文件名不含后缀
|
||||
suffix = path_obj.suffix # 后缀 (.txt)
|
||||
|
||||
parent_dir = os.path.dirname(save_path)
|
||||
counter = 0
|
||||
current_path = path_obj
|
||||
while True:
|
||||
try:
|
||||
# x 模式保证:如果文件存在则报错,如果不存在则创建。
|
||||
# 这个检查+创建的过程是操作系统级的原子操作,没有竞态条件。
|
||||
with open(current_path, 'xb') as f:
|
||||
f.write(file_sample)
|
||||
break
|
||||
except FileExistsError:
|
||||
# 文件已存在(被其他线程/进程抢占),更新文件名重试
|
||||
counter += 1
|
||||
new_filename = f"{stem}_{counter}{suffix}"
|
||||
current_path = parent_dir / new_filename
|
||||
os.chmod(parent_dir, 0o770)
|
||||
os.chmod(save_path, 0o640)
|
||||
os.chmod(current_path, 0o640)
|
||||
return str(current_path)
|
||||
|
||||
def _get_from_data(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample[self.data_key] = bytes(sample[self.data_key])
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import importlib
|
||||
import sys
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -119,12 +120,14 @@ class RayDataset(BasicDataset):
|
||||
|
||||
# 加载Ops module
|
||||
temp_ops = self.load_ops_module(op_name)
|
||||
operators_cls_list.append(temp_ops)
|
||||
|
||||
if index == 0:
|
||||
init_kwargs["is_first_op"] = True
|
||||
|
||||
if index == len(cfg_process) - 1:
|
||||
init_kwargs["is_last_op"] = True
|
||||
operators_cls_list.append(temp_ops)
|
||||
init_kwargs["instance_id"] = kwargs.get("instance_id", str(uuid.uuid4()))
|
||||
init_kwargs_list.append(init_kwargs)
|
||||
|
||||
for cls_id, operators_cls in enumerate(operators_cls_list):
|
||||
|
||||
Reference in New Issue
Block a user