You've already forked DataMate
## 功能概述
将数据标注模块从固定 YOLO 算子改造为支持通用算子编排,实现与数据清洗模块类似的灵活算子组合能力。
## 改动内容
### 第 1 步:数据库改造(DDL)
- 新增 SQL migration 脚本:scripts/db/data-annotation-operator-pipeline-migration.sql
- 修改 t_dm_auto_annotation_tasks 表:
- 新增字段:task_mode, executor_type, pipeline, output_dataset_id, created_by, stop_requested, started_at, heartbeat_at, run_token
- 新增索引:idx_status_created, idx_created_by
- 创建 t_dm_annotation_task_operator_instance 表:用于存储算子实例详情
### 第 2 步:API 层改造
- 扩展请求模型(schema/auto.py):
- 新增 OperatorPipelineStep 模型
- 支持 pipeline 字段,保留旧 YOLO 字段向后兼容
- 实现多写法归一(operatorId/operator_id/id, overrides/settingsOverride/settings_override)
- 修改任务创建服务(service/auto.py):
- 新增 validate_file_ids() 校验方法
- 新增 _to_pipeline() 兼容映射方法
- 写入新字段并集成算子实例表
- 修复 fileIds 去重准确性问题
- 新增 API 路由(interface/auto.py):
- 新增 /operator-tasks 系列接口
- 新增 stop API 接口(/auto/{id}/stop 和 /operator-tasks/{id}/stop)
- 保留旧 /auto 接口向后兼容
- ORM 模型对齐(annotation_management.py):
- AutoAnnotationTask 新增所有 DDL 字段
- 新增 AnnotationTaskOperatorInstance 模型
- 状态定义补充 stopped
### 第 3 步:Runtime 层改造
- 修改 worker 执行逻辑(auto_annotation_worker.py):
- 实现原子任务抢占机制(run_token)
- 从硬编码 YOLO 改为通用 pipeline 执行
- 新增算子解析和实例化能力
- 支持 stop_requested 检查
- 保留 legacy_yolo 模式向后兼容
- 支持多种算子调用方式(execute 和 __call__)
### 第 4 步:灰度发布
- 完善 YOLO 算子元数据(metadata.yml):
- 补齐 raw_id, language, modal, inputs, outputs, settings 字段
- 注册标注算子(__init__.py):
- 将 YOLO 算子注册到 OPERATORS 注册表
- 确保 annotation 包被正确加载
- 新增白名单控制:
- 支持环境变量 AUTO_ANNOTATION_OPERATOR_WHITELIST
- 灰度发布时可限制可用算子
## 关键特性
### 向后兼容
- 旧 /auto 接口完全保留
- 旧请求参数自动映射到 pipeline
- legacy_yolo 模式确保旧逻辑正常运行
### 新功能
- 支持通用 pipeline 编排
- 支持多算子组合
- 支持任务停止控制
- 支持白名单灰度发布
### 可靠性
- 原子任务抢占(防止重复执行)
- 完整的错误处理和状态管理
- 详细的审计追踪(算子实例表)
## 部署说明
1. 执行 DDL:mysql < scripts/db/data-annotation-operator-pipeline-migration.sql
2. 配置环境变量:AUTO_ANNOTATION_OPERATOR_WHITELIST=ImageObjectDetectionBoundingBox
3. 重启服务:datamate-runtime 和 datamate-backend-python
## 验证步骤
1. 兼容模式验证:使用旧 /auto 接口创建任务
2. 通用编排验证:使用新 /operator-tasks 接口创建 pipeline 任务
3. 原子 claim 验证:检查 run_token 机制
4. 停止验证:测试 stop API
5. 白名单验证:测试算子白名单拦截
## 相关文件
- DDL: scripts/db/data-annotation-operator-pipeline-migration.sql
- API: runtime/datamate-python/app/module/annotation/
- Worker: runtime/python-executor/datamate/auto_annotation_worker.py
- 算子: runtime/ops/annotation/image_object_detection_bounding_box/
282 lines
9.2 KiB
Python
282 lines
9.2 KiB
Python
"""FastAPI routes for Annotation Operator Tasks.
|
|
|
|
兼容路由:
|
|
- GET/POST/DELETE /api/annotation/auto
|
|
- GET /api/annotation/auto/{task_id}/status
|
|
|
|
新路由:
|
|
- GET/POST/DELETE /api/annotation/operator-tasks
|
|
- GET /api/annotation/operator-tasks/{task_id}
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import List, Literal
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Path
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.session import get_db
|
|
from app.module.shared.schema import StandardResponse
|
|
from app.module.dataset import DatasetManagementService
|
|
from app.core.logging import get_logger
|
|
|
|
from ..security import (
|
|
RequestUserContext,
|
|
assert_dataset_access,
|
|
get_request_user_context,
|
|
)
|
|
from ..schema.auto import (
|
|
CreateAutoAnnotationTaskRequest,
|
|
AutoAnnotationTaskResponse,
|
|
)
|
|
from ..service.auto import AutoAnnotationTaskService
|
|
|
|
|
|
router = APIRouter(
|
|
tags=["annotation/auto"],
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
service = AutoAnnotationTaskService()
|
|
|
|
|
|
def _normalize_request_by_route(
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
route_mode: Literal["legacy_auto", "operator_tasks"],
|
|
) -> CreateAutoAnnotationTaskRequest:
|
|
"""根据路由入口做请求标准化。"""
|
|
|
|
if route_mode == "legacy_auto":
|
|
# 旧接口强制走 legacy_yolo 模式,保持行为一致
|
|
return request.model_copy(update={"task_mode": "legacy_yolo"})
|
|
|
|
# 新接口默认走 pipeline 模式(若请求未显式指定 taskMode)
|
|
task_mode = request.task_mode
|
|
if request.pipeline and task_mode == "legacy_yolo":
|
|
task_mode = "pipeline"
|
|
|
|
return request.model_copy(update={"task_mode": task_mode})
|
|
|
|
|
|
async def _create_task_internal(
|
|
*,
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
db: AsyncSession,
|
|
user_context: RequestUserContext,
|
|
route_mode: Literal["legacy_auto", "operator_tasks"],
|
|
) -> AutoAnnotationTaskResponse:
|
|
normalized_request = _normalize_request_by_route(request, route_mode)
|
|
|
|
logger.info(
|
|
"Creating annotation task: route_mode=%s, name=%s, dataset_id=%s, task_mode=%s, executor_type=%s, config=%s, pipeline=%s, file_ids=%s",
|
|
route_mode,
|
|
normalized_request.name,
|
|
normalized_request.dataset_id,
|
|
normalized_request.task_mode,
|
|
normalized_request.executor_type,
|
|
normalized_request.config.model_dump(by_alias=True) if normalized_request.config else None,
|
|
[step.model_dump(by_alias=True) for step in normalized_request.pipeline]
|
|
if normalized_request.pipeline else None,
|
|
normalized_request.file_ids,
|
|
)
|
|
|
|
# 权限 + fileIds 归属校验
|
|
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
|
|
# 尝试获取数据集名称和总量用于冗余字段
|
|
dataset_name = None
|
|
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
|
|
try:
|
|
dm_client = DatasetManagementService(db)
|
|
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
|
|
if dataset is not None:
|
|
dataset_name = dataset.name
|
|
if not normalized_request.file_ids:
|
|
total_images = getattr(dataset, "fileCount", 0) or 0
|
|
except Exception as e: # pragma: no cover - 容错
|
|
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
|
|
|
|
return await service.create_task(
|
|
db,
|
|
normalized_request,
|
|
user_context=user_context,
|
|
dataset_name=dataset_name,
|
|
total_images=total_images,
|
|
)
|
|
|
|
|
|
@router.get("/auto", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
|
@router.get("/operator-tasks", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
|
async def list_annotation_operator_tasks(
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""获取标注任务列表。"""
|
|
|
|
tasks = await service.list_tasks(db, user_context)
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=tasks,
|
|
)
|
|
|
|
|
|
@router.post("/auto", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
async def create_auto_annotation_task(
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""兼容旧版 /auto 接口创建任务。"""
|
|
|
|
task = await _create_task_internal(
|
|
request=request,
|
|
db=db,
|
|
user_context=user_context,
|
|
route_mode="legacy_auto",
|
|
)
|
|
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=task,
|
|
)
|
|
|
|
|
|
@router.post("/operator-tasks", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
async def create_annotation_operator_task(
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""新接口:创建通用算子编排标注任务。"""
|
|
|
|
task = await _create_task_internal(
|
|
request=request,
|
|
db=db,
|
|
user_context=user_context,
|
|
route_mode="operator_tasks",
|
|
)
|
|
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=task,
|
|
)
|
|
|
|
|
|
@router.get("/auto/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
@router.get("/operator-tasks/{task_id}", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
async def get_auto_annotation_task_status(
|
|
task_id: str = Path(..., description="任务ID"),
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""获取单个标注任务状态/详情。"""
|
|
|
|
task = await service.get_task(db, task_id, user_context)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=task,
|
|
)
|
|
|
|
|
|
@router.delete("/auto/{task_id}", response_model=StandardResponse[bool])
|
|
@router.delete("/operator-tasks/{task_id}", response_model=StandardResponse[bool])
|
|
async def delete_auto_annotation_task(
|
|
task_id: str = Path(..., description="任务ID"),
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
|
|
|
ok = await service.soft_delete_task(db, task_id, user_context)
|
|
if not ok:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=True,
|
|
)
|
|
|
|
|
|
@router.post("/auto/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
@router.post("/operator-tasks/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
|
async def stop_auto_annotation_task(
|
|
task_id: str = Path(..., description="任务ID"),
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""请求停止自动标注任务。"""
|
|
|
|
task = await service.request_stop_task(db, task_id, user_context)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return StandardResponse(
|
|
code=200,
|
|
message="success",
|
|
data=task,
|
|
)
|
|
|
|
|
|
@router.get("/auto/{task_id}/download")
|
|
@router.get("/operator-tasks/{task_id}/download")
|
|
async def download_auto_annotation_result(
|
|
task_id: str = Path(..., description="任务ID"),
|
|
db: AsyncSession = Depends(get_db),
|
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
|
):
|
|
"""下载指定自动标注任务的结果 ZIP。"""
|
|
|
|
import os
|
|
import zipfile
|
|
import tempfile
|
|
|
|
# 复用服务层获取任务信息
|
|
task = await service.get_task(db, task_id, user_context)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if not task.output_path:
|
|
raise HTTPException(status_code=400, detail="Task has no output path")
|
|
|
|
output_dir = task.output_path
|
|
if not os.path.isdir(output_dir):
|
|
raise HTTPException(status_code=404, detail="Output directory not found")
|
|
|
|
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
|
os.close(tmp_fd)
|
|
|
|
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
for root, _, files in os.walk(output_dir):
|
|
for filename in files:
|
|
file_path = os.path.join(root, filename)
|
|
arcname = os.path.relpath(file_path, output_dir)
|
|
zf.write(file_path, arcname)
|
|
|
|
file_size = os.path.getsize(tmp_path)
|
|
if file_size == 0:
|
|
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
|
|
|
def iterfile():
|
|
with open(tmp_path, "rb") as f:
|
|
while True:
|
|
chunk = f.read(8192)
|
|
if not chunk:
|
|
break
|
|
yield chunk
|
|
|
|
filename = f"{task.name}_annotations.zip"
|
|
headers = {
|
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
|
"Content-Length": str(file_size),
|
|
}
|
|
|
|
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|