You've already forked DataMate
feature: multiple ratio configurations can be set for the data set. (#103)
feature: multiple ratio configurations can be set for the data set.
This commit is contained in:
@@ -27,15 +27,15 @@ class PagedDatasetFileResponse(BaseModel):
|
||||
size: int = Field(..., description="每页大小")
|
||||
|
||||
class DatasetFileTag(BaseModel):
|
||||
id: str = Field(..., description="标签ID")
|
||||
type: str = Field(..., description="类型")
|
||||
from_name: str = Field(..., description="标签名称")
|
||||
value: dict = Field(..., description="标签值")
|
||||
id: str = Field(None, description="标签ID")
|
||||
type: str = Field(None, description="类型")
|
||||
from_name: str = Field(None, description="标签名称")
|
||||
values: dict = Field(None, description="标签值")
|
||||
|
||||
def get_tags(self) -> List[str]:
|
||||
tags = []
|
||||
# 如果 value 是字典类型,根据 type 获取对应的值
|
||||
tag_values = self.value.get(self.type, [])
|
||||
# 如果 values 是字典类型,根据 type 获取对应的值
|
||||
tag_values = self.values.get(self.type, [])
|
||||
|
||||
# 处理标签值
|
||||
if isinstance(tag_values, list):
|
||||
@@ -55,7 +55,7 @@ class FileTagUpdate(BaseModel):
|
||||
"""单个文件的标签更新请求"""
|
||||
file_id: str = Field(..., alias="fileId", description="文件ID")
|
||||
tags: List[Dict[str, Any]] = Field(..., description="要更新的标签列表(部分更新)")
|
||||
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@@ -63,7 +63,7 @@ class FileTagUpdate(BaseModel):
|
||||
class BatchUpdateFileTagsRequest(BaseModel):
|
||||
"""批量更新文件标签请求"""
|
||||
updates: List[FileTagUpdate] = Field(..., description="文件标签更新列表", min_length=1)
|
||||
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@@ -74,7 +74,7 @@ class FileTagUpdateResult(BaseModel):
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: Optional[str] = Field(None, description="结果信息")
|
||||
tags_updated_at: Optional[datetime] = Field(None, alias="tagsUpdatedAt", description="标签更新时间")
|
||||
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@@ -85,6 +85,6 @@ class BatchUpdateFileTagsResponse(BaseModel):
|
||||
total: int = Field(..., description="总更新数量")
|
||||
success_count: int = Field(..., alias="successCount", description="成功数量")
|
||||
failure_count: int = Field(..., alias="failureCount", description="失败数量")
|
||||
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@@ -170,7 +170,6 @@ async def list_ratio_tasks(
|
||||
description=i.description,
|
||||
status=i.status,
|
||||
totals=i.totals,
|
||||
ratio_method=i.ratio_method,
|
||||
target_dataset_id=i.target_dataset_id,
|
||||
target_dataset_name=(ds.name if ds else None),
|
||||
created_at=str(i.created_at) if getattr(i, "created_at", None) else None,
|
||||
@@ -330,7 +329,6 @@ async def get_ratio_task(
|
||||
description=instance.description,
|
||||
status=instance.status or "UNKNOWN",
|
||||
totals=instance.totals or 0,
|
||||
ratio_method=instance.ratio_method or "",
|
||||
config=config,
|
||||
target_dataset=target_dataset_info,
|
||||
created_at=instance.created_at,
|
||||
|
||||
@@ -88,7 +88,6 @@ class RatioTaskItem(BaseModel):
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
totals: Optional[int] = None
|
||||
ratio_method: Optional[str] = None
|
||||
target_dataset_id: Optional[str] = None
|
||||
target_dataset_name: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
@@ -110,7 +109,6 @@ class RatioTaskDetailResponse(BaseModel):
|
||||
description: Optional[str] = Field(None, description="任务描述")
|
||||
status: str = Field(..., description="任务状态")
|
||||
totals: int = Field(..., description="目标总数")
|
||||
ratio_method: str = Field(..., description="配比方式")
|
||||
config: List[Dict[str, Any]] = Field(..., description="配比配置")
|
||||
target_dataset: Dict[str, Any] = Field(..., description="目标数据集信息")
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
import random
|
||||
import json
|
||||
@@ -173,7 +174,7 @@ class RatioTaskService:
|
||||
@staticmethod
|
||||
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
|
||||
src_path = f.file_path
|
||||
dst_prefix = f"/dataset/{target_ds.id}"
|
||||
dst_prefix = f"/dataset/{target_ds.id}/"
|
||||
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
|
||||
|
||||
new_path = dst_prefix + file_name
|
||||
@@ -181,18 +182,20 @@ class RatioTaskService:
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
new_file = DatasetFiles(
|
||||
dataset_id=target_ds.id, # type: ignore
|
||||
file_name=file_name,
|
||||
file_path=new_path,
|
||||
file_type=f.file_type,
|
||||
file_size=f.file_size,
|
||||
check_sum=f.check_sum,
|
||||
tags=f.tags,
|
||||
dataset_filemetadata=f.dataset_filemetadata,
|
||||
status="ACTIVE",
|
||||
)
|
||||
session.add(new_file)
|
||||
file_data = {
|
||||
"dataset_id": target_ds.id, # type: ignore
|
||||
"file_name": file_name,
|
||||
"file_path": new_path,
|
||||
"file_type": f.file_type,
|
||||
"file_size": f.file_size,
|
||||
"check_sum": f.check_sum,
|
||||
"tags": f.tags,
|
||||
"tags_updated_at": datetime.now(),
|
||||
"dataset_filemetadata": f.dataset_filemetadata,
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
file_record = {k: v for k, v in file_data.items() if v is not None}
|
||||
session.add(DatasetFiles(**file_record))
|
||||
existing_paths.add(new_path)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user