"""FastAPI routes for Annotation Operator Tasks. 兼容路由: - GET/POST/DELETE /api/annotation/auto - GET /api/annotation/auto/{task_id}/status 新路由: - GET/POST/DELETE /api/annotation/operator-tasks - GET /api/annotation/operator-tasks/{task_id} """ from __future__ import annotations from typing import List, Literal from fastapi import APIRouter, Depends, HTTPException, Path from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from app.db.session import get_db from app.module.shared.schema import StandardResponse from app.module.dataset import DatasetManagementService from app.core.logging import get_logger from ..security import ( RequestUserContext, assert_dataset_access, get_request_user_context, ) from ..schema.auto import ( CreateAutoAnnotationTaskRequest, AutoAnnotationTaskResponse, ) from ..service.auto import AutoAnnotationTaskService router = APIRouter( tags=["annotation/auto"], ) logger = get_logger(__name__) service = AutoAnnotationTaskService() def _normalize_request_by_route( request: CreateAutoAnnotationTaskRequest, route_mode: Literal["legacy_auto", "operator_tasks"], ) -> CreateAutoAnnotationTaskRequest: """根据路由入口做请求标准化。""" if route_mode == "legacy_auto": # 旧接口强制走 legacy_yolo 模式,保持行为一致 return request.model_copy(update={"task_mode": "legacy_yolo"}) # 新接口默认走 pipeline 模式(若请求未显式指定 taskMode) task_mode = request.task_mode if request.pipeline and task_mode == "legacy_yolo": task_mode = "pipeline" return request.model_copy(update={"task_mode": task_mode}) async def _create_task_internal( *, request: CreateAutoAnnotationTaskRequest, db: AsyncSession, user_context: RequestUserContext, route_mode: Literal["legacy_auto", "operator_tasks"], ) -> AutoAnnotationTaskResponse: normalized_request = _normalize_request_by_route(request, route_mode) logger.info( "Creating annotation task: route_mode=%s, name=%s, dataset_id=%s, task_mode=%s, executor_type=%s, config=%s, pipeline=%s, file_ids=%s", route_mode, normalized_request.name, normalized_request.dataset_id, normalized_request.task_mode, normalized_request.executor_type, normalized_request.config.model_dump(by_alias=True) if normalized_request.config else None, [step.model_dump(by_alias=True) for step in normalized_request.pipeline] if normalized_request.pipeline else None, normalized_request.file_ids, ) # 权限 + fileIds 归属校验 await assert_dataset_access(db, normalized_request.dataset_id, user_context) # 尝试获取数据集名称和总量用于冗余字段 dataset_name = None dataset_type = "IMAGE" total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0 try: dm_client = DatasetManagementService(db) dataset = await dm_client.get_dataset(normalized_request.dataset_id) if dataset is not None: dataset_name = dataset.name dataset_type = getattr(dataset, "datasetType", None) or "IMAGE" if not normalized_request.file_ids: total_images = getattr(dataset, "fileCount", 0) or 0 except Exception as e: # pragma: no cover - 容错 logger.warning("Failed to fetch dataset summary for annotation task: %s", e) resolved_dataset_type = normalized_request.dataset_type or dataset_type return await service.create_task( db, normalized_request, user_context=user_context, dataset_name=dataset_name, total_images=total_images, dataset_type=resolved_dataset_type, ) @router.get("/auto", response_model=StandardResponse[List[AutoAnnotationTaskResponse]]) @router.get("/operator-tasks", response_model=StandardResponse[List[AutoAnnotationTaskResponse]]) async def list_annotation_operator_tasks( db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """获取标注任务列表。""" tasks = await service.list_tasks(db, user_context) return StandardResponse( code=200, message="success", data=tasks, ) @router.post("/auto", response_model=StandardResponse[AutoAnnotationTaskResponse]) async def create_auto_annotation_task( request: CreateAutoAnnotationTaskRequest, db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """兼容旧版 /auto 接口创建任务。""" task = await _create_task_internal( request=request, db=db, user_context=user_context, route_mode="legacy_auto", ) return StandardResponse( code=200, message="success", data=task, ) @router.post("/operator-tasks", response_model=StandardResponse[AutoAnnotationTaskResponse]) async def create_annotation_operator_task( request: CreateAutoAnnotationTaskRequest, db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """新接口:创建通用算子编排标注任务。""" task = await _create_task_internal( request=request, db=db, user_context=user_context, route_mode="operator_tasks", ) return StandardResponse( code=200, message="success", data=task, ) @router.get("/auto/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.get("/operator-tasks/{task_id}", response_model=StandardResponse[AutoAnnotationTaskResponse]) async def get_auto_annotation_task_status( task_id: str = Path(..., description="任务ID"), db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """获取单个标注任务状态/详情。""" task = await service.get_task(db, task_id, user_context) if not task: raise HTTPException(status_code=404, detail="Task not found") return StandardResponse( code=200, message="success", data=task, ) @router.delete("/auto/{task_id}", response_model=StandardResponse[bool]) @router.delete("/operator-tasks/{task_id}", response_model=StandardResponse[bool]) async def delete_auto_annotation_task( task_id: str = Path(..., description="任务ID"), db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """删除(软删除)自动标注任务,仅标记 deleted_at。""" ok = await service.soft_delete_task(db, task_id, user_context) if not ok: raise HTTPException(status_code=404, detail="Task not found") return StandardResponse( code=200, message="success", data=True, ) @router.post("/auto/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.post("/operator-tasks/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse]) async def stop_auto_annotation_task( task_id: str = Path(..., description="任务ID"), db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """请求停止自动标注任务。""" task = await service.request_stop_task(db, task_id, user_context) if not task: raise HTTPException(status_code=404, detail="Task not found") return StandardResponse( code=200, message="success", data=task, ) @router.get("/auto/{task_id}/download") @router.get("/operator-tasks/{task_id}/download") async def download_auto_annotation_result( task_id: str = Path(..., description="任务ID"), db: AsyncSession = Depends(get_db), user_context: RequestUserContext = Depends(get_request_user_context), ): """下载指定自动标注任务的结果 ZIP。""" import os import zipfile import tempfile # 复用服务层获取任务信息 task = await service.get_task(db, task_id, user_context) if not task: raise HTTPException(status_code=404, detail="Task not found") if not task.output_path: raise HTTPException(status_code=400, detail="Task has no output path") output_dir = task.output_path if not os.path.isdir(output_dir): raise HTTPException(status_code=404, detail="Output directory not found") tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip") os.close(tmp_fd) with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf: for root, _, files in os.walk(output_dir): for filename in files: file_path = os.path.join(root, filename) arcname = os.path.relpath(file_path, output_dir) zf.write(file_path, arcname) file_size = os.path.getsize(tmp_path) if file_size == 0: raise HTTPException(status_code=500, detail="Generated ZIP is empty") def iterfile(): with open(tmp_path, "rb") as f: while True: chunk = f.read(8192) if not chunk: break yield chunk filename = f"{task.name}_annotations.zip" headers = { "Content-Disposition": f'attachment; filename="{filename}"', "Content-Length": str(file_size), } return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)