You've already forked DataMate
feature: add data-evaluation
* feature: add evaluation task management function * feature: add evaluation task detail page * fix: delete duplicate definition for table t_model_config * refactor: rename package synthesis to ratio * refactor: add eval file table and refactor related code * fix: calling large models in parallel during evaluation
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/synthesis",
|
||||
tags = ["synthesis"]
|
||||
)
|
||||
|
||||
# Include sub-routers
|
||||
from .ratio_task import router as ratio_task_router
|
||||
|
||||
router.include_router(ratio_task_router)
|
||||
342
runtime/datamate-python/app/module/ratio/interface/ratio_task.py
Normal file
342
runtime/datamate-python/app/module/ratio/interface/ratio_task.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Set
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import or_, func, delete, select
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import Dataset
|
||||
from app.db.session import get_db
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.module.shared.schema import StandardResponse, TaskStatus
|
||||
from app.module.ratio.schema.ratio_task import (
|
||||
CreateRatioTaskResponse,
|
||||
CreateRatioTaskRequest,
|
||||
PagedRatioTaskResponse,
|
||||
RatioTaskItem,
|
||||
TargetDatasetInfo,
|
||||
RatioTaskDetailResponse,
|
||||
)
|
||||
from app.module.ratio.service.ratio_task import RatioTaskService
|
||||
from app.db.models.ratio_task import RatioInstance, RatioRelation, RatioRelation as RatioRelationModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/ratio-task",
|
||||
tags=["synthesis/ratio-task"],
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[CreateRatioTaskResponse], status_code=200)
|
||||
async def create_ratio_task(
|
||||
req: CreateRatioTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
创建配比任务
|
||||
|
||||
Path: /api/synthesis/ratio-task
|
||||
"""
|
||||
try:
|
||||
# 校验 config 中的 dataset_id 是否存在
|
||||
dm_service = DatasetManagementService(db)
|
||||
source_types = await get_dataset_types(dm_service, req)
|
||||
|
||||
await valid_exists(db, req)
|
||||
|
||||
target_dataset = await create_target_dataset(db, req, source_types)
|
||||
|
||||
instance = await create_ratio_instance(db, req, target_dataset)
|
||||
|
||||
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
|
||||
|
||||
response_data = CreateRatioTaskResponse(
|
||||
id=instance.id,
|
||||
name=instance.name,
|
||||
description=instance.description,
|
||||
totals=instance.totals or 0,
|
||||
status=instance.status or TaskStatus.PENDING.name,
|
||||
config=req.config,
|
||||
targetDataset=TargetDatasetInfo(
|
||||
id=str(target_dataset.id),
|
||||
name=str(target_dataset.name),
|
||||
datasetType=str(target_dataset.dataset_type),
|
||||
status=str(target_dataset.status),
|
||||
)
|
||||
)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to create ratio task: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
|
||||
service = RatioTaskService(db)
|
||||
logger.info(f"create_ratio_instance: {req}")
|
||||
instance = await service.create_task(
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
totals=int(req.totals),
|
||||
config=[
|
||||
{
|
||||
"dataset_id": item.dataset_id,
|
||||
"counts": int(item.counts),
|
||||
"filter_conditions": item.filter_conditions,
|
||||
}
|
||||
for item in req.config
|
||||
],
|
||||
target_dataset_id=target_dataset.id,
|
||||
)
|
||||
return instance
|
||||
|
||||
|
||||
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
|
||||
# 创建目标数据集:名称使用“<任务名称>-时间戳”
|
||||
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
target_type = get_target_dataset_type(source_types)
|
||||
target_dataset_id = uuid.uuid4()
|
||||
|
||||
target_dataset = Dataset(
|
||||
id=str(target_dataset_id),
|
||||
name=target_dataset_name,
|
||||
description=req.description or "",
|
||||
dataset_type=target_type,
|
||||
status="DRAFT",
|
||||
path=f"/dataset/{target_dataset_id}",
|
||||
)
|
||||
db.add(target_dataset)
|
||||
await db.flush() # 获取 target_dataset.id
|
||||
return target_dataset
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
|
||||
async def list_ratio_tasks(
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
name: str | None = None,
|
||||
status: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""分页查询配比任务,支持名称与状态过滤"""
|
||||
try:
|
||||
query = select(RatioInstance)
|
||||
# filters
|
||||
if name:
|
||||
# simple contains filter
|
||||
query = query.where(RatioInstance.name.like(f"%{name}%"))
|
||||
if status:
|
||||
query = query.where(RatioInstance.status == status)
|
||||
|
||||
# count
|
||||
count_q = select(func.count()).select_from(query.subquery())
|
||||
total = (await db.execute(count_q)).scalar_one()
|
||||
|
||||
# page (1-based)
|
||||
page_index = max(page, 1) - 1
|
||||
query = query.order_by(RatioInstance.created_at.desc()).offset(page_index * size).limit(size)
|
||||
result = await db.execute(query)
|
||||
items = result.scalars().all()
|
||||
|
||||
# map to DTOs and attach dataset name
|
||||
# preload datasets
|
||||
ds_ids = {i.target_dataset_id for i in items if i.target_dataset_id}
|
||||
ds_map = {}
|
||||
if ds_ids:
|
||||
ds_res = await db.execute(select(Dataset).where(Dataset.id.in_(list(ds_ids))))
|
||||
for d in ds_res.scalars().all():
|
||||
ds_map[d.id] = d
|
||||
|
||||
content: list[RatioTaskItem] = []
|
||||
for i in items:
|
||||
ds = ds_map.get(i.target_dataset_id) if i.target_dataset_id else None
|
||||
content.append(
|
||||
RatioTaskItem(
|
||||
id=i.id,
|
||||
name=i.name or "",
|
||||
description=i.description,
|
||||
status=i.status,
|
||||
totals=i.totals,
|
||||
target_dataset_id=i.target_dataset_id,
|
||||
target_dataset_name=(ds.name if ds else None),
|
||||
created_at=str(i.created_at) if getattr(i, "created_at", None) else None,
|
||||
updated_at=str(i.updated_at) if getattr(i, "updated_at", None) else None,
|
||||
)
|
||||
)
|
||||
|
||||
total_pages = (total + size - 1) // size if size > 0 else 0
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=PagedRatioTaskResponse(
|
||||
content=content,
|
||||
totalElements=total,
|
||||
totalPages=total_pages,
|
||||
page=page,
|
||||
size=size,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list ratio tasks: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("", response_model=StandardResponse[str], status_code=200)
|
||||
async def delete_ratio_tasks(
|
||||
ids: list[str] = Query(..., description="要删除的配比任务ID列表"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除配比任务,返回简单结果字符串。"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ids is required")
|
||||
|
||||
# 先删除关联关系
|
||||
await db.execute(
|
||||
delete(RatioRelation).where(RatioRelation.ratio_instance_id.in_(ids))
|
||||
)
|
||||
# 再删除实例
|
||||
await db.execute(
|
||||
delete(RatioInstance).where(RatioInstance.id.in_(ids))
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return StandardResponse(code=200, message="success", data="success")
|
||||
except HTTPException:
|
||||
await db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Failed to delete ratio tasks: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Fail to delete ratio task: {e}")
|
||||
|
||||
|
||||
async def valid_exists(db: AsyncSession, req: CreateRatioTaskRequest) -> None:
|
||||
"""校验配比任务名称不能重复(精确匹配,去除首尾空格)。"""
|
||||
name = (req.name or "").strip()
|
||||
if not name:
|
||||
raise HTTPException(status_code=400, detail="ratio task name is required")
|
||||
|
||||
# 查询是否已存在同名任务
|
||||
ratio_task = await db.execute(select(RatioInstance.id).where(RatioInstance.name == name))
|
||||
exists = ratio_task.scalar_one_or_none()
|
||||
if exists is not None:
|
||||
logger.error(f"create ratio task failed: ratio task '{name}' already exists (id={exists})")
|
||||
raise HTTPException(status_code=400, detail=f"ratio task '{name}' already exists")
|
||||
|
||||
|
||||
async def get_dataset_types(dm_service: DatasetManagementService, req: CreateRatioTaskRequest) -> Set[str]:
|
||||
source_types: Set[str] = set()
|
||||
for item in req.config:
|
||||
dataset = await dm_service.get_dataset(item.dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=400, detail=f"dataset_id not found: {item.dataset_id}")
|
||||
else:
|
||||
dtype = getattr(dataset, "dataset_type", None) or getattr(dataset, "datasetType", None)
|
||||
source_types.add(str(dtype).upper())
|
||||
return source_types
|
||||
|
||||
|
||||
def get_target_dataset_type(source_types: Set[str]) -> str:
|
||||
# 根据源数据集类型决定目标数据集类型
|
||||
# 规则:
|
||||
# 1) 若全部为 TEXT -> TEXT
|
||||
# 2) 若存在且仅存在一种介质类型(IMAGE/AUDIO/VIDEO),且无其它类型 -> 对应介质类型
|
||||
# 3) 其它情况 -> OTHER
|
||||
media_modalities = {"IMAGE", "AUDIO", "VIDEO"}
|
||||
target_type = "OTHER"
|
||||
if source_types == {"TEXT"}:
|
||||
target_type = "TEXT"
|
||||
else:
|
||||
media_involved = source_types & media_modalities
|
||||
if len(media_involved) == 1 and source_types == media_involved:
|
||||
# 仅有一种介质类型且无其它类型
|
||||
target_type = next(iter(media_involved))
|
||||
return target_type
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=StandardResponse[RatioTaskDetailResponse], status_code=200)
|
||||
async def get_ratio_task(
|
||||
task_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取配比任务详情
|
||||
|
||||
Path: /api/synthesis/ratio-task/{task_id}
|
||||
"""
|
||||
try:
|
||||
# 查询任务实例
|
||||
instance_res = await db.execute(
|
||||
select(RatioInstance).where(RatioInstance.id == task_id)
|
||||
)
|
||||
instance = instance_res.scalar_one_or_none()
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Ratio task not found")
|
||||
|
||||
# 查询关联的配比关系
|
||||
relations_res = await db.execute(
|
||||
select(RatioRelationModel).where(RatioRelationModel.ratio_instance_id == task_id)
|
||||
)
|
||||
relations = list(relations_res.scalars().all())
|
||||
|
||||
# 查询目标数据集
|
||||
target_ds = None
|
||||
if instance.target_dataset_id:
|
||||
ds_res = await db.execute(
|
||||
select(Dataset).where(Dataset.id == instance.target_dataset_id)
|
||||
)
|
||||
target_ds = ds_res.scalar_one_or_none()
|
||||
|
||||
# 构建响应
|
||||
config = [
|
||||
{
|
||||
"dataset_id": rel.source_dataset_id,
|
||||
"counts": str(rel.counts) if rel.counts is not None else "0",
|
||||
"filter_conditions": rel.filter_conditions or "",
|
||||
}
|
||||
for rel in relations
|
||||
]
|
||||
|
||||
target_dataset_info = {
|
||||
"id": str(target_ds.id) if target_ds else None,
|
||||
"name": target_ds.name if target_ds else None,
|
||||
"type": target_ds.dataset_type if target_ds else None,
|
||||
"status": target_ds.status if target_ds else None,
|
||||
"file_count": target_ds.file_count if target_ds else 0,
|
||||
"size_bytes": target_ds.size_bytes if target_ds else 0,
|
||||
}
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=RatioTaskDetailResponse(
|
||||
id=instance.id,
|
||||
name=instance.name or "",
|
||||
description=instance.description,
|
||||
status=instance.status or "UNKNOWN",
|
||||
totals=instance.totals or 0,
|
||||
config=config,
|
||||
target_dataset=target_dataset_info,
|
||||
created_at=instance.created_at,
|
||||
updated_at=instance.updated_at,
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ratio task {task_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
124
runtime/datamate-python/app/module/ratio/schema/ratio_task.py
Normal file
124
runtime/datamate-python/app/module/ratio/schema/ratio_task.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.shared.schema.common import TaskStatus
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class LabelFilter(BaseModel):
|
||||
label: Optional[str] = Field(..., description="标签")
|
||||
value: Optional[str] = Field(None, description="标签值")
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
|
||||
label: Optional[LabelFilter] = Field(None, description="标签")
|
||||
|
||||
@field_validator("date_range")
|
||||
@classmethod
|
||||
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
|
||||
# ensure it's a numeric string if provided
|
||||
if not v:
|
||||
return v
|
||||
try:
|
||||
int(v)
|
||||
return v
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError("date_range must be a numeric string")
|
||||
|
||||
class Config:
|
||||
# allow population by field name when constructing model programmatically
|
||||
validate_by_name = True
|
||||
|
||||
|
||||
class RatioConfigItem(BaseModel):
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
|
||||
counts: str = Field(..., description="数量")
|
||||
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
|
||||
|
||||
@field_validator("counts")
|
||||
@classmethod
|
||||
def validate_counts(cls, v: str) -> str:
|
||||
# ensure it's a numeric string
|
||||
try:
|
||||
int(v)
|
||||
except Exception:
|
||||
raise ValueError("counts must be a numeric string")
|
||||
return v
|
||||
|
||||
|
||||
class CreateRatioTaskRequest(BaseModel):
|
||||
name: str = Field(..., description="名称")
|
||||
description: Optional[str] = Field(None, description="描述")
|
||||
totals: str = Field(..., description="目标数量")
|
||||
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
|
||||
|
||||
@field_validator("totals")
|
||||
@classmethod
|
||||
def validate_totals(cls, v: str) -> str:
|
||||
try:
|
||||
iv = int(v)
|
||||
if iv < 0:
|
||||
raise ValueError("totals must be >= 0")
|
||||
except Exception:
|
||||
raise ValueError("totals must be a numeric string")
|
||||
return v
|
||||
|
||||
|
||||
class TargetDatasetInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
datasetType: str
|
||||
status: str
|
||||
|
||||
|
||||
class CreateRatioTaskResponse(BaseModel):
|
||||
# task info
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
totals: int
|
||||
status: TaskStatus
|
||||
# echoed config
|
||||
config: List[RatioConfigItem]
|
||||
# created dataset
|
||||
targetDataset: TargetDatasetInfo
|
||||
|
||||
|
||||
class RatioTaskItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
totals: Optional[int] = None
|
||||
target_dataset_id: Optional[str] = None
|
||||
target_dataset_name: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
|
||||
class PagedRatioTaskResponse(BaseModel):
|
||||
content: List[RatioTaskItem]
|
||||
totalElements: int
|
||||
totalPages: int
|
||||
page: int
|
||||
size: int
|
||||
|
||||
|
||||
class RatioTaskDetailResponse(BaseModel):
|
||||
"""Detailed response for a ratio task."""
|
||||
id: str = Field(..., description="任务ID")
|
||||
name: str = Field(..., description="任务名称")
|
||||
description: Optional[str] = Field(None, description="任务描述")
|
||||
status: str = Field(..., description="任务状态")
|
||||
totals: int = Field(..., description="目标总数")
|
||||
config: List[Dict[str, Any]] = Field(..., description="配比配置")
|
||||
target_dataset: Dict[str, Any] = Field(..., description="目标数据集信息")
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, description="更新时间")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
318
runtime/datamate-python/app/module/ratio/service/ratio_task.py
Normal file
318
runtime/datamate-python/app/module/ratio/service/ratio_task.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models.ratio_task import RatioInstance, RatioRelation
|
||||
from app.db.models import Dataset, DatasetFiles
|
||||
from app.db.session import AsyncSessionLocal
|
||||
from app.module.dataset.schema.dataset_file import DatasetFileTag
|
||||
from app.module.shared.schema import TaskStatus
|
||||
from app.module.ratio.schema.ratio_task import FilterCondition
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RatioTaskService:
|
||||
"""Service for Ratio Task DB operations."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str],
|
||||
totals: int,
|
||||
config: List[Dict[str, Any]],
|
||||
target_dataset_id: Optional[str] = None,
|
||||
) -> RatioInstance:
|
||||
"""Create a ratio task instance and its relations.
|
||||
|
||||
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
|
||||
"""
|
||||
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
|
||||
|
||||
instance = RatioInstance(
|
||||
name=name,
|
||||
description=description,
|
||||
totals=totals,
|
||||
target_dataset_id=target_dataset_id,
|
||||
status="PENDING",
|
||||
)
|
||||
self.db.add(instance)
|
||||
await self.db.flush() # populate instance.id
|
||||
|
||||
for item in config or []:
|
||||
relation = RatioRelation(
|
||||
ratio_instance_id=instance.id,
|
||||
source_dataset_id=item.get("dataset_id"),
|
||||
counts=int(item.get("counts", 0)),
|
||||
filter_conditions=json.dumps({
|
||||
'date_range': item.get("filter_conditions").date_range,
|
||||
'label': {
|
||||
"label":item.get("filter_conditions").label.label,
|
||||
"value":item.get("filter_conditions").label.value,
|
||||
},
|
||||
})
|
||||
)
|
||||
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
|
||||
self.db.add(relation)
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(instance)
|
||||
logger.info(f"Ratio task created: {instance.id}")
|
||||
return instance
|
||||
|
||||
# ========================= Execution (Background) ========================= #
|
||||
|
||||
@staticmethod
|
||||
async def execute_dataset_ratio_task(instance_id: str) -> None:
|
||||
"""Execute a ratio task in background.
|
||||
|
||||
Supported ratio_method:
|
||||
- DATASET: randomly select counts files from each source dataset
|
||||
- TAG: randomly select counts files matching relation.filter_conditions tags
|
||||
|
||||
Steps:
|
||||
- Mark instance RUNNING
|
||||
- For each relation: fetch ACTIVE files, optionally filter by tags
|
||||
- Copy selected files into target dataset
|
||||
- Update dataset statistics and mark instance SUCCESS/FAILED
|
||||
"""
|
||||
async with AsyncSessionLocal() as session: # type: AsyncSession
|
||||
try:
|
||||
# Load instance and relations
|
||||
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
||||
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
|
||||
if not instance:
|
||||
logger.error(f"Ratio instance not found: {instance_id}")
|
||||
return
|
||||
logger.info(f"start execute ratio task: {instance_id}")
|
||||
|
||||
rel_res = await session.execute(
|
||||
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
|
||||
)
|
||||
relations: List[RatioRelation] = list(rel_res.scalars().all())
|
||||
|
||||
# Mark running
|
||||
instance.status = TaskStatus.RUNNING.name
|
||||
|
||||
# Load target dataset
|
||||
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
|
||||
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
|
||||
if not target_ds:
|
||||
logger.error(f"Target dataset not found for instance {instance_id}")
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
return
|
||||
|
||||
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
|
||||
|
||||
# Update target dataset statistics
|
||||
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
|
||||
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
|
||||
# If target dataset has files, mark it ACTIVE
|
||||
if (target_ds.file_count or 0) > 0: # type: ignore
|
||||
target_ds.status = "ACTIVE"
|
||||
|
||||
# Done
|
||||
instance.status = TaskStatus.COMPLETED.name
|
||||
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
|
||||
try:
|
||||
# Try mark failed
|
||||
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
|
||||
instance = inst_res.scalar_one_or_none()
|
||||
if instance:
|
||||
instance.status = TaskStatus.FAILED.name
|
||||
finally:
|
||||
pass
|
||||
finally:
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
|
||||
# Preload existing target file paths for deduplication
|
||||
existing_path_rows = await session.execute(
|
||||
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
|
||||
)
|
||||
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
|
||||
|
||||
added_count = 0
|
||||
added_size = 0
|
||||
|
||||
for rel in relations:
|
||||
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
|
||||
continue
|
||||
|
||||
files = await RatioTaskService.get_files(rel, session)
|
||||
|
||||
if not files:
|
||||
continue
|
||||
|
||||
pick_n = min(rel.counts or 0, len(files))
|
||||
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
|
||||
|
||||
# Copy into target dataset with de-dup by target path
|
||||
for f in chosen:
|
||||
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
|
||||
added_count += 1
|
||||
added_size += int(f.file_size or 0)
|
||||
|
||||
# Periodically flush to avoid huge transactions
|
||||
await session.flush()
|
||||
return added_count, added_size
|
||||
|
||||
@staticmethod
|
||||
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
|
||||
src_path = f.file_path
|
||||
dst_prefix = f"/dataset/{target_ds.id}/"
|
||||
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
|
||||
|
||||
new_path = dst_prefix + file_name
|
||||
dst_dir = os.path.dirname(new_path)
|
||||
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
|
||||
await asyncio.to_thread(shutil.copy2, src_path, new_path)
|
||||
|
||||
file_data = {
|
||||
"dataset_id": target_ds.id, # type: ignore
|
||||
"file_name": file_name,
|
||||
"file_path": new_path,
|
||||
"file_type": f.file_type,
|
||||
"file_size": f.file_size,
|
||||
"check_sum": f.check_sum,
|
||||
"tags": f.tags,
|
||||
"tags_updated_at": datetime.now(),
|
||||
"dataset_filemetadata": f.dataset_filemetadata,
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
file_record = {k: v for k, v in file_data.items() if v is not None}
|
||||
session.add(DatasetFiles(**file_record))
|
||||
existing_paths.add(new_path)
|
||||
|
||||
@staticmethod
|
||||
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
|
||||
file_name = f.file_name
|
||||
new_path = dst_prefix + file_name
|
||||
|
||||
# Handle file path conflicts by appending a number to the filename
|
||||
if new_path in existing_paths:
|
||||
file_name_base, file_ext = os.path.splitext(file_name)
|
||||
counter = 1
|
||||
original_file_name = file_name
|
||||
while new_path in existing_paths:
|
||||
file_name = f"{file_name_base}_{counter}{file_ext}"
|
||||
new_path = f"{dst_prefix}{file_name}"
|
||||
counter += 1
|
||||
if counter > 1000: # Safety check to prevent infinite loops
|
||||
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
|
||||
break
|
||||
return file_name
|
||||
|
||||
@staticmethod
|
||||
async def get_files(rel: RatioRelation, session) -> list[Any]:
|
||||
# Fetch all files for the source dataset (ACTIVE only)
|
||||
files_res = await session.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.dataset_id == rel.source_dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
)
|
||||
files = list(files_res.scalars().all())
|
||||
|
||||
# TAG mode: filter by tags according to relation.filter_conditions
|
||||
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
|
||||
if conditions:
|
||||
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
|
||||
return files
|
||||
|
||||
# ------------------------- helpers for TAG filtering ------------------------- #
|
||||
|
||||
@staticmethod
|
||||
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
|
||||
"""Parse filter_conditions JSON string into a FilterCondition object.
|
||||
|
||||
Args:
|
||||
conditions: JSON string containing filter conditions
|
||||
|
||||
Returns:
|
||||
FilterCondition object if conditions is not None/empty, otherwise None
|
||||
"""
|
||||
if not conditions:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(conditions)
|
||||
return FilterCondition(**data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse filter conditions: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating FilterCondition: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
|
||||
if not conditions:
|
||||
return True
|
||||
logger.info(f"start filter file: {file}, conditions: {conditions}")
|
||||
|
||||
# Check data range condition if provided
|
||||
if conditions.date_range:
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
data_range_days = int(conditions.date_range)
|
||||
if data_range_days > 0:
|
||||
cutoff_date = datetime.now() - timedelta(days=data_range_days)
|
||||
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
|
||||
return False
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
|
||||
return False
|
||||
|
||||
# Check label condition if provided
|
||||
if conditions.label:
|
||||
tags = file.tags
|
||||
if not tags:
|
||||
return False
|
||||
try:
|
||||
# tags could be a list of strings or list of objects with 'name'
|
||||
tag_names = RatioTaskService.get_all_tags(tags)
|
||||
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get tags for {file}", e)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all_tags(tags) -> set[str]:
|
||||
"""获取所有处理后的标签字符串列表"""
|
||||
all_tags = set()
|
||||
if not tags:
|
||||
return all_tags
|
||||
|
||||
file_tags = []
|
||||
for tag_data in tags:
|
||||
# 处理可能的命名风格转换(下划线转驼峰)
|
||||
processed_data = {}
|
||||
for key, value in tag_data.items():
|
||||
# 将驼峰转为下划线以匹配 Pydantic 模型字段
|
||||
processed_data[key] = value
|
||||
# 创建 DatasetFileTag 对象
|
||||
file_tag = DatasetFileTag(**processed_data)
|
||||
file_tags.append(file_tag)
|
||||
|
||||
for file_tag in file_tags:
|
||||
for tag_data in file_tag.get_tags():
|
||||
all_tags.add(tag_data)
|
||||
return all_tags
|
||||
Reference in New Issue
Block a user