fix: prevent deletion of predefined operators and improve error handling (#192)

* fix: prevent deletion of predefined operators and improve error handling

* fix: prevent deletion of predefined operators and improve error handling
This commit is contained in:
hhhhsc701
2025-12-22 19:30:41 +08:00
committed by GitHub
parent c1516c87b6
commit d82bff441a
15 changed files with 98 additions and 55 deletions

View File

@@ -5,6 +5,7 @@ import os
import time
import traceback
import uuid
from pathlib import Path
from typing import List, Dict, Any, Tuple
import cv2
@@ -445,7 +446,7 @@ class FileExporter(BaseOp):
return False
if save_path:
self.save_file(sample, save_path)
save_path = self.save_file(sample, save_path)
sample[self.text_key] = ''
sample[self.data_key] = b''
sample[Fields.result] = True
@@ -453,6 +454,7 @@ class FileExporter(BaseOp):
file_type = save_path.split('.')[-1]
sample[self.filetype_key] = file_type
file_name = os.path.basename(save_path)
base_name, _ = os.path.splitext(file_name)
new_file_name = base_name + '.' + file_type
sample[self.filename_key] = new_file_name
@@ -516,13 +518,28 @@ class FileExporter(BaseOp):
def save_file(self, sample, save_path):
# 以二进制格式保存文件
file_sample = sample[self.text_key].encode('utf-8') if sample[self.text_key] else sample[self.data_key]
with open(save_path, 'wb') as f:
f.write(file_sample)
# 获取父目录路径
path_obj = Path(save_path).resolve()
parent_dir = path_obj.parent
stem = path_obj.stem # 文件名不含后缀
suffix = path_obj.suffix # 后缀 (.txt)
parent_dir = os.path.dirname(save_path)
counter = 0
current_path = path_obj
while True:
try:
# x 模式保证:如果文件存在则报错,如果不存在则创建。
# 这个检查+创建的过程是操作系统级的原子操作,没有竞态条件。
with open(current_path, 'xb') as f:
f.write(file_sample)
break
except FileExistsError:
# 文件已存在(被其他线程/进程抢占),更新文件名重试
counter += 1
new_filename = f"{stem}_{counter}{suffix}"
current_path = parent_dir / new_filename
os.chmod(parent_dir, 0o770)
os.chmod(save_path, 0o640)
os.chmod(current_path, 0o640)
return str(current_path)
def _get_from_data(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[self.data_key] = bytes(sample[self.data_key])

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import os
import importlib
import sys
import uuid
from abc import ABC, abstractmethod
import pyarrow as pa
@@ -119,12 +120,14 @@ class RayDataset(BasicDataset):
# 加载Ops module
temp_ops = self.load_ops_module(op_name)
operators_cls_list.append(temp_ops)
if index == 0:
init_kwargs["is_first_op"] = True
if index == len(cfg_process) - 1:
init_kwargs["is_last_op"] = True
operators_cls_list.append(temp_ops)
init_kwargs["instance_id"] = kwargs.get("instance_id", str(uuid.uuid4()))
init_kwargs_list.append(init_kwargs)
for cls_id, operators_cls in enumerate(operators_cls_list):