feat(annotation): 支持通用算子编排的数据标注功能

## 功能概述
将数据标注模块从固定 YOLO 算子改造为支持通用算子编排,实现与数据清洗模块类似的灵活算子组合能力。

## 改动内容

### 第 1 步:数据库改造(DDL)
- 新增 SQL migration 脚本:scripts/db/data-annotation-operator-pipeline-migration.sql
- 修改 t_dm_auto_annotation_tasks 表:
  - 新增字段:task_mode, executor_type, pipeline, output_dataset_id, created_by, stop_requested, started_at, heartbeat_at, run_token
  - 新增索引:idx_status_created, idx_created_by
- 创建 t_dm_annotation_task_operator_instance 表:用于存储算子实例详情

### 第 2 步:API 层改造
- 扩展请求模型(schema/auto.py):
  - 新增 OperatorPipelineStep 模型
  - 支持 pipeline 字段,保留旧 YOLO 字段向后兼容
  - 实现多写法归一(operatorId/operator_id/id, overrides/settingsOverride/settings_override)
- 修改任务创建服务(service/auto.py):
  - 新增 validate_file_ids() 校验方法
  - 新增 _to_pipeline() 兼容映射方法
  - 写入新字段并集成算子实例表
  - 修复 fileIds 去重准确性问题
- 新增 API 路由(interface/auto.py):
  - 新增 /operator-tasks 系列接口
  - 新增 stop API 接口(/auto/{id}/stop 和 /operator-tasks/{id}/stop)
  - 保留旧 /auto 接口向后兼容
- ORM 模型对齐(annotation_management.py):
  - AutoAnnotationTask 新增所有 DDL 字段
  - 新增 AnnotationTaskOperatorInstance 模型
  - 状态定义补充 stopped

### 第 3 步:Runtime 层改造
- 修改 worker 执行逻辑(auto_annotation_worker.py):
  - 实现原子任务抢占机制(run_token)
  - 从硬编码 YOLO 改为通用 pipeline 执行
  - 新增算子解析和实例化能力
  - 支持 stop_requested 检查
  - 保留 legacy_yolo 模式向后兼容
  - 支持多种算子调用方式(execute 和 __call__)

### 第 4 步:灰度发布
- 完善 YOLO 算子元数据(metadata.yml):
  - 补齐 raw_id, language, modal, inputs, outputs, settings 字段
- 注册标注算子(__init__.py):
  - 将 YOLO 算子注册到 OPERATORS 注册表
  - 确保 annotation 包被正确加载
- 新增白名单控制:
  - 支持环境变量 AUTO_ANNOTATION_OPERATOR_WHITELIST
  - 灰度发布时可限制可用算子

## 关键特性

### 向后兼容
- 旧 /auto 接口完全保留
- 旧请求参数自动映射到 pipeline
- legacy_yolo 模式确保旧逻辑正常运行

### 新功能
- 支持通用 pipeline 编排
- 支持多算子组合
- 支持任务停止控制
- 支持白名单灰度发布

### 可靠性
- 原子任务抢占(防止重复执行)
- 完整的错误处理和状态管理
- 详细的审计追踪(算子实例表)

## 部署说明

1. 执行 DDL:mysql < scripts/db/data-annotation-operator-pipeline-migration.sql
2. 配置环境变量:AUTO_ANNOTATION_OPERATOR_WHITELIST=ImageObjectDetectionBoundingBox
3. 重启服务:datamate-runtime 和 datamate-backend-python

## 验证步骤

1. 兼容模式验证:使用旧 /auto 接口创建任务
2. 通用编排验证:使用新 /operator-tasks 接口创建 pipeline 任务
3. 原子 claim 验证:检查 run_token 机制
4. 停止验证:测试 stop API
5. 白名单验证:测试算子白名单拦截

## 相关文件

- DDL: scripts/db/data-annotation-operator-pipeline-migration.sql
- API: runtime/datamate-python/app/module/annotation/
- Worker: runtime/python-executor/datamate/auto_annotation_worker.py
- 算子: runtime/ops/annotation/image_object_detection_bounding_box/
This commit is contained in:
2026-02-07 22:35:33 +08:00
parent 9efc07935f
commit 2f49fc4199
9 changed files with 1606 additions and 480 deletions

