Files
DataMate/runtime/datamate-python/app/module/annotation/interface/auto.py
Jerry Yan 8ffa131fad feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)
移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO
类型数据集均可用于自动标注任务。

- 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列
- 后端 schema/API/service 全链路传递 dataset_type
- Worker 动态构建 sample key(image/text/audio/video)和输出目录
- 前端移除数据集类型校验,下拉框显示数据集类型标识
- 输出数据集继承源数据集类型,不再硬编码为 IMAGE
- 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 23:23:05 +08:00

287 lines
9.4 KiB
Python

"""FastAPI routes for Annotation Operator Tasks.
兼容路由:
- GET/POST/DELETE /api/annotation/auto
- GET /api/annotation/auto/{task_id}/status
新路由:
- GET/POST/DELETE /api/annotation/operator-tasks
- GET /api/annotation/operator-tasks/{task_id}
"""
from __future__ import annotations
from typing import List, Literal
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
from ..service.auto import AutoAnnotationTaskService
router = APIRouter(
tags=["annotation/auto"],
)
logger = get_logger(__name__)
service = AutoAnnotationTaskService()
def _normalize_request_by_route(
request: CreateAutoAnnotationTaskRequest,
route_mode: Literal["legacy_auto", "operator_tasks"],
) -> CreateAutoAnnotationTaskRequest:
"""根据路由入口做请求标准化。"""
if route_mode == "legacy_auto":
# 旧接口强制走 legacy_yolo 模式,保持行为一致
return request.model_copy(update={"task_mode": "legacy_yolo"})
# 新接口默认走 pipeline 模式(若请求未显式指定 taskMode)
task_mode = request.task_mode
if request.pipeline and task_mode == "legacy_yolo":
task_mode = "pipeline"
return request.model_copy(update={"task_mode": task_mode})
async def _create_task_internal(
*,
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession,
user_context: RequestUserContext,
route_mode: Literal["legacy_auto", "operator_tasks"],
) -> AutoAnnotationTaskResponse:
normalized_request = _normalize_request_by_route(request, route_mode)
logger.info(
"Creating annotation task: route_mode=%s, name=%s, dataset_id=%s, task_mode=%s, executor_type=%s, config=%s, pipeline=%s, file_ids=%s",
route_mode,
normalized_request.name,
normalized_request.dataset_id,
normalized_request.task_mode,
normalized_request.executor_type,
normalized_request.config.model_dump(by_alias=True) if normalized_request.config else None,
[step.model_dump(by_alias=True) for step in normalized_request.pipeline]
if normalized_request.pipeline else None,
normalized_request.file_ids,
)
# 权限 + fileIds 归属校验
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
# 尝试获取数据集名称和总量用于冗余字段
dataset_name = None
dataset_type = "IMAGE"
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
try:
dm_client = DatasetManagementService(db)
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
if dataset is not None:
dataset_name = dataset.name
dataset_type = getattr(dataset, "datasetType", None) or "IMAGE"
if not normalized_request.file_ids:
total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
resolved_dataset_type = normalized_request.dataset_type or dataset_type
return await service.create_task(
db,
normalized_request,
user_context=user_context,
dataset_name=dataset_name,
total_images=total_images,
dataset_type=resolved_dataset_type,
)
@router.get("/auto", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
@router.get("/operator-tasks", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_annotation_operator_tasks(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取标注任务列表。"""
tasks = await service.list_tasks(db, user_context)
return StandardResponse(
code=200,
message="success",
data=tasks,
)
@router.post("/auto", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""兼容旧版 /auto 接口创建任务。"""
task = await _create_task_internal(
request=request,
db=db,
user_context=user_context,
route_mode="legacy_auto",
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.post("/operator-tasks", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_annotation_operator_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""新接口:创建通用算子编排标注任务。"""
task = await _create_task_internal(
request=request,
db=db,
user_context=user_context,
route_mode="operator_tasks",
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/auto/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
@router.get("/operator-tasks/{task_id}", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个标注任务状态/详情。"""
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.delete("/auto/{task_id}", response_model=StandardResponse[bool])
@router.delete("/operator-tasks/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id, user_context)
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=True,
)
@router.post("/auto/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
@router.post("/operator-tasks/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def stop_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""请求停止自动标注任务。"""
task = await service.request_stop_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/auto/{task_id}/download")
@router.get("/operator-tasks/{task_id}/download")
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。"""
import os
import zipfile
import tempfile
# 复用服务层获取任务信息
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not task.output_path:
raise HTTPException(status_code=400, detail="Task has no output path")
output_dir = task.output_path
if not os.path.isdir(output_dir):
raise HTTPException(status_code=404, detail="Output directory not found")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(output_dir):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, output_dir)
zf.write(file_path, arcname)
file_size = os.path.getsize(tmp_path)
if file_size == 0:
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
def iterfile():
with open(tmp_path, "rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
break
yield chunk
filename = f"{task.name}_annotations.zip"
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size),
}
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)