feature: add data-evaluation

* feature: add evaluation task management function

* feature: add evaluation task detail page

* fix: delete duplicate definition for table t_model_config

* refactor: rename package synthesis to ratio

* refactor: add eval file table and  refactor related code

* fix: calling large models in parallel during evaluation
This commit is contained in:
hefanli
2025-12-04 09:23:54 +08:00
committed by GitHub
parent 265e284fb8
commit 1d19cd3a62
52 changed files with 2882 additions and 1244 deletions

View File

@@ -0,0 +1,11 @@
from fastapi import APIRouter
router = APIRouter(
prefix="/synthesis",
tags = ["synthesis"]
)
# Include sub-routers
from .ratio_task import router as ratio_task_router
router.include_router(ratio_task_router)

View File

@@ -0,0 +1,342 @@
import asyncio
import uuid
from typing import Set
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import or_, func, delete, select
from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse, TaskStatus
from app.module.ratio.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
PagedRatioTaskResponse,
RatioTaskItem,
TargetDatasetInfo,
RatioTaskDetailResponse,
)
from app.module.ratio.service.ratio_task import RatioTaskService
from app.db.models.ratio_task import RatioInstance, RatioRelation, RatioRelation as RatioRelationModel
router = APIRouter(
prefix="/ratio-task",
tags=["synthesis/ratio-task"],
)
logger = get_logger(__name__)
@router.post("", response_model=StandardResponse[CreateRatioTaskResponse], status_code=200)
async def create_ratio_task(
req: CreateRatioTaskRequest,
db: AsyncSession = Depends(get_db),
):
"""
创建配比任务
Path: /api/synthesis/ratio-task
"""
try:
# 校验 config 中的 dataset_id 是否存在
dm_service = DatasetManagementService(db)
source_types = await get_dataset_types(dm_service, req)
await valid_exists(db, req)
target_dataset = await create_target_dataset(db, req, source_types)
instance = await create_ratio_instance(db, req, target_dataset)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
response_data = CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
status=instance.status or TaskStatus.PENDING.name,
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
name=str(target_dataset.name),
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
)
)
return StandardResponse(
code=200,
message="success",
data=response_data
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
async def create_ratio_instance(db, req: CreateRatioTaskRequest, target_dataset: Dataset) -> RatioInstance:
service = RatioTaskService(db)
logger.info(f"create_ratio_instance: {req}")
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
return instance
async def create_target_dataset(db, req: CreateRatioTaskRequest, source_types: set[str]) -> Dataset:
# 创建目标数据集:名称使用“<任务名称>-时间戳”
target_dataset_name = f"{req.name}-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types)
target_dataset_id = uuid.uuid4()
target_dataset = Dataset(
id=str(target_dataset_id),
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
path=f"/dataset/{target_dataset_id}",
)
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
return target_dataset
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
async def list_ratio_tasks(
page: int = 1,
size: int = 10,
name: str | None = None,
status: str | None = None,
db: AsyncSession = Depends(get_db),
):
"""分页查询配比任务,支持名称与状态过滤"""
try:
query = select(RatioInstance)
# filters
if name:
# simple contains filter
query = query.where(RatioInstance.name.like(f"%{name}%"))
if status:
query = query.where(RatioInstance.status == status)
# count
count_q = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_q)).scalar_one()
# page (1-based)
page_index = max(page, 1) - 1
query = query.order_by(RatioInstance.created_at.desc()).offset(page_index * size).limit(size)
result = await db.execute(query)
items = result.scalars().all()
# map to DTOs and attach dataset name
# preload datasets
ds_ids = {i.target_dataset_id for i in items if i.target_dataset_id}
ds_map = {}
if ds_ids:
ds_res = await db.execute(select(Dataset).where(Dataset.id.in_(list(ds_ids))))
for d in ds_res.scalars().all():
ds_map[d.id] = d
content: list[RatioTaskItem] = []
for i in items:
ds = ds_map.get(i.target_dataset_id) if i.target_dataset_id else None
content.append(
RatioTaskItem(
id=i.id,
name=i.name or "",
description=i.description,
status=i.status,
totals=i.totals,
target_dataset_id=i.target_dataset_id,
target_dataset_name=(ds.name if ds else None),
created_at=str(i.created_at) if getattr(i, "created_at", None) else None,
updated_at=str(i.updated_at) if getattr(i, "updated_at", None) else None,
)
)
total_pages = (total + size - 1) // size if size > 0 else 0
return StandardResponse(
code=200,
message="success",
data=PagedRatioTaskResponse(
content=content,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
),
)
except Exception as e:
logger.error(f"Failed to list ratio tasks: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("", response_model=StandardResponse[str], status_code=200)
async def delete_ratio_tasks(
ids: list[str] = Query(..., description="要删除的配比任务ID列表"),
db: AsyncSession = Depends(get_db),
):
"""删除配比任务,返回简单结果字符串。"""
try:
if not ids:
raise HTTPException(status_code=400, detail="ids is required")
# 先删除关联关系
await db.execute(
delete(RatioRelation).where(RatioRelation.ratio_instance_id.in_(ids))
)
# 再删除实例
await db.execute(
delete(RatioInstance).where(RatioInstance.id.in_(ids))
)
await db.commit()
return StandardResponse(code=200, message="success", data="success")
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete ratio tasks: {e}")
raise HTTPException(status_code=500, detail=f"Fail to delete ratio task: {e}")
async def valid_exists(db: AsyncSession, req: CreateRatioTaskRequest) -> None:
"""校验配比任务名称不能重复(精确匹配,去除首尾空格)。"""
name = (req.name or "").strip()
if not name:
raise HTTPException(status_code=400, detail="ratio task name is required")
# 查询是否已存在同名任务
ratio_task = await db.execute(select(RatioInstance.id).where(RatioInstance.name == name))
exists = ratio_task.scalar_one_or_none()
if exists is not None:
logger.error(f"create ratio task failed: ratio task '{name}' already exists (id={exists})")
raise HTTPException(status_code=400, detail=f"ratio task '{name}' already exists")
async def get_dataset_types(dm_service: DatasetManagementService, req: CreateRatioTaskRequest) -> Set[str]:
source_types: Set[str] = set()
for item in req.config:
dataset = await dm_service.get_dataset(item.dataset_id)
if not dataset:
raise HTTPException(status_code=400, detail=f"dataset_id not found: {item.dataset_id}")
else:
dtype = getattr(dataset, "dataset_type", None) or getattr(dataset, "datasetType", None)
source_types.add(str(dtype).upper())
return source_types
def get_target_dataset_type(source_types: Set[str]) -> str:
# 根据源数据集类型决定目标数据集类型
# 规则:
# 1) 若全部为 TEXT -> TEXT
# 2) 若存在且仅存在一种介质类型(IMAGE/AUDIO/VIDEO),且无其它类型 -> 对应介质类型
# 3) 其它情况 -> OTHER
media_modalities = {"IMAGE", "AUDIO", "VIDEO"}
target_type = "OTHER"
if source_types == {"TEXT"}:
target_type = "TEXT"
else:
media_involved = source_types & media_modalities
if len(media_involved) == 1 and source_types == media_involved:
# 仅有一种介质类型且无其它类型
target_type = next(iter(media_involved))
return target_type
@router.get("/{task_id}", response_model=StandardResponse[RatioTaskDetailResponse], status_code=200)
async def get_ratio_task(
task_id: str,
db: AsyncSession = Depends(get_db),
):
"""
获取配比任务详情
Path: /api/synthesis/ratio-task/{task_id}
"""
try:
# 查询任务实例
instance_res = await db.execute(
select(RatioInstance).where(RatioInstance.id == task_id)
)
instance = instance_res.scalar_one_or_none()
if not instance:
raise HTTPException(status_code=404, detail="Ratio task not found")
# 查询关联的配比关系
relations_res = await db.execute(
select(RatioRelationModel).where(RatioRelationModel.ratio_instance_id == task_id)
)
relations = list(relations_res.scalars().all())
# 查询目标数据集
target_ds = None
if instance.target_dataset_id:
ds_res = await db.execute(
select(Dataset).where(Dataset.id == instance.target_dataset_id)
)
target_ds = ds_res.scalar_one_or_none()
# 构建响应
config = [
{
"dataset_id": rel.source_dataset_id,
"counts": str(rel.counts) if rel.counts is not None else "0",
"filter_conditions": rel.filter_conditions or "",
}
for rel in relations
]
target_dataset_info = {
"id": str(target_ds.id) if target_ds else None,
"name": target_ds.name if target_ds else None,
"type": target_ds.dataset_type if target_ds else None,
"status": target_ds.status if target_ds else None,
"file_count": target_ds.file_count if target_ds else 0,
"size_bytes": target_ds.size_bytes if target_ds else 0,
}
return StandardResponse(
code=200,
message="success",
data=RatioTaskDetailResponse(
id=instance.id,
name=instance.name or "",
description=instance.description,
status=instance.status or "UNKNOWN",
totals=instance.totals or 0,
config=config,
target_dataset=target_dataset_info,
created_at=instance.created_at,
updated_at=instance.updated_at,
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get ratio task {task_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,124 @@
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
from app.core.logging import get_logger
from app.module.shared.schema.common import TaskStatus
logger = get_logger(__name__)
class LabelFilter(BaseModel):
label: Optional[str] = Field(..., description="标签")
value: Optional[str] = Field(None, description="标签值")
class FilterCondition(BaseModel):
date_range: Optional[str] = Field(None, description="数据范围", alias="dateRange")
label: Optional[LabelFilter] = Field(None, description="标签")
@field_validator("date_range")
@classmethod
def validate_date_range(cls, v: Optional[str]) -> Optional[str]:
# ensure it's a numeric string if provided
if not v:
return v
try:
int(v)
return v
except (ValueError, TypeError) as e:
raise ValueError("date_range must be a numeric string")
class Config:
# allow population by field name when constructing model programmatically
validate_by_name = True
class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量")
filter_conditions: FilterCondition = Field(..., alias="filterConditions", description="过滤条件")
@field_validator("counts")
@classmethod
def validate_counts(cls, v: str) -> str:
# ensure it's a numeric string
try:
int(v)
except Exception:
raise ValueError("counts must be a numeric string")
return v
class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量")
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("totals")
@classmethod
def validate_totals(cls, v: str) -> str:
try:
iv = int(v)
if iv < 0:
raise ValueError("totals must be >= 0")
except Exception:
raise ValueError("totals must be a numeric string")
return v
class TargetDatasetInfo(BaseModel):
id: str
name: str
datasetType: str
status: str
class CreateRatioTaskResponse(BaseModel):
# task info
id: str
name: str
description: Optional[str] = None
totals: int
status: TaskStatus
# echoed config
config: List[RatioConfigItem]
# created dataset
targetDataset: TargetDatasetInfo
class RatioTaskItem(BaseModel):
id: str
name: str
description: Optional[str] = None
status: Optional[str] = None
totals: Optional[int] = None
target_dataset_id: Optional[str] = None
target_dataset_name: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None
class PagedRatioTaskResponse(BaseModel):
content: List[RatioTaskItem]
totalElements: int
totalPages: int
page: int
size: int
class RatioTaskDetailResponse(BaseModel):
"""Detailed response for a ratio task."""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
description: Optional[str] = Field(None, description="任务描述")
status: str = Field(..., description="任务状态")
totals: int = Field(..., description="目标总数")
config: List[Dict[str, Any]] = Field(..., description="配比配置")
target_dataset: Dict[str, Any] = Field(..., description="目标数据集信息")
created_at: Optional[datetime] = Field(None, description="创建时间")
updated_at: Optional[datetime] = Field(None, description="更新时间")
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}

View File

@@ -0,0 +1,318 @@
from datetime import datetime
from typing import List, Optional, Dict, Any
import random
import json
import os
import shutil
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
from app.module.dataset.schema.dataset_file import DatasetFileTag
from app.module.shared.schema import TaskStatus
from app.module.ratio.schema.ratio_task import FilterCondition
logger = get_logger(__name__)
class RatioTaskService:
"""Service for Ratio Task DB operations."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_task(
self,
*,
name: str,
description: Optional[str],
totals: int,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
"""Create a ratio task instance and its relations.
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
)
self.db.add(instance)
await self.db.flush() # populate instance.id
for item in config or []:
relation = RatioRelation(
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=json.dumps({
'date_range': item.get("filter_conditions").date_range,
'label': {
"label":item.get("filter_conditions").label.label,
"value":item.get("filter_conditions").label.value,
},
})
)
logger.info(f"Relation created: {relation.id}, {relation}, {item}, {config}")
self.db.add(relation)
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Ratio task created: {instance.id}")
return instance
# ========================= Execution (Background) ========================= #
@staticmethod
async def execute_dataset_ratio_task(instance_id: str) -> None:
"""Execute a ratio task in background.
Supported ratio_method:
- DATASET: randomly select counts files from each source dataset
- TAG: randomly select counts files matching relation.filter_conditions tags
Steps:
- Mark instance RUNNING
- For each relation: fetch ACTIVE files, optionally filter by tags
- Copy selected files into target dataset
- Update dataset statistics and mark instance SUCCESS/FAILED
"""
async with AsyncSessionLocal() as session: # type: AsyncSession
try:
# Load instance and relations
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
if not instance:
logger.error(f"Ratio instance not found: {instance_id}")
return
logger.info(f"start execute ratio task: {instance_id}")
rel_res = await session.execute(
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
)
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = TaskStatus.RUNNING.name
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = TaskStatus.FAILED.name
return
added_count, added_size = await RatioTaskService.handle_ratio_relations(relations,session, target_ds)
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = TaskStatus.COMPLETED.name
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}, {instance.status}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = TaskStatus.FAILED.name
finally:
pass
finally:
await session.commit()
@staticmethod
async def handle_ratio_relations(relations: list[RatioRelation], session, target_ds: Dataset) -> tuple[int, int]:
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
files = await RatioTaskService.get_files(rel, session)
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
await RatioTaskService.handle_selected_file(existing_paths, f, session, target_ds)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
return added_count, added_size
@staticmethod
async def handle_selected_file(existing_paths: set[Any], f, session, target_ds: Dataset):
src_path = f.file_path
dst_prefix = f"/dataset/{target_ds.id}/"
file_name = RatioTaskService.get_new_file_name(dst_prefix, existing_paths, f)
new_path = dst_prefix + file_name
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
file_data = {
"dataset_id": target_ds.id, # type: ignore
"file_name": file_name,
"file_path": new_path,
"file_type": f.file_type,
"file_size": f.file_size,
"check_sum": f.check_sum,
"tags": f.tags,
"tags_updated_at": datetime.now(),
"dataset_filemetadata": f.dataset_filemetadata,
"status": "ACTIVE",
}
file_record = {k: v for k, v in file_data.items() if v is not None}
session.add(DatasetFiles(**file_record))
existing_paths.add(new_path)
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name
new_path = dst_prefix + file_name
# Handle file path conflicts by appending a number to the filename
if new_path in existing_paths:
file_name_base, file_ext = os.path.splitext(file_name)
counter = 1
original_file_name = file_name
while new_path in existing_paths:
file_name = f"{file_name_base}_{counter}{file_ext}"
new_path = f"{dst_prefix}{file_name}"
counter += 1
if counter > 1000: # Safety check to prevent infinite loops
logger.error(f"Could not find unique filename for {original_file_name} after 1000 attempts")
break
return file_name
@staticmethod
async def get_files(rel: RatioRelation, session) -> list[Any]:
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
conditions = RatioTaskService._parse_conditions(rel.filter_conditions)
if conditions:
files = [f for f in files if RatioTaskService._filter_file(f, conditions)]
return files
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_conditions(conditions: Optional[str]) -> Optional[FilterCondition]:
"""Parse filter_conditions JSON string into a FilterCondition object.
Args:
conditions: JSON string containing filter conditions
Returns:
FilterCondition object if conditions is not None/empty, otherwise None
"""
if not conditions:
return None
try:
data = json.loads(conditions)
return FilterCondition(**data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse filter conditions: {e}")
return None
except Exception as e:
logger.error(f"Error creating FilterCondition: {e}")
return None
@staticmethod
def _filter_file(file: DatasetFiles, conditions: FilterCondition) -> bool:
if not conditions:
return True
logger.info(f"start filter file: {file}, conditions: {conditions}")
# Check data range condition if provided
if conditions.date_range:
try:
from datetime import datetime, timedelta
data_range_days = int(conditions.date_range)
if data_range_days > 0:
cutoff_date = datetime.now() - timedelta(days=data_range_days)
if file.tags_updated_at and file.tags_updated_at < cutoff_date:
return False
except (ValueError, TypeError) as e:
logger.warning(f"Invalid data_range value: {conditions.date_range}", e)
return False
# Check label condition if provided
if conditions.label:
tags = file.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = RatioTaskService.get_all_tags(tags)
return f"{conditions.label.label}@{conditions.label.value}" in tag_names
except Exception as e:
logger.exception(f"Failed to get tags for {file}", e)
return False
return True
@staticmethod
def get_all_tags(tags) -> set[str]:
"""获取所有处理后的标签字符串列表"""
all_tags = set()
if not tags:
return all_tags
file_tags = []
for tag_data in tags:
# 处理可能的命名风格转换(下划线转驼峰)
processed_data = {}
for key, value in tag_data.items():
# 将驼峰转为下划线以匹配 Pydantic 模型字段
processed_data[key] = value
# 创建 DatasetFileTag 对象
file_tag = DatasetFileTag(**processed_data)
file_tags.append(file_tag)
for file_tag in file_tags:
for tag_data in file_tag.get_tags():
all_tags.add(tag_data)
return all_tags