View File

@@ -1,19 +1,20 @@
"""FastAPI routes for Auto Annotation tasks.
These routes back the frontend AutoAnnotation module:
- GET /api/annotation/auto
- POST /api/annotation/auto
- DELETE /api/annotation/auto/{task_id}
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
"""
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
"""FastAPI routes for Annotation Operator Tasks.
兼容路由:
- GET/POST/DELETE /api/annotation/auto
- GET /api/annotation/auto/{task_id}/status
新路由:
- GET/POST/DELETE /api/annotation/operator-tasks
- GET /api/annotation/operator-tasks/{task_id}
"""
from __future__ import annotations
from typing import List, Literal
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
@@ -29,110 +30,163 @@ from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
from ..service.auto import AutoAnnotationTaskService
router = APIRouter(
prefix="/auto",
tags=["annotation/auto"],
)
from ..service.auto import AutoAnnotationTaskService
router = APIRouter(
tags=["annotation/auto"],
)
logger = get_logger(__name__)
service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks(
service = AutoAnnotationTaskService()
def _normalize_request_by_route(
request: CreateAutoAnnotationTaskRequest,
route_mode: Literal["legacy_auto", "operator_tasks"],
) -> CreateAutoAnnotationTaskRequest:
"""根据路由入口做请求标准化。"""
if route_mode == "legacy_auto":
# 旧接口强制走 legacy_yolo 模式,保持行为一致
return request.model_copy(update={"task_mode": "legacy_yolo"})
# 新接口默认走 pipeline 模式(若请求未显式指定 taskMode)
task_mode = request.task_mode
if request.pipeline and task_mode == "legacy_yolo":
task_mode = "pipeline"
return request.model_copy(update={"task_mode": task_mode})
async def _create_task_internal(
*,
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession,
user_context: RequestUserContext,
route_mode: Literal["legacy_auto", "operator_tasks"],
) -> AutoAnnotationTaskResponse:
normalized_request = _normalize_request_by_route(request, route_mode)
logger.info(
"Creating annotation task: route_mode=%s, name=%s, dataset_id=%s, task_mode=%s, executor_type=%s, config=%s, pipeline=%s, file_ids=%s",
route_mode,
normalized_request.name,
normalized_request.dataset_id,
normalized_request.task_mode,
normalized_request.executor_type,
normalized_request.config.model_dump(by_alias=True) if normalized_request.config else None,
[step.model_dump(by_alias=True) for step in normalized_request.pipeline]
if normalized_request.pipeline else None,
normalized_request.file_ids,
)
# 权限 + fileIds 归属校验
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
# 尝试获取数据集名称和总量用于冗余字段
dataset_name = None
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
try:
dm_client = DatasetManagementService(db)
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
if dataset is not None:
dataset_name = dataset.name
if not normalized_request.file_ids:
total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
return await service.create_task(
db,
normalized_request,
user_context=user_context,
dataset_name=dataset_name,
total_images=total_images,
)
@router.get("/auto", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
@router.get("/operator-tasks", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_annotation_operator_tasks(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。
"""
"""获取标注任务列表。"""
tasks = await service.list_tasks(db, user_context)
return StandardResponse(
code=200,
message="success",
data=tasks,
)
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
return StandardResponse(
code=200,
message="success",
data=tasks,
)
@router.post("/auto", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""创建自动标注任务。
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
"""
logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name,
request.dataset_id,
request.config.model_dump(by_alias=True),
request.file_ids,
)
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None
total_images = 0
await assert_dataset_access(db, request.dataset_id, user_context)
try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
dataset = await dm_client.get_dataset(request.dataset_id)
if dataset is not None:
dataset_name = dataset.name
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
if request.file_ids:
total_images = len(request.file_ids)
else:
total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset name for auto task: %s", e)
task = await service.create_task(
db,
request,
dataset_name=dataset_name,
total_images=total_images,
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
"""兼容旧版 /auto 接口创建任务。"""
task = await _create_task_internal(
request=request,
db=db,
user_context=user_context,
route_mode="legacy_auto",
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.post("/operator-tasks", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_annotation_operator_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""新接口:创建通用算子编排标注任务。"""
task = await _create_task_internal(
request=request,
db=db,
user_context=user_context,
route_mode="operator_tasks",
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/auto/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
@router.get("/operator-tasks/{task_id}", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个自动标注任务状态
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
"""
"""获取单个标注任务状态/详情。"""
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.delete("/{task_id}", response_model=StandardResponse[bool])
data=task,
)
@router.delete("/auto/{task_id}", response_model=StandardResponse[bool])
@router.delete("/operator-tasks/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
@@ -144,14 +198,35 @@ async def delete_auto_annotation_task(
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=True,
)
@router.get("/{task_id}/download")
return StandardResponse(
code=200,
message="success",
data=True,
)
@router.post("/auto/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
@router.post("/operator-tasks/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def stop_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""请求停止自动标注任务。"""
task = await service.request_stop_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/auto/{task_id}/download")
@router.get("/operator-tasks/{task_id}/download")
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
@@ -159,10 +234,9 @@ async def download_auto_annotation_result(
):
"""下载指定自动标注任务的结果 ZIP。"""
import io
import os
import zipfile
import tempfile
import os
import zipfile
import tempfile
# 复用服务层获取任务信息
task = await service.get_task(db, task_id, user_context)

View File

@@ -1,13 +1,15 @@
"""Schemas for Auto Annotation tasks"""
from __future__ import annotations
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
from __future__ import annotations
import json
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict, model_validator
class AutoAnnotationConfig(BaseModel):
class AutoAnnotationConfig(BaseModel):
"""自动标注任务配置(与前端 payload 对齐)"""
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
@@ -23,43 +25,133 @@ class AutoAnnotationConfig(BaseModel):
description="自动标注结果要写入的新数据集名称(可选)",
)
model_config = ConfigDict(populate_by_name=True)
model_config = ConfigDict(populate_by_name=True)
class OperatorPipelineStep(BaseModel):
"""通用算子编排中的单个算子节点定义"""
operator_id: str = Field(alias="operatorId", description="算子ID(raw_id)")
overrides: Dict[str, Any] = Field(
default_factory=dict,
alias="overrides",
description="算子参数覆盖(对应 settings override)",
)
@model_validator(mode="before")
@classmethod
def normalize_compatible_fields(cls, value: Any):
if not isinstance(value, dict):
return value
normalized = dict(value)
if "operatorId" not in normalized:
for key in ("operator_id", "id"):
candidate = normalized.get(key)
if candidate:
normalized["operatorId"] = candidate
break
if "overrides" not in normalized:
for key in ("settingsOverride", "settings_override"):
candidate = normalized.get(key)
if isinstance(candidate, str):
try:
candidate = json.loads(candidate)
except Exception:
candidate = None
if isinstance(candidate, dict):
normalized["overrides"] = candidate
break
return normalized
model_config = ConfigDict(populate_by_name=True)
class CreateAutoAnnotationTaskRequest(BaseModel):
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
config: AutoAnnotationConfig = Field(..., description="任务配置")
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
model_config = ConfigDict(populate_by_name=True)
class CreateAutoAnnotationTaskRequest(BaseModel):
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
config: Optional[AutoAnnotationConfig] = Field(
default=None,
description="兼容旧版 YOLO 任务配置",
)
pipeline: Optional[List[OperatorPipelineStep]] = Field(
default=None,
description="通用算子编排定义",
)
task_mode: str = Field(
default="legacy_yolo",
alias="taskMode",
description="任务模式: legacy_yolo/pipeline",
)
executor_type: str = Field(
default="annotation_local",
alias="executorType",
description="执行器类型",
)
output_dataset_name: Optional[str] = Field(
default=None,
alias="outputDatasetName",
description="输出数据集名称(优先级高于 config.outputDatasetName)",
)
file_ids: Optional[List[str]] = Field(
None,
alias="fileIds",
description="要处理的文件ID列表,为空则处理数据集中所有图像",
)
@model_validator(mode="after")
def validate_config_or_pipeline(self):
if self.config is None and not self.pipeline:
raise ValueError("Either config or pipeline must be provided")
return self
model_config = ConfigDict(populate_by_name=True)
class AutoAnnotationTaskResponse(BaseModel):
class AutoAnnotationTaskResponse(BaseModel):
"""自动标注任务响应模型(列表/详情均可复用)"""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
source_datasets: Optional[List[str]] = Field(
default=None,
alias="sourceDatasets",
description="本任务实际处理涉及到的所有数据集名称列表",
)
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
source_datasets: Optional[List[str]] = Field(
default=None,
alias="sourceDatasets",
description="本任务实际处理涉及到的所有数据集名称列表",
)
config: Dict[str, Any] = Field(..., description="任务配置")
status: str = Field(..., description="任务状态")
progress: int = Field(..., description="任务进度 0-100")
total_images: int = Field(..., alias="totalImages", description="总图片数")
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
output_dataset_id: Optional[str] = Field(
None,
alias="outputDatasetId",
description="输出数据集ID",
)
stop_requested: Optional[bool] = Field(
None,
alias="stopRequested",
description="是否请求停止",
)
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
created_by: Optional[str] = Field(None, alias="createdBy", description="创建人")
started_at: Optional[datetime] = Field(None, alias="startedAt", description="启动时间")
heartbeat_at: Optional[datetime] = Field(None, alias="heartbeatAt", description="心跳时间")
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
model_config = ConfigDict(populate_by_name=True, from_attributes=True)

View File

@@ -1,14 +1,18 @@
"""Service layer for Auto Annotation tasks"""
from __future__ import annotations
from typing import List, Optional
from datetime import datetime
from uuid import uuid4
"""Service layer for Auto Annotation tasks"""
from __future__ import annotations
from typing import List, Optional, Dict, Any
from datetime import datetime
from uuid import uuid4
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.annotation_management import (
AutoAnnotationTask,
AnnotationTaskOperatorInstance,
)
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
@@ -19,42 +23,175 @@ from ..schema.auto import (
class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
async def create_task(
self,
db: AsyncSession,
request: CreateAutoAnnotationTaskRequest,
dataset_name: Optional[str] = None,
total_images: int = 0,
) -> AutoAnnotationTaskResponse:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
@staticmethod
def _normalize_file_ids(file_ids: Optional[List[str]]) -> List[str]:
if not file_ids:
return []
return [fid for fid in dict.fromkeys(file_ids) if fid]
@staticmethod
def _extract_operator_id(step: Dict[str, Any]) -> Optional[str]:
operator_id = step.get("operatorId") or step.get("operator_id") or step.get("id")
if operator_id is None:
return None
operator_id = str(operator_id).strip()
return operator_id or None
@classmethod
def _to_operator_instances(
cls,
task_id: str,
pipeline: List[Dict[str, Any]],
) -> List[AnnotationTaskOperatorInstance]:
instances: List[AnnotationTaskOperatorInstance] = []
for step in pipeline:
if not isinstance(step, dict):
continue
operator_id = cls._extract_operator_id(step)
if not operator_id:
continue
settings_override = (
step.get("overrides")
or step.get("settingsOverride")
or step.get("settings_override")
or {}
)
if not isinstance(settings_override, dict):
settings_override = {}
instances.append(
AnnotationTaskOperatorInstance(
task_id=task_id,
op_index=len(instances) + 1,
operator_id=operator_id,
settings_override=settings_override,
inputs=step.get("inputs"),
outputs=step.get("outputs"),
)
)
return instances
@staticmethod
def _to_pipeline(request: CreateAutoAnnotationTaskRequest) -> Optional[List[Dict[str, Any]]]:
"""将请求标准化为 pipeline 结构。"""
if request.pipeline:
return [step.model_dump(by_alias=True) for step in request.pipeline]
if request.config is None:
return None
# 兼容旧版 YOLO 请求 -> 单步 pipeline
config = request.config.model_dump(by_alias=True)
step_overrides: Dict[str, Any] = {
"modelSize": config.get("modelSize"),
"confThreshold": config.get("confThreshold"),
"targetClasses": config.get("targetClasses") or [],
}
output_dataset_name = request.output_dataset_name or config.get("outputDatasetName")
if output_dataset_name:
step_overrides["outputDatasetName"] = output_dataset_name
return [
{
"operatorId": "ImageObjectDetectionBoundingBox",
"overrides": step_overrides,
}
]
async def validate_file_ids(
self,
db: AsyncSession,
dataset_id: str,
file_ids: Optional[List[str]],
) -> List[str]:
"""校验 fileIds 是否全部属于 dataset 且有效。"""
normalized_ids = self._normalize_file_ids(file_ids)
if not normalized_ids:
if file_ids:
raise HTTPException(status_code=400, detail="fileIds 不能为空列表")
return []
stmt = select(DatasetFiles.id).where(
DatasetFiles.id.in_(normalized_ids),
DatasetFiles.dataset_id == dataset_id,
DatasetFiles.status == "ACTIVE",
)
result = await db.execute(stmt)
found_ids = {row[0] for row in result.fetchall()}
missing = [fid for fid in normalized_ids if fid not in found_ids]
if missing:
raise HTTPException(
status_code=400,
detail=f"部分 fileIds 不存在、不可用或不属于数据集: {missing[:10]}",
)
return normalized_ids
async def create_task(
self,
db: AsyncSession,
request: CreateAutoAnnotationTaskRequest,
user_context: RequestUserContext,
dataset_name: Optional[str] = None,
total_images: int = 0,
) -> AutoAnnotationTaskResponse:
"""创建自动标注任务,初始状态为 pending。
这里仅插入任务记录,不负责真正执行 YOLO 推理,
后续可以由调度器/worker 读取该表并更新进度。
"""
now = datetime.now()
task = AutoAnnotationTask(
id=str(uuid4()),
name=request.name,
dataset_id=request.dataset_id,
dataset_name=dataset_name,
config=request.config.model_dump(by_alias=True),
file_ids=request.file_ids, # 存储用户选择的文件ID列表
status="pending",
progress=0,
total_images=total_images,
processed_images=0,
detected_objects=0,
created_at=now,
updated_at=now,
)
db.add(task)
await db.commit()
await db.refresh(task)
"""
now = datetime.now()
validated_file_ids = await self.validate_file_ids(
db,
request.dataset_id,
request.file_ids,
)
if validated_file_ids:
total_images = len(validated_file_ids)
normalized_pipeline = self._to_pipeline(request)
if not normalized_pipeline:
raise HTTPException(status_code=400, detail="pipeline 不能为空")
normalized_config = request.config.model_dump(by_alias=True) if request.config else {}
task_id = str(uuid4())
task = AutoAnnotationTask(
id=task_id,
name=request.name,
dataset_id=request.dataset_id,
dataset_name=dataset_name,
created_by=user_context.user_id,
config=normalized_config,
task_mode=request.task_mode,
executor_type=request.executor_type,
pipeline=normalized_pipeline,
file_ids=validated_file_ids or None,
status="pending",
progress=0,
total_images=total_images,
processed_images=0,
detected_objects=0,
stop_requested=False,
created_at=now,
updated_at=now,
)
operator_instances = self._to_operator_instances(task_id, normalized_pipeline)
db.add(task)
if operator_instances:
db.add_all(operator_instances)
await db.commit()
await db.refresh(task)
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
resp = AutoAnnotationTaskResponse.model_validate(task)
@@ -152,9 +289,48 @@ class AutoAnnotationTaskService:
# 回退:只显示一个数据集
if task.dataset_name:
return [task.dataset_name]
if task.dataset_id:
return [task.dataset_id]
return []
if task.dataset_id:
return [task.dataset_id]
return []
async def request_stop_task(
self,
db: AsyncSession,
task_id: str,
user_context: RequestUserContext,
) -> Optional[AutoAnnotationTaskResponse]:
query = select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return None
now = datetime.now()
terminal_states = {"completed", "failed", "stopped"}
if task.status not in terminal_states:
task.stop_requested = True
task.error_message = "Task stop requested"
if task.status == "pending":
task.status = "stopped"
task.progress = task.progress or 0
task.completed_at = now
task.run_token = None
task.updated_at = now
await db.commit()
await db.refresh(task)
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp
async def soft_delete_task(
self,