You've already forked DataMate
## 功能概述
将数据标注模块从固定 YOLO 算子改造为支持通用算子编排,实现与数据清洗模块类似的灵活算子组合能力。
## 改动内容
### 第 1 步:数据库改造(DDL)
- 新增 SQL migration 脚本:scripts/db/data-annotation-operator-pipeline-migration.sql
- 修改 t_dm_auto_annotation_tasks 表:
- 新增字段:task_mode, executor_type, pipeline, output_dataset_id, created_by, stop_requested, started_at, heartbeat_at, run_token
- 新增索引:idx_status_created, idx_created_by
- 创建 t_dm_annotation_task_operator_instance 表:用于存储算子实例详情
### 第 2 步:API 层改造
- 扩展请求模型(schema/auto.py):
- 新增 OperatorPipelineStep 模型
- 支持 pipeline 字段,保留旧 YOLO 字段向后兼容
- 实现多写法归一(operatorId/operator_id/id, overrides/settingsOverride/settings_override)
- 修改任务创建服务(service/auto.py):
- 新增 validate_file_ids() 校验方法
- 新增 _to_pipeline() 兼容映射方法
- 写入新字段并集成算子实例表
- 修复 fileIds 去重准确性问题
- 新增 API 路由(interface/auto.py):
- 新增 /operator-tasks 系列接口
- 新增 stop API 接口(/auto/{id}/stop 和 /operator-tasks/{id}/stop)
- 保留旧 /auto 接口向后兼容
- ORM 模型对齐(annotation_management.py):
- AutoAnnotationTask 新增所有 DDL 字段
- 新增 AnnotationTaskOperatorInstance 模型
- 状态定义补充 stopped
### 第 3 步:Runtime 层改造
- 修改 worker 执行逻辑(auto_annotation_worker.py):
- 实现原子任务抢占机制(run_token)
- 从硬编码 YOLO 改为通用 pipeline 执行
- 新增算子解析和实例化能力
- 支持 stop_requested 检查
- 保留 legacy_yolo 模式向后兼容
- 支持多种算子调用方式(execute 和 __call__)
### 第 4 步:灰度发布
- 完善 YOLO 算子元数据(metadata.yml):
- 补齐 raw_id, language, modal, inputs, outputs, settings 字段
- 注册标注算子(__init__.py):
- 将 YOLO 算子注册到 OPERATORS 注册表
- 确保 annotation 包被正确加载
- 新增白名单控制:
- 支持环境变量 AUTO_ANNOTATION_OPERATOR_WHITELIST
- 灰度发布时可限制可用算子
## 关键特性
### 向后兼容
- 旧 /auto 接口完全保留
- 旧请求参数自动映射到 pipeline
- legacy_yolo 模式确保旧逻辑正常运行
### 新功能
- 支持通用 pipeline 编排
- 支持多算子组合
- 支持任务停止控制
- 支持白名单灰度发布
### 可靠性
- 原子任务抢占(防止重复执行)
- 完整的错误处理和状态管理
- 详细的审计追踪(算子实例表)
## 部署说明
1. 执行 DDL:mysql < scripts/db/data-annotation-operator-pipeline-migration.sql
2. 配置环境变量:AUTO_ANNOTATION_OPERATOR_WHITELIST=ImageObjectDetectionBoundingBox
3. 重启服务:datamate-runtime 和 datamate-backend-python
## 验证步骤
1. 兼容模式验证:使用旧 /auto 接口创建任务
2. 通用编排验证:使用新 /operator-tasks 接口创建 pipeline 任务
3. 原子 claim 验证:检查 run_token 机制
4. 停止验证:测试 stop API
5. 白名单验证:测试算子白名单拦截
## 相关文件
- DDL: scripts/db/data-annotation-operator-pipeline-migration.sql
- API: runtime/datamate-python/app/module/annotation/
- Worker: runtime/python-executor/datamate/auto_annotation_worker.py
- 算子: runtime/ops/annotation/image_object_detection_bounding_box/
354 lines
12 KiB
Python
354 lines
12 KiB
Python
"""Service layer for Auto Annotation tasks"""
|
|
from __future__ import annotations
|
|
|
|
from typing import List, Optional, Dict, Any
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models.annotation_management import (
|
|
AutoAnnotationTask,
|
|
AnnotationTaskOperatorInstance,
|
|
)
|
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
|
from app.module.annotation.security import RequestUserContext
|
|
|
|
from ..schema.auto import (
|
|
CreateAutoAnnotationTaskRequest,
|
|
AutoAnnotationTaskResponse,
|
|
)
|
|
|
|
|
|
class AutoAnnotationTaskService:
|
|
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
|
|
|
@staticmethod
|
|
def _normalize_file_ids(file_ids: Optional[List[str]]) -> List[str]:
|
|
if not file_ids:
|
|
return []
|
|
return [fid for fid in dict.fromkeys(file_ids) if fid]
|
|
|
|
@staticmethod
|
|
def _extract_operator_id(step: Dict[str, Any]) -> Optional[str]:
|
|
operator_id = step.get("operatorId") or step.get("operator_id") or step.get("id")
|
|
if operator_id is None:
|
|
return None
|
|
operator_id = str(operator_id).strip()
|
|
return operator_id or None
|
|
|
|
@classmethod
|
|
def _to_operator_instances(
|
|
cls,
|
|
task_id: str,
|
|
pipeline: List[Dict[str, Any]],
|
|
) -> List[AnnotationTaskOperatorInstance]:
|
|
instances: List[AnnotationTaskOperatorInstance] = []
|
|
for step in pipeline:
|
|
if not isinstance(step, dict):
|
|
continue
|
|
operator_id = cls._extract_operator_id(step)
|
|
if not operator_id:
|
|
continue
|
|
|
|
settings_override = (
|
|
step.get("overrides")
|
|
or step.get("settingsOverride")
|
|
or step.get("settings_override")
|
|
or {}
|
|
)
|
|
if not isinstance(settings_override, dict):
|
|
settings_override = {}
|
|
|
|
instances.append(
|
|
AnnotationTaskOperatorInstance(
|
|
task_id=task_id,
|
|
op_index=len(instances) + 1,
|
|
operator_id=operator_id,
|
|
settings_override=settings_override,
|
|
inputs=step.get("inputs"),
|
|
outputs=step.get("outputs"),
|
|
)
|
|
)
|
|
return instances
|
|
|
|
@staticmethod
|
|
def _to_pipeline(request: CreateAutoAnnotationTaskRequest) -> Optional[List[Dict[str, Any]]]:
|
|
"""将请求标准化为 pipeline 结构。"""
|
|
|
|
if request.pipeline:
|
|
return [step.model_dump(by_alias=True) for step in request.pipeline]
|
|
|
|
if request.config is None:
|
|
return None
|
|
|
|
# 兼容旧版 YOLO 请求 -> 单步 pipeline
|
|
config = request.config.model_dump(by_alias=True)
|
|
step_overrides: Dict[str, Any] = {
|
|
"modelSize": config.get("modelSize"),
|
|
"confThreshold": config.get("confThreshold"),
|
|
"targetClasses": config.get("targetClasses") or [],
|
|
}
|
|
|
|
output_dataset_name = request.output_dataset_name or config.get("outputDatasetName")
|
|
if output_dataset_name:
|
|
step_overrides["outputDatasetName"] = output_dataset_name
|
|
|
|
return [
|
|
{
|
|
"operatorId": "ImageObjectDetectionBoundingBox",
|
|
"overrides": step_overrides,
|
|
}
|
|
]
|
|
|
|
async def validate_file_ids(
|
|
self,
|
|
db: AsyncSession,
|
|
dataset_id: str,
|
|
file_ids: Optional[List[str]],
|
|
) -> List[str]:
|
|
"""校验 fileIds 是否全部属于 dataset 且有效。"""
|
|
|
|
normalized_ids = self._normalize_file_ids(file_ids)
|
|
if not normalized_ids:
|
|
if file_ids:
|
|
raise HTTPException(status_code=400, detail="fileIds 不能为空列表")
|
|
return []
|
|
|
|
stmt = select(DatasetFiles.id).where(
|
|
DatasetFiles.id.in_(normalized_ids),
|
|
DatasetFiles.dataset_id == dataset_id,
|
|
DatasetFiles.status == "ACTIVE",
|
|
)
|
|
result = await db.execute(stmt)
|
|
found_ids = {row[0] for row in result.fetchall()}
|
|
missing = [fid for fid in normalized_ids if fid not in found_ids]
|
|
|
|
if missing:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"部分 fileIds 不存在、不可用或不属于数据集: {missing[:10]}",
|
|
)
|
|
|
|
return normalized_ids
|
|
|
|
async def create_task(
|
|
self,
|
|
db: AsyncSession,
|
|
request: CreateAutoAnnotationTaskRequest,
|
|
user_context: RequestUserContext,
|
|
dataset_name: Optional[str] = None,
|
|
total_images: int = 0,
|
|
) -> AutoAnnotationTaskResponse:
|
|
"""创建自动标注任务,初始状态为 pending。
|
|
|
|
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
|
后续可以由调度器/worker 读取该表并更新进度。
|
|
"""
|
|
|
|
now = datetime.now()
|
|
|
|
validated_file_ids = await self.validate_file_ids(
|
|
db,
|
|
request.dataset_id,
|
|
request.file_ids,
|
|
)
|
|
if validated_file_ids:
|
|
total_images = len(validated_file_ids)
|
|
|
|
normalized_pipeline = self._to_pipeline(request)
|
|
if not normalized_pipeline:
|
|
raise HTTPException(status_code=400, detail="pipeline 不能为空")
|
|
|
|
normalized_config = request.config.model_dump(by_alias=True) if request.config else {}
|
|
|
|
task_id = str(uuid4())
|
|
task = AutoAnnotationTask(
|
|
id=task_id,
|
|
name=request.name,
|
|
dataset_id=request.dataset_id,
|
|
dataset_name=dataset_name,
|
|
created_by=user_context.user_id,
|
|
config=normalized_config,
|
|
task_mode=request.task_mode,
|
|
executor_type=request.executor_type,
|
|
pipeline=normalized_pipeline,
|
|
file_ids=validated_file_ids or None,
|
|
status="pending",
|
|
progress=0,
|
|
total_images=total_images,
|
|
processed_images=0,
|
|
detected_objects=0,
|
|
stop_requested=False,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
operator_instances = self._to_operator_instances(task_id, normalized_pipeline)
|
|
db.add(task)
|
|
if operator_instances:
|
|
db.add_all(operator_instances)
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
|
return resp
|
|
|
|
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
|
if user_context.is_admin:
|
|
return query
|
|
return query.join(
|
|
Dataset,
|
|
AutoAnnotationTask.dataset_id == Dataset.id,
|
|
).where(Dataset.created_by == user_context.user_id)
|
|
|
|
async def list_tasks(
|
|
self,
|
|
db: AsyncSession,
|
|
user_context: RequestUserContext,
|
|
) -> List[AutoAnnotationTaskResponse]:
|
|
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
|
|
|
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(
|
|
query.order_by(AutoAnnotationTask.created_at.desc())
|
|
)
|
|
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
|
|
|
responses: List[AutoAnnotationTaskResponse] = []
|
|
for task in tasks:
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
# 出错时降级为单个 datasetName/datasetId
|
|
fallback_name = getattr(task, "dataset_name", None)
|
|
fallback_id = getattr(task, "dataset_id", "")
|
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
|
responses.append(resp)
|
|
|
|
return responses
|
|
|
|
async def get_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: str,
|
|
user_context: RequestUserContext,
|
|
) -> Optional[AutoAnnotationTaskResponse]:
|
|
query = select(AutoAnnotationTask).where(
|
|
AutoAnnotationTask.id == task_id,
|
|
AutoAnnotationTask.deleted_at.is_(None),
|
|
)
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(query)
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
return None
|
|
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
fallback_name = getattr(task, "dataset_name", None)
|
|
fallback_id = getattr(task, "dataset_id", "")
|
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
|
return resp
|
|
|
|
async def _compute_source_datasets(
|
|
self,
|
|
db: AsyncSession,
|
|
task: AutoAnnotationTask,
|
|
) -> List[str]:
|
|
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
|
|
|
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
|
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
|
"""
|
|
|
|
file_ids = task.file_ids or []
|
|
if file_ids:
|
|
stmt = (
|
|
select(Dataset.name)
|
|
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
|
.where(DatasetFiles.id.in_(file_ids))
|
|
.distinct()
|
|
)
|
|
result = await db.execute(stmt)
|
|
names = [row[0] for row in result.fetchall() if row[0]]
|
|
if names:
|
|
return names
|
|
|
|
# 回退:只显示一个数据集
|
|
if task.dataset_name:
|
|
return [task.dataset_name]
|
|
if task.dataset_id:
|
|
return [task.dataset_id]
|
|
return []
|
|
|
|
async def request_stop_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: str,
|
|
user_context: RequestUserContext,
|
|
) -> Optional[AutoAnnotationTaskResponse]:
|
|
query = select(AutoAnnotationTask).where(
|
|
AutoAnnotationTask.id == task_id,
|
|
AutoAnnotationTask.deleted_at.is_(None),
|
|
)
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(query)
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
return None
|
|
|
|
now = datetime.now()
|
|
terminal_states = {"completed", "failed", "stopped"}
|
|
if task.status not in terminal_states:
|
|
task.stop_requested = True
|
|
task.error_message = "Task stop requested"
|
|
if task.status == "pending":
|
|
task.status = "stopped"
|
|
task.progress = task.progress or 0
|
|
task.completed_at = now
|
|
task.run_token = None
|
|
task.updated_at = now
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
|
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
|
try:
|
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
|
except Exception:
|
|
fallback_name = getattr(task, "dataset_name", None)
|
|
fallback_id = getattr(task, "dataset_id", "")
|
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
|
return resp
|
|
|
|
async def soft_delete_task(
|
|
self,
|
|
db: AsyncSession,
|
|
task_id: str,
|
|
user_context: RequestUserContext,
|
|
) -> bool:
|
|
query = select(AutoAnnotationTask).where(
|
|
AutoAnnotationTask.id == task_id,
|
|
AutoAnnotationTask.deleted_at.is_(None),
|
|
)
|
|
query = self._apply_dataset_scope(query, user_context)
|
|
result = await db.execute(query)
|
|
task = result.scalar_one_or_none()
|
|
if not task:
|
|
return False
|
|
|
|
task.deleted_at = datetime.now()
|
|
await db.commit()
|
|
return True
|