feature:增加数据配比功能 (#52)

* refactor: 修改调整数据归集实现,删除无用代码,优化代码结构

* feature: 每天凌晨00:00扫描所有数据集,检查数据集是否超过了预设的保留天数,超出保留天数的数据集调用删除接口进行删除

* fix: 修改删除数据集文件的逻辑,上传到数据集中的文件会同时删除数据库中的记录和文件系统中的文件,归集过来的文件仅删除数据库中的记录

* fix: 增加参数校验和接口定义,删除不使用的接口

* fix: 数据集统计数据默认为0

* feature: 数据集状态增加流转,创建时为草稿状态,上传文件或者归集文件后修改为活动状态

* refactor: 修改分页查询归集任务的代码

* fix: 更新后重新执行;归集任务执行增加事务控制

* feature: 创建归集任务时能够同步创建数据集,更新归集任务时能更新到指定数据集

* fix: 创建归集任务不需要创建数据集时不应该报错

* fix: 修复删除文件时数据集的统计数据不变动

* feature: 查询数据集详情时能够获取到文件标签分布

* fix: tags为空时不进行分析

* fix: 状态修改为ACTIVE

* fix: 修改解析tag的方法

* feature: 实现创建、分页查询、删除配比任务

* feature: 实现创建、分页查询、删除配比任务的前端交互

* fix: 修复进度计算异常导致的页面报错
This commit is contained in:
hefanli
2025-11-03 10:17:39 +08:00
committed by GitHub
parent 07edf16044
commit 08bd4eca5c
32 changed files with 1894 additions and 1028 deletions

View File

@@ -2,6 +2,7 @@ from fastapi import APIRouter
from .system.interface import router as system_router
from .annotation.interface import router as annotation_router
from .synthesis.interface import router as ratio_router
router = APIRouter(
prefix="/api"
@@ -9,5 +10,6 @@ router = APIRouter(
router.include_router(system_router)
router.include_router(annotation_router)
router.include_router(ratio_router)
__all__ = ["router"]
__all__ = ["router"]

View File

@@ -0,0 +1,11 @@
from fastapi import APIRouter
router = APIRouter(
prefix="/synthesis",
tags = ["synthesis"]
)
# Include sub-routers
from .ratio_task import router as ratio_task_router
router.include_router(ratio_task_router)

View File

@@ -0,0 +1,253 @@
import asyncio
from typing import Set
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import or_, func, delete, select
from app.core.logging import get_logger
from app.db.models import Dataset
from app.db.session import get_db
from app.module.dataset import DatasetManagementService
from app.module.shared.schema import StandardResponse
from app.module.synthesis.schema.ratio_task import (
CreateRatioTaskResponse,
CreateRatioTaskRequest,
PagedRatioTaskResponse,
RatioTaskItem,
TargetDatasetInfo,
)
from app.module.synthesis.service.ratio_task import RatioTaskService
from app.db.models.ratio_task import RatioInstance, RatioRelation
router = APIRouter(
prefix="/ratio-task",
tags=["synthesis/ratio-task"],
)
logger = get_logger(__name__)
@router.post("", response_model=StandardResponse[CreateRatioTaskResponse], status_code=200)
async def create_ratio_task(
req: CreateRatioTaskRequest,
db: AsyncSession = Depends(get_db),
):
"""
创建配比任务
Path: /api/synthesis/ratio-task
"""
try:
# 校验 config 中的 dataset_id 是否存在
dm_service = DatasetManagementService(db)
source_types = await get_dataset_types(dm_service, req)
await valid_exists(db, req)
# 创建目标数据集:名称使用“<任务名称>-配比生成-时间戳”
target_dataset_name = f"{req.name}-配比生成-{datetime.now().strftime('%Y%m%d%H%M%S')}"
target_type = get_target_dataset_type(source_types)
target_dataset = Dataset(
name=target_dataset_name,
description=req.description or "",
dataset_type=target_type,
status="DRAFT",
)
db.add(target_dataset)
await db.flush() # 获取 target_dataset.id
service = RatioTaskService(db)
instance = await service.create_task(
name=req.name,
description=req.description,
totals=int(req.totals),
ratio_method=req.ratio_method,
config=[
{
"dataset_id": item.dataset_id,
"counts": int(item.counts),
"filter_conditions": item.filter_conditions,
}
for item in req.config
],
target_dataset_id=target_dataset.id,
)
# 异步执行配比任务(支持 DATASET / TAG)
asyncio.create_task(RatioTaskService.execute_dataset_ratio_task(instance.id))
return StandardResponse(
code=200,
message="success",
data=CreateRatioTaskResponse(
id=instance.id,
name=instance.name,
description=instance.description,
totals=instance.totals or 0,
ratio_method=instance.ratio_method or req.ratio_method,
status=instance.status or "PENDING",
config=req.config,
targetDataset=TargetDatasetInfo(
id=str(target_dataset.id),
name=str(target_dataset.name),
datasetType=str(target_dataset.dataset_type),
status=str(target_dataset.status),
)
)
)
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to create ratio task: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PagedRatioTaskResponse], status_code=200)
async def list_ratio_tasks(
page: int = 1,
size: int = 10,
name: str | None = None,
status: str | None = None,
db: AsyncSession = Depends(get_db),
):
"""分页查询配比任务,支持名称与状态过滤"""
try:
query = select(RatioInstance)
# filters
if name:
# simple contains filter
query = query.where(RatioInstance.name.like(f"%{name}%"))
if status:
query = query.where(RatioInstance.status == status)
# count
count_q = select(func.count()).select_from(query.subquery())
total = (await db.execute(count_q)).scalar_one()
# page (1-based)
page_index = max(page, 1) - 1
query = query.order_by(RatioInstance.created_at.desc()).offset(page_index * size).limit(size)
result = await db.execute(query)
items = result.scalars().all()
# map to DTOs and attach dataset name
# preload datasets
ds_ids = {i.target_dataset_id for i in items if i.target_dataset_id}
ds_map = {}
if ds_ids:
ds_res = await db.execute(select(Dataset).where(Dataset.id.in_(list(ds_ids))))
for d in ds_res.scalars().all():
ds_map[d.id] = d
content: list[RatioTaskItem] = []
for i in items:
ds = ds_map.get(i.target_dataset_id) if i.target_dataset_id else None
content.append(
RatioTaskItem(
id=i.id,
name=i.name or "",
description=i.description,
status=i.status,
totals=i.totals,
ratio_method=i.ratio_method,
target_dataset_id=i.target_dataset_id,
target_dataset_name=(ds.name if ds else None),
created_at=str(i.created_at) if getattr(i, "created_at", None) else None,
updated_at=str(i.updated_at) if getattr(i, "updated_at", None) else None,
)
)
total_pages = (total + size - 1) // size if size > 0 else 0
return StandardResponse(
code=200,
message="success",
data=PagedRatioTaskResponse(
content=content,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
),
)
except Exception as e:
logger.error(f"Failed to list ratio tasks: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("", response_model=StandardResponse[str], status_code=200)
async def delete_ratio_tasks(
ids: list[str] = Query(..., description="要删除的配比任务ID列表"),
db: AsyncSession = Depends(get_db),
):
"""删除配比任务,返回简单结果字符串。"""
try:
if not ids:
raise HTTPException(status_code=400, detail="ids is required")
# 先删除关联关系
await db.execute(
delete(RatioRelation).where(RatioRelation.ratio_instance_id.in_(ids))
)
# 再删除实例
await db.execute(
delete(RatioInstance).where(RatioInstance.id.in_(ids))
)
await db.commit()
return StandardResponse(code=200, message="success", data="success")
except HTTPException:
await db.rollback()
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete ratio tasks: {e}")
raise HTTPException(status_code=500, detail=f"Fail to delete ratio task: {e}")
async def valid_exists(db, req: CreateRatioTaskRequest):
# 校验配比任务名称不能重复
exist_task_q = await db.execute(
select(RatioInstance).where(RatioInstance.name == req.name)
)
try:
exist_task_q.scalar_one_or_none()
except Exception as e:
logger.error(f"create ratio task failed: ratio task {req.name} already exists")
raise HTTPException(status_code=400, detail=f"ratio task {req.name} already exists")
async def get_dataset_types(dm_service: DatasetManagementService, req: CreateRatioTaskRequest) -> Set[str]:
source_types: Set[str] = set()
for item in req.config:
dataset = await dm_service.get_dataset(item.dataset_id)
if not dataset:
raise HTTPException(status_code=400, detail=f"dataset_id not found: {item.dataset_id}")
else:
dtype = getattr(dataset, "dataset_type", None) or getattr(dataset, "datasetType", None)
source_types.add(str(dtype).upper())
return source_types
def get_target_dataset_type(source_types: Set[str]) -> str:
# 根据源数据集类型决定目标数据集类型
# 规则:
# 1) 若全部为 TEXT -> TEXT
# 2) 若存在且仅存在一种介质类型(IMAGE/AUDIO/VIDEO),且无其它类型 -> 对应介质类型
# 3) 其它情况 -> OTHER
media_modalities = {"IMAGE", "AUDIO", "VIDEO"}
target_type = "OTHER"
if source_types == {"TEXT"}:
target_type = "TEXT"
else:
media_involved = source_types & media_modalities
if len(media_involved) == 1 and source_types == media_involved:
# 仅有一种介质类型且无其它类型
target_type = next(iter(media_involved))
return target_type

View File

@@ -0,0 +1,86 @@
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
class RatioConfigItem(BaseModel):
dataset_id: str = Field(..., alias="datasetId", description="数据集id")
counts: str = Field(..., description="数量")
filter_conditions: str = Field(..., description="过滤条件")
@field_validator("counts")
@classmethod
def validate_counts(cls, v: str) -> str:
# ensure it's a numeric string
try:
int(v)
except Exception:
raise ValueError("counts must be a numeric string")
return v
class CreateRatioTaskRequest(BaseModel):
name: str = Field(..., description="名称")
description: Optional[str] = Field(None, description="描述")
totals: str = Field(..., description="目标数量")
ratio_method: str = Field(..., description="配比方式", alias="ratio_method")
config: List[RatioConfigItem] = Field(..., description="配比设置列表")
@field_validator("ratio_method")
@classmethod
def validate_ratio_method(cls, v: str) -> str:
allowed = {"TAG", "DATASET"}
if v not in allowed:
raise ValueError(f"ratio_method must be one of {allowed}")
return v
@field_validator("totals")
@classmethod
def validate_totals(cls, v: str) -> str:
try:
iv = int(v)
if iv < 0:
raise ValueError("totals must be >= 0")
except Exception:
raise ValueError("totals must be a numeric string")
return v
class TargetDatasetInfo(BaseModel):
id: str
name: str
datasetType: str
status: str
class CreateRatioTaskResponse(BaseModel):
# task info
id: str
name: str
description: Optional[str] = None
totals: int
ratio_method: str
status: str
# echoed config
config: List[RatioConfigItem]
# created dataset
targetDataset: TargetDatasetInfo
class RatioTaskItem(BaseModel):
id: str
name: str
description: Optional[str] = None
status: Optional[str] = None
totals: Optional[int] = None
ratio_method: Optional[str] = None
target_dataset_id: Optional[str] = None
target_dataset_name: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None
class PagedRatioTaskResponse(BaseModel):
content: List[RatioTaskItem]
totalElements: int
totalPages: int
page: int
size: int

View File

@@ -0,0 +1,282 @@
from typing import List, Optional, Dict, Any
import random
import os
import shutil
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.db.models.ratio_task import RatioInstance, RatioRelation
from app.db.models import Dataset, DatasetFiles
from app.db.session import AsyncSessionLocal
logger = get_logger(__name__)
class RatioTaskService:
"""Service for Ratio Task DB operations."""
def __init__(self, db: AsyncSession):
self.db = db
async def create_task(
self,
*,
name: str,
description: Optional[str],
totals: int,
ratio_method: str,
config: List[Dict[str, Any]],
target_dataset_id: Optional[str] = None,
) -> RatioInstance:
"""Create a ratio task instance and its relations.
config item format: {"dataset_id": str, "counts": int, "filter_conditions": str}
"""
logger.info(f"Creating ratio task: name={name}, method={ratio_method}, totals={totals}, items={len(config or [])}")
instance = RatioInstance(
name=name,
description=description,
ratio_method=ratio_method,
totals=totals,
target_dataset_id=target_dataset_id,
status="PENDING",
)
self.db.add(instance)
await self.db.flush() # populate instance.id
for item in config or []:
relation = RatioRelation(
ratio_instance_id=instance.id,
source_dataset_id=item.get("dataset_id"),
counts=int(item.get("counts", 0)),
filter_conditions=item.get("filter_conditions"),
)
self.db.add(relation)
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Ratio task created: {instance.id}")
return instance
# ========================= Execution (Background) ========================= #
@staticmethod
async def execute_dataset_ratio_task(instance_id: str) -> None:
"""Execute a ratio task in background.
Supported ratio_method:
- DATASET: randomly select counts files from each source dataset
- TAG: randomly select counts files matching relation.filter_conditions tags
Steps:
- Mark instance RUNNING
- For each relation: fetch ACTIVE files, optionally filter by tags
- Copy selected files into target dataset
- Update dataset statistics and mark instance SUCCESS/FAILED
"""
async with AsyncSessionLocal() as session: # type: AsyncSession
try:
# Load instance and relations
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance: Optional[RatioInstance] = inst_res.scalar_one_or_none()
if not instance:
logger.error(f"Ratio instance not found: {instance_id}")
return
logger.info(f"start execute ratio task: {instance_id}")
rel_res = await session.execute(
select(RatioRelation).where(RatioRelation.ratio_instance_id == instance_id)
)
relations: List[RatioRelation] = list(rel_res.scalars().all())
# Mark running
instance.status = "RUNNING"
if instance.ratio_method not in {"DATASET", "TAG"}:
logger.info(f"Instance {instance_id} ratio_method={instance.ratio_method} not supported yet")
instance.status = "SUCCESS"
return
# Load target dataset
ds_res = await session.execute(select(Dataset).where(Dataset.id == instance.target_dataset_id))
target_ds: Optional[Dataset] = ds_res.scalar_one_or_none()
if not target_ds:
logger.error(f"Target dataset not found for instance {instance_id}")
instance.status = "FAILED"
return
# Preload existing target file paths for deduplication
existing_path_rows = await session.execute(
select(DatasetFiles.file_path).where(DatasetFiles.dataset_id == target_ds.id)
)
existing_paths = set(p for p in existing_path_rows.scalars().all() if p)
added_count = 0
added_size = 0
for rel in relations:
if not rel.source_dataset_id or not rel.counts or rel.counts <= 0:
continue
# Fetch all files for the source dataset (ACTIVE only)
files_res = await session.execute(
select(DatasetFiles).where(
DatasetFiles.dataset_id == rel.source_dataset_id,
DatasetFiles.status == "ACTIVE",
)
)
files = list(files_res.scalars().all())
# TAG mode: filter by tags according to relation.filter_conditions
if instance.ratio_method == "TAG":
required_tags = RatioTaskService._parse_required_tags(rel.filter_conditions)
if required_tags:
files = [f for f in files if RatioTaskService._file_contains_tags(f, required_tags)]
if not files:
continue
pick_n = min(rel.counts or 0, len(files))
chosen = random.sample(files, pick_n) if pick_n < len(files) else files
# Copy into target dataset with de-dup by target path
for f in chosen:
src_path = f.file_path
new_path = src_path
needs_copy = False
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(src_path, str) and src_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = src_path.replace(src_prefix, dst_prefix, 1)
needs_copy = True
# De-dup by target path
if new_path in existing_paths:
continue
# Perform copy only when needed
if needs_copy:
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
session.add(new_file)
existing_paths.add(new_path)
added_count += 1
added_size += int(f.file_size or 0)
# Periodically flush to avoid huge transactions
await session.flush()
# Update target dataset statistics
target_ds.file_count = (target_ds.file_count or 0) + added_count # type: ignore
target_ds.size_bytes = (target_ds.size_bytes or 0) + added_size # type: ignore
# If target dataset has files, mark it ACTIVE
if (target_ds.file_count or 0) > 0: # type: ignore
target_ds.status = "ACTIVE"
# Done
instance.status = "SUCCESS"
logger.info(f"Dataset ratio execution completed: instance={instance_id}, files={added_count}, size={added_size}")
except Exception as e:
logger.exception(f"Dataset ratio execution failed for {instance_id}: {e}")
try:
# Try mark failed
inst_res = await session.execute(select(RatioInstance).where(RatioInstance.id == instance_id))
instance = inst_res.scalar_one_or_none()
if instance:
instance.status = "FAILED"
finally:
pass
finally:
await session.commit()
# ------------------------- helpers for TAG filtering ------------------------- #
@staticmethod
def _parse_required_tags(conditions: Optional[str]) -> set[str]:
"""Parse filter_conditions into a set of required tag strings.
Supports simple separators: comma, semicolon, space. Empty/None -> empty set.
"""
if not conditions:
return set()
raw = conditions.replace("\n", " ")
seps = [",", ";", " "]
tokens = [raw]
for sep in seps:
nxt = []
for t in tokens:
nxt.extend(t.split(sep))
tokens = nxt
return {t.strip() for t in tokens if t and t.strip()}
@staticmethod
def _file_contains_tags(f: DatasetFiles, required: set[str]) -> bool:
if not required:
return True
tags = f.tags
if not tags:
return False
try:
# tags could be a list of strings or list of objects with 'name'
tag_names = set()
if isinstance(tags, list):
for item in tags:
if isinstance(item, str):
tag_names.add(item)
elif isinstance(item, dict):
name = item.get("name") or item.get("label") or item.get("tag")
if isinstance(name, str):
tag_names.add(name)
elif isinstance(tags, dict):
# flat dict of name->... treat keys as tags
tag_names = set(map(str, tags.keys()))
else:
return False
logger.info(f">>>>>{tags}>>>>>{required}, {tag_names}")
return required.issubset(tag_names)
except Exception:
return False
@staticmethod
async def get_new_file(f, rel: RatioRelation, target_ds: Dataset) -> DatasetFiles:
new_path = f.file_path
src_prefix = f"/dataset/{rel.source_dataset_id}"
if isinstance(f.file_path, str) and f.file_path.startswith(src_prefix):
dst_prefix = f"/dataset/{target_ds.id}"
new_path = f.file_path.replace(src_prefix, dst_prefix, 1)
dst_dir = os.path.dirname(new_path)
# Ensure directory and copy the file in a thread to avoid blocking the event loop
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, f.file_path, new_path)
new_file = DatasetFiles(
dataset_id=target_ds.id, # type: ignore
file_name=f.file_name,
file_path=new_path,
file_type=f.file_type,
file_size=f.file_size,
check_sum=f.check_sum,
tags=f.tags,
dataset_filemetadata=f.dataset_filemetadata,
status="ACTIVE",
)
return new_file