You've already forked DataMate
feat(annotation): 支持通用算子编排的数据标注功能
## 功能概述
将数据标注模块从固定 YOLO 算子改造为支持通用算子编排,实现与数据清洗模块类似的灵活算子组合能力。
## 改动内容
### 第 1 步:数据库改造(DDL)
- 新增 SQL migration 脚本:scripts/db/data-annotation-operator-pipeline-migration.sql
- 修改 t_dm_auto_annotation_tasks 表:
- 新增字段:task_mode, executor_type, pipeline, output_dataset_id, created_by, stop_requested, started_at, heartbeat_at, run_token
- 新增索引:idx_status_created, idx_created_by
- 创建 t_dm_annotation_task_operator_instance 表:用于存储算子实例详情
### 第 2 步:API 层改造
- 扩展请求模型(schema/auto.py):
- 新增 OperatorPipelineStep 模型
- 支持 pipeline 字段,保留旧 YOLO 字段向后兼容
- 实现多写法归一(operatorId/operator_id/id, overrides/settingsOverride/settings_override)
- 修改任务创建服务(service/auto.py):
- 新增 validate_file_ids() 校验方法
- 新增 _to_pipeline() 兼容映射方法
- 写入新字段并集成算子实例表
- 修复 fileIds 去重准确性问题
- 新增 API 路由(interface/auto.py):
- 新增 /operator-tasks 系列接口
- 新增 stop API 接口(/auto/{id}/stop 和 /operator-tasks/{id}/stop)
- 保留旧 /auto 接口向后兼容
- ORM 模型对齐(annotation_management.py):
- AutoAnnotationTask 新增所有 DDL 字段
- 新增 AnnotationTaskOperatorInstance 模型
- 状态定义补充 stopped
### 第 3 步:Runtime 层改造
- 修改 worker 执行逻辑(auto_annotation_worker.py):
- 实现原子任务抢占机制(run_token)
- 从硬编码 YOLO 改为通用 pipeline 执行
- 新增算子解析和实例化能力
- 支持 stop_requested 检查
- 保留 legacy_yolo 模式向后兼容
- 支持多种算子调用方式(execute 和 __call__)
### 第 4 步:灰度发布
- 完善 YOLO 算子元数据(metadata.yml):
- 补齐 raw_id, language, modal, inputs, outputs, settings 字段
- 注册标注算子(__init__.py):
- 将 YOLO 算子注册到 OPERATORS 注册表
- 确保 annotation 包被正确加载
- 新增白名单控制:
- 支持环境变量 AUTO_ANNOTATION_OPERATOR_WHITELIST
- 灰度发布时可限制可用算子
## 关键特性
### 向后兼容
- 旧 /auto 接口完全保留
- 旧请求参数自动映射到 pipeline
- legacy_yolo 模式确保旧逻辑正常运行
### 新功能
- 支持通用 pipeline 编排
- 支持多算子组合
- 支持任务停止控制
- 支持白名单灰度发布
### 可靠性
- 原子任务抢占(防止重复执行)
- 完整的错误处理和状态管理
- 详细的审计追踪(算子实例表)
## 部署说明
1. 执行 DDL:mysql < scripts/db/data-annotation-operator-pipeline-migration.sql
2. 配置环境变量:AUTO_ANNOTATION_OPERATOR_WHITELIST=ImageObjectDetectionBoundingBox
3. 重启服务:datamate-runtime 和 datamate-backend-python
## 验证步骤
1. 兼容模式验证:使用旧 /auto 接口创建任务
2. 通用编排验证:使用新 /operator-tasks 接口创建 pipeline 任务
3. 原子 claim 验证:检查 run_token 机制
4. 停止验证:测试 stop API
5. 白名单验证:测试算子白名单拦截
## 相关文件
- DDL: scripts/db/data-annotation-operator-pipeline-migration.sql
- API: runtime/datamate-python/app/module/annotation/
- Worker: runtime/python-executor/datamate/auto_annotation_worker.py
- 算子: runtime/ops/annotation/image_object_detection_bounding_box/
This commit is contained in:
@@ -210,6 +210,7 @@ class AutoAnnotationTask(Base):
|
||||
dataset_name = Column(
|
||||
String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)"
|
||||
)
|
||||
created_by = Column(String(255), nullable=True, comment="任务创建人")
|
||||
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
||||
file_ids = Column(
|
||||
JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像"
|
||||
@@ -218,13 +219,28 @@ class AutoAnnotationTask(Base):
|
||||
String(50),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
comment="任务状态: pending/running/completed/failed",
|
||||
comment="任务状态: pending/running/completed/failed/stopped",
|
||||
)
|
||||
task_mode = Column(
|
||||
String(32),
|
||||
nullable=False,
|
||||
default="legacy_yolo",
|
||||
comment="任务模式: legacy_yolo/pipeline",
|
||||
)
|
||||
executor_type = Column(
|
||||
String(32),
|
||||
nullable=False,
|
||||
default="annotation_local",
|
||||
comment="执行器类型",
|
||||
)
|
||||
pipeline = Column(JSON, nullable=True, comment="算子编排定义")
|
||||
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
||||
stop_requested = Column(Boolean, default=False, comment="是否请求停止")
|
||||
total_images = Column(Integer, default=0, comment="总图片数")
|
||||
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
||||
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
||||
output_path = Column(String(500), nullable=True, comment="输出路径")
|
||||
output_dataset_id = Column(String(36), nullable=True, comment="输出数据集ID")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
created_at = Column(
|
||||
TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间"
|
||||
@@ -235,6 +251,9 @@ class AutoAnnotationTask(Base):
|
||||
onupdate=func.current_timestamp(),
|
||||
comment="更新时间",
|
||||
)
|
||||
started_at = Column(TIMESTAMP, nullable=True, comment="任务启动时间")
|
||||
heartbeat_at = Column(TIMESTAMP, nullable=True, comment="worker心跳时间")
|
||||
run_token = Column(String(64), nullable=True, comment="运行令牌")
|
||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
@@ -245,3 +264,36 @@ class AutoAnnotationTask(Base):
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
|
||||
class AnnotationTaskOperatorInstance(Base):
|
||||
"""自动标注任务内算子实例模型,对应表 t_dm_annotation_task_operator_instance"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_task_operator_instance"
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="自增主键")
|
||||
task_id = Column(String(36), nullable=False, comment="自动标注任务ID")
|
||||
op_index = Column(Integer, nullable=False, comment="算子顺序(从1开始)")
|
||||
operator_id = Column(String(64), nullable=False, comment="算子ID(raw_id)")
|
||||
settings_override = Column(JSON, nullable=True, comment="任务级算子参数覆盖")
|
||||
inputs = Column(String(64), nullable=True, comment="输入模态")
|
||||
outputs = Column(String(64), nullable=True, comment="输出模态")
|
||||
created_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
updated_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
nullable=False,
|
||||
comment="更新时间",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("task_id", "op_index", name="uk_task_op_index"),
|
||||
Index("idx_task_id", "task_id"),
|
||||
Index("idx_operator_id", "operator_id"),
|
||||
)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
"""FastAPI routes for Auto Annotation tasks.
|
||||
"""FastAPI routes for Annotation Operator Tasks.
|
||||
|
||||
These routes back the frontend AutoAnnotation module:
|
||||
- GET /api/annotation/auto
|
||||
- POST /api/annotation/auto
|
||||
- DELETE /api/annotation/auto/{task_id}
|
||||
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
|
||||
兼容路由:
|
||||
- GET/POST/DELETE /api/annotation/auto
|
||||
- GET /api/annotation/auto/{task_id}/status
|
||||
|
||||
新路由:
|
||||
- GET/POST/DELETE /api/annotation/operator-tasks
|
||||
- GET /api/annotation/operator-tasks/{task_id}
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -33,7 +34,6 @@ from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auto",
|
||||
tags=["annotation/auto"],
|
||||
)
|
||||
|
||||
@@ -41,15 +41,77 @@ logger = get_logger(__name__)
|
||||
service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_auto_annotation_tasks(
|
||||
def _normalize_request_by_route(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
route_mode: Literal["legacy_auto", "operator_tasks"],
|
||||
) -> CreateAutoAnnotationTaskRequest:
|
||||
"""根据路由入口做请求标准化。"""
|
||||
|
||||
if route_mode == "legacy_auto":
|
||||
# 旧接口强制走 legacy_yolo 模式,保持行为一致
|
||||
return request.model_copy(update={"task_mode": "legacy_yolo"})
|
||||
|
||||
# 新接口默认走 pipeline 模式(若请求未显式指定 taskMode)
|
||||
task_mode = request.task_mode
|
||||
if request.pipeline and task_mode == "legacy_yolo":
|
||||
task_mode = "pipeline"
|
||||
|
||||
return request.model_copy(update={"task_mode": task_mode})
|
||||
|
||||
|
||||
async def _create_task_internal(
|
||||
*,
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession,
|
||||
user_context: RequestUserContext,
|
||||
route_mode: Literal["legacy_auto", "operator_tasks"],
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
normalized_request = _normalize_request_by_route(request, route_mode)
|
||||
|
||||
logger.info(
|
||||
"Creating annotation task: route_mode=%s, name=%s, dataset_id=%s, task_mode=%s, executor_type=%s, config=%s, pipeline=%s, file_ids=%s",
|
||||
route_mode,
|
||||
normalized_request.name,
|
||||
normalized_request.dataset_id,
|
||||
normalized_request.task_mode,
|
||||
normalized_request.executor_type,
|
||||
normalized_request.config.model_dump(by_alias=True) if normalized_request.config else None,
|
||||
[step.model_dump(by_alias=True) for step in normalized_request.pipeline]
|
||||
if normalized_request.pipeline else None,
|
||||
normalized_request.file_ids,
|
||||
)
|
||||
|
||||
# 权限 + fileIds 归属校验
|
||||
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
|
||||
# 尝试获取数据集名称和总量用于冗余字段
|
||||
dataset_name = None
|
||||
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
if not normalized_request.file_ids:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
|
||||
|
||||
return await service.create_task(
|
||||
db,
|
||||
normalized_request,
|
||||
user_context=user_context,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auto", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
@router.get("/operator-tasks", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_annotation_operator_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""获取自动标注任务列表。
|
||||
|
||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||
"""
|
||||
"""获取标注任务列表。"""
|
||||
|
||||
tasks = await service.list_tasks(db, user_context)
|
||||
return StandardResponse(
|
||||
@@ -59,48 +121,19 @@ async def list_auto_annotation_tasks(
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
@router.post("/auto", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""创建自动标注任务。
|
||||
"""兼容旧版 /auto 接口创建任务。"""
|
||||
|
||||
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
request.dataset_id,
|
||||
request.config.model_dump(by_alias=True),
|
||||
request.file_ids,
|
||||
)
|
||||
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
await assert_dataset_access(db, request.dataset_id, user_context)
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||
dataset = await dm_client.get_dataset(request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
|
||||
if request.file_ids:
|
||||
total_images = len(request.file_ids)
|
||||
else:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset name for auto task: %s", e)
|
||||
|
||||
task = await service.create_task(
|
||||
db,
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
task = await _create_task_internal(
|
||||
request=request,
|
||||
db=db,
|
||||
user_context=user_context,
|
||||
route_mode="legacy_auto",
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
@@ -110,16 +143,36 @@ async def create_auto_annotation_task(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
@router.post("/operator-tasks", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_annotation_operator_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""新接口:创建通用算子编排标注任务。"""
|
||||
|
||||
task = await _create_task_internal(
|
||||
request=request,
|
||||
db=db,
|
||||
user_context=user_context,
|
||||
route_mode="operator_tasks",
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auto/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
@router.get("/operator-tasks/{task_id}", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""获取单个自动标注任务状态。
|
||||
|
||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||
"""
|
||||
"""获取单个标注任务状态/详情。"""
|
||||
|
||||
task = await service.get_task(db, task_id, user_context)
|
||||
if not task:
|
||||
@@ -132,7 +185,8 @@ async def get_auto_annotation_task_status(
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=StandardResponse[bool])
|
||||
@router.delete("/auto/{task_id}", response_model=StandardResponse[bool])
|
||||
@router.delete("/operator-tasks/{task_id}", response_model=StandardResponse[bool])
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -151,7 +205,28 @@ async def delete_auto_annotation_task(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/download")
|
||||
@router.post("/auto/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
@router.post("/operator-tasks/{task_id}/stop", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def stop_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""请求停止自动标注任务。"""
|
||||
|
||||
task = await service.request_stop_task(db, task_id, user_context)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auto/{task_id}/download")
|
||||
@router.get("/operator-tasks/{task_id}/download")
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -159,7 +234,6 @@ async def download_auto_annotation_result(
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Schemas for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
|
||||
|
||||
class AutoAnnotationConfig(BaseModel):
|
||||
@@ -26,13 +28,87 @@ class AutoAnnotationConfig(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class OperatorPipelineStep(BaseModel):
|
||||
"""通用算子编排中的单个算子节点定义"""
|
||||
|
||||
operator_id: str = Field(alias="operatorId", description="算子ID(raw_id)")
|
||||
overrides: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
alias="overrides",
|
||||
description="算子参数覆盖(对应 settings override)",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_compatible_fields(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
normalized = dict(value)
|
||||
|
||||
if "operatorId" not in normalized:
|
||||
for key in ("operator_id", "id"):
|
||||
candidate = normalized.get(key)
|
||||
if candidate:
|
||||
normalized["operatorId"] = candidate
|
||||
break
|
||||
|
||||
if "overrides" not in normalized:
|
||||
for key in ("settingsOverride", "settings_override"):
|
||||
candidate = normalized.get(key)
|
||||
if isinstance(candidate, str):
|
||||
try:
|
||||
candidate = json.loads(candidate)
|
||||
except Exception:
|
||||
candidate = None
|
||||
if isinstance(candidate, dict):
|
||||
normalized["overrides"] = candidate
|
||||
break
|
||||
|
||||
return normalized
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class CreateAutoAnnotationTaskRequest(BaseModel):
|
||||
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
config: AutoAnnotationConfig = Field(..., description="任务配置")
|
||||
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
|
||||
config: Optional[AutoAnnotationConfig] = Field(
|
||||
default=None,
|
||||
description="兼容旧版 YOLO 任务配置",
|
||||
)
|
||||
pipeline: Optional[List[OperatorPipelineStep]] = Field(
|
||||
default=None,
|
||||
description="通用算子编排定义",
|
||||
)
|
||||
task_mode: str = Field(
|
||||
default="legacy_yolo",
|
||||
alias="taskMode",
|
||||
description="任务模式: legacy_yolo/pipeline",
|
||||
)
|
||||
executor_type: str = Field(
|
||||
default="annotation_local",
|
||||
alias="executorType",
|
||||
description="执行器类型",
|
||||
)
|
||||
output_dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="outputDatasetName",
|
||||
description="输出数据集名称(优先级高于 config.outputDatasetName)",
|
||||
)
|
||||
file_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
alias="fileIds",
|
||||
description="要处理的文件ID列表,为空则处理数据集中所有图像",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_config_or_pipeline(self):
|
||||
if self.config is None and not self.pipeline:
|
||||
raise ValueError("Either config or pipeline must be provided")
|
||||
return self
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@@ -44,6 +120,9 @@ class AutoAnnotationTaskResponse(BaseModel):
|
||||
name: str = Field(..., description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
|
||||
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
|
||||
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
|
||||
source_datasets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
alias="sourceDatasets",
|
||||
@@ -56,7 +135,20 @@ class AutoAnnotationTaskResponse(BaseModel):
|
||||
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
||||
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
||||
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
||||
output_dataset_id: Optional[str] = Field(
|
||||
None,
|
||||
alias="outputDatasetId",
|
||||
description="输出数据集ID",
|
||||
)
|
||||
stop_requested: Optional[bool] = Field(
|
||||
None,
|
||||
alias="stopRequested",
|
||||
description="是否请求停止",
|
||||
)
|
||||
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
|
||||
created_by: Optional[str] = Field(None, alias="createdBy", description="创建人")
|
||||
started_at: Optional[datetime] = Field(None, alias="startedAt", description="启动时间")
|
||||
heartbeat_at: Optional[datetime] = Field(None, alias="heartbeatAt", description="心跳时间")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
"""Service layer for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.annotation_management import (
|
||||
AutoAnnotationTask,
|
||||
AnnotationTaskOperatorInstance,
|
||||
)
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from app.module.annotation.security import RequestUserContext
|
||||
|
||||
@@ -21,10 +25,120 @@ from ..schema.auto import (
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
@staticmethod
|
||||
def _normalize_file_ids(file_ids: Optional[List[str]]) -> List[str]:
|
||||
if not file_ids:
|
||||
return []
|
||||
return [fid for fid in dict.fromkeys(file_ids) if fid]
|
||||
|
||||
@staticmethod
|
||||
def _extract_operator_id(step: Dict[str, Any]) -> Optional[str]:
|
||||
operator_id = step.get("operatorId") or step.get("operator_id") or step.get("id")
|
||||
if operator_id is None:
|
||||
return None
|
||||
operator_id = str(operator_id).strip()
|
||||
return operator_id or None
|
||||
|
||||
@classmethod
|
||||
def _to_operator_instances(
|
||||
cls,
|
||||
task_id: str,
|
||||
pipeline: List[Dict[str, Any]],
|
||||
) -> List[AnnotationTaskOperatorInstance]:
|
||||
instances: List[AnnotationTaskOperatorInstance] = []
|
||||
for step in pipeline:
|
||||
if not isinstance(step, dict):
|
||||
continue
|
||||
operator_id = cls._extract_operator_id(step)
|
||||
if not operator_id:
|
||||
continue
|
||||
|
||||
settings_override = (
|
||||
step.get("overrides")
|
||||
or step.get("settingsOverride")
|
||||
or step.get("settings_override")
|
||||
or {}
|
||||
)
|
||||
if not isinstance(settings_override, dict):
|
||||
settings_override = {}
|
||||
|
||||
instances.append(
|
||||
AnnotationTaskOperatorInstance(
|
||||
task_id=task_id,
|
||||
op_index=len(instances) + 1,
|
||||
operator_id=operator_id,
|
||||
settings_override=settings_override,
|
||||
inputs=step.get("inputs"),
|
||||
outputs=step.get("outputs"),
|
||||
)
|
||||
)
|
||||
return instances
|
||||
|
||||
@staticmethod
|
||||
def _to_pipeline(request: CreateAutoAnnotationTaskRequest) -> Optional[List[Dict[str, Any]]]:
|
||||
"""将请求标准化为 pipeline 结构。"""
|
||||
|
||||
if request.pipeline:
|
||||
return [step.model_dump(by_alias=True) for step in request.pipeline]
|
||||
|
||||
if request.config is None:
|
||||
return None
|
||||
|
||||
# 兼容旧版 YOLO 请求 -> 单步 pipeline
|
||||
config = request.config.model_dump(by_alias=True)
|
||||
step_overrides: Dict[str, Any] = {
|
||||
"modelSize": config.get("modelSize"),
|
||||
"confThreshold": config.get("confThreshold"),
|
||||
"targetClasses": config.get("targetClasses") or [],
|
||||
}
|
||||
|
||||
output_dataset_name = request.output_dataset_name or config.get("outputDatasetName")
|
||||
if output_dataset_name:
|
||||
step_overrides["outputDatasetName"] = output_dataset_name
|
||||
|
||||
return [
|
||||
{
|
||||
"operatorId": "ImageObjectDetectionBoundingBox",
|
||||
"overrides": step_overrides,
|
||||
}
|
||||
]
|
||||
|
||||
async def validate_file_ids(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
dataset_id: str,
|
||||
file_ids: Optional[List[str]],
|
||||
) -> List[str]:
|
||||
"""校验 fileIds 是否全部属于 dataset 且有效。"""
|
||||
|
||||
normalized_ids = self._normalize_file_ids(file_ids)
|
||||
if not normalized_ids:
|
||||
if file_ids:
|
||||
raise HTTPException(status_code=400, detail="fileIds 不能为空列表")
|
||||
return []
|
||||
|
||||
stmt = select(DatasetFiles.id).where(
|
||||
DatasetFiles.id.in_(normalized_ids),
|
||||
DatasetFiles.dataset_id == dataset_id,
|
||||
DatasetFiles.status == "ACTIVE",
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
found_ids = {row[0] for row in result.fetchall()}
|
||||
missing = [fid for fid in normalized_ids if fid not in found_ids]
|
||||
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"部分 fileIds 不存在、不可用或不属于数据集: {missing[:10]}",
|
||||
)
|
||||
|
||||
return normalized_ids
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
user_context: RequestUserContext,
|
||||
dataset_name: Optional[str] = None,
|
||||
total_images: int = 0,
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
@@ -36,23 +150,46 @@ class AutoAnnotationTaskService:
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
validated_file_ids = await self.validate_file_ids(
|
||||
db,
|
||||
request.dataset_id,
|
||||
request.file_ids,
|
||||
)
|
||||
if validated_file_ids:
|
||||
total_images = len(validated_file_ids)
|
||||
|
||||
normalized_pipeline = self._to_pipeline(request)
|
||||
if not normalized_pipeline:
|
||||
raise HTTPException(status_code=400, detail="pipeline 不能为空")
|
||||
|
||||
normalized_config = request.config.model_dump(by_alias=True) if request.config else {}
|
||||
|
||||
task_id = str(uuid4())
|
||||
task = AutoAnnotationTask(
|
||||
id=str(uuid4()),
|
||||
id=task_id,
|
||||
name=request.name,
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
config=request.config.model_dump(by_alias=True),
|
||||
file_ids=request.file_ids, # 存储用户选择的文件ID列表
|
||||
created_by=user_context.user_id,
|
||||
config=normalized_config,
|
||||
task_mode=request.task_mode,
|
||||
executor_type=request.executor_type,
|
||||
pipeline=normalized_pipeline,
|
||||
file_ids=validated_file_ids or None,
|
||||
status="pending",
|
||||
progress=0,
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
stop_requested=False,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
operator_instances = self._to_operator_instances(task_id, normalized_pipeline)
|
||||
db.add(task)
|
||||
if operator_instances:
|
||||
db.add_all(operator_instances)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
@@ -156,6 +293,45 @@ class AutoAnnotationTaskService:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def request_stop_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task_id: str,
|
||||
user_context: RequestUserContext,
|
||||
) -> Optional[AutoAnnotationTaskResponse]:
|
||||
query = select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
query = self._apply_dataset_scope(query, user_context)
|
||||
result = await db.execute(query)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
now = datetime.now()
|
||||
terminal_states = {"completed", "failed", "stopped"}
|
||||
if task.status not in terminal_states:
|
||||
task.stop_requested = True
|
||||
task.error_message = "Task stop requested"
|
||||
if task.status == "pending":
|
||||
task.status = "stopped"
|
||||
task.progress = task.progress or 0
|
||||
task.completed_at = now
|
||||
task.run_token = None
|
||||
task.updated_at = now
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def soft_delete_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Annotation-related operators (e.g. YOLO detection)."""
|
||||
|
||||
from . import image_object_detection_bounding_box
|
||||
|
||||
__all__ = [
|
||||
"image_object_detection_bounding_box",
|
||||
]
|
||||
|
||||
@@ -4,6 +4,13 @@ This package exposes the ImageObjectDetectionBoundingBox annotator so that
|
||||
the auto-annotation worker can import it via different module paths.
|
||||
"""
|
||||
|
||||
from datamate.core.base_op import OPERATORS
|
||||
|
||||
from .process import ImageObjectDetectionBoundingBox
|
||||
|
||||
OPERATORS.register_module(
|
||||
module_name="ImageObjectDetectionBoundingBox",
|
||||
module_path="ops.annotation.image_object_detection_bounding_box.process",
|
||||
)
|
||||
|
||||
__all__ = ["ImageObjectDetectionBoundingBox"]
|
||||
|
||||
@@ -1,3 +1,48 @@
|
||||
name: image_object_detection_bounding_box
|
||||
version: 0.1.0
|
||||
description: "YOLOv8-based object detection operator for auto annotation"
|
||||
name: '图像目标检测(YOLOv8)'
|
||||
name_en: 'Image Object Detection (YOLOv8)'
|
||||
description: '基于 YOLOv8 的目标检测算子,输出带框图像与标注 JSON。'
|
||||
description_en: 'YOLOv8-based object detection operator that outputs boxed images and annotation JSON files.'
|
||||
language: 'python'
|
||||
vendor: 'huawei'
|
||||
raw_id: 'ImageObjectDetectionBoundingBox'
|
||||
version: '1.0.0'
|
||||
types:
|
||||
- 'annotation'
|
||||
modal: 'image'
|
||||
inputs: 'image'
|
||||
outputs: 'image'
|
||||
settings:
|
||||
modelSize:
|
||||
name: '模型规模'
|
||||
description: 'YOLOv8 模型规模:n/s/m/l/x。'
|
||||
type: 'select'
|
||||
defaultVal: 'l'
|
||||
options:
|
||||
- label: 'n'
|
||||
value: 'n'
|
||||
- label: 's'
|
||||
value: 's'
|
||||
- label: 'm'
|
||||
value: 'm'
|
||||
- label: 'l'
|
||||
value: 'l'
|
||||
- label: 'x'
|
||||
value: 'x'
|
||||
confThreshold:
|
||||
name: '置信度阈值'
|
||||
description: '检测结果最小置信度,范围 0~1。'
|
||||
type: 'slider'
|
||||
defaultVal: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
step: 0.01
|
||||
targetClasses:
|
||||
name: '目标类别'
|
||||
description: 'COCO 类别 ID 列表;为空表示全部类别。'
|
||||
type: 'input'
|
||||
defaultVal: '[]'
|
||||
outputDir:
|
||||
name: '输出目录'
|
||||
description: '算子输出目录(由运行时注入)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
|
||||
@@ -19,6 +19,7 @@ frontend can display real-time status.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -32,6 +33,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
|
||||
try:
|
||||
import datamate.ops # noqa: F401
|
||||
except Exception as import_ops_err: # pragma: no cover - 兜底日志
|
||||
logger.warning("Failed to import datamate.ops package for operator registry: {}", import_ops_err)
|
||||
|
||||
try:
|
||||
import ops.annotation # type: ignore # noqa: F401
|
||||
except Exception as import_annotation_ops_err: # pragma: no cover - 兜底日志
|
||||
logger.warning(
|
||||
"Failed to import ops.annotation package for operator registry: {}",
|
||||
import_annotation_ops_err,
|
||||
)
|
||||
|
||||
try:
|
||||
from datamate.core.base_op import OPERATORS
|
||||
except Exception: # pragma: no cover - 兜底
|
||||
OPERATORS = None # type: ignore
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
|
||||
# 尝试多种导入路径,适配不同的打包/安装方式
|
||||
@@ -101,42 +120,91 @@ DEFAULT_OUTPUT_ROOT = os.getenv(
|
||||
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
|
||||
)
|
||||
|
||||
DEFAULT_OPERATOR_WHITELIST = os.getenv(
|
||||
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
|
||||
"ImageObjectDetectionBoundingBox",
|
||||
)
|
||||
|
||||
|
||||
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
||||
"""从 t_dm_auto_annotation_tasks 中取出一个 pending 任务。"""
|
||||
"""原子 claim 一个 pending 任务并返回任务详情。"""
|
||||
|
||||
sql = text(
|
||||
def _parse_json_field(value: Any, default: Any) -> Any:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, (dict, list)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
text_value = value.strip()
|
||||
if not text_value:
|
||||
return default
|
||||
try:
|
||||
return json.loads(text_value)
|
||||
except Exception:
|
||||
return default
|
||||
return default
|
||||
|
||||
run_token = str(uuid.uuid4())
|
||||
now = datetime.now()
|
||||
|
||||
claim_sql = text(
|
||||
"""
|
||||
SELECT id, name, dataset_id, dataset_name, config, file_ids, status,
|
||||
total_images, processed_images, detected_objects, output_path
|
||||
UPDATE t_dm_auto_annotation_tasks
|
||||
SET status = 'running',
|
||||
run_token = :run_token,
|
||||
started_at = COALESCE(started_at, :now),
|
||||
heartbeat_at = :now,
|
||||
updated_at = :now,
|
||||
error_message = NULL
|
||||
WHERE id = (
|
||||
SELECT id FROM (
|
||||
SELECT id
|
||||
FROM t_dm_auto_annotation_tasks
|
||||
WHERE status = 'pending' AND deleted_at IS NULL
|
||||
WHERE status = 'pending'
|
||||
AND deleted_at IS NULL
|
||||
AND COALESCE(stop_requested, 0) = 0
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
) AS pending_task
|
||||
)
|
||||
AND status = 'pending'
|
||||
AND deleted_at IS NULL
|
||||
AND COALESCE(stop_requested, 0) = 0
|
||||
"""
|
||||
)
|
||||
|
||||
query_sql = text(
|
||||
"""
|
||||
SELECT id, name, dataset_id, dataset_name, created_by,
|
||||
config, file_ids, pipeline,
|
||||
task_mode, executor_type,
|
||||
status, stop_requested, run_token,
|
||||
total_images, processed_images, detected_objects,
|
||||
output_path, output_dataset_id
|
||||
FROM t_dm_auto_annotation_tasks
|
||||
WHERE run_token = :run_token
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
result = conn.execute(sql).fetchone()
|
||||
claim_result = conn.execute(claim_sql, {"run_token": run_token, "now": now})
|
||||
if not claim_result or int(getattr(claim_result, "rowcount", 0) or 0) <= 0:
|
||||
return None
|
||||
|
||||
result = conn.execute(query_sql, {"run_token": run_token}).fetchone()
|
||||
if not result:
|
||||
return None
|
||||
|
||||
row = dict(result._mapping) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
row["config"] = json.loads(row["config"]) if row.get("config") else {}
|
||||
except Exception:
|
||||
row["config"] = {}
|
||||
row["config"] = _parse_json_field(row.get("config"), {})
|
||||
|
||||
try:
|
||||
raw_ids = row.get("file_ids")
|
||||
if not raw_ids:
|
||||
row["file_ids"] = None
|
||||
elif isinstance(raw_ids, str):
|
||||
row["file_ids"] = json.loads(raw_ids)
|
||||
else:
|
||||
row["file_ids"] = raw_ids
|
||||
except Exception:
|
||||
row["file_ids"] = None
|
||||
parsed_file_ids = _parse_json_field(row.get("file_ids"), None)
|
||||
row["file_ids"] = parsed_file_ids if parsed_file_ids else None
|
||||
|
||||
parsed_pipeline = _parse_json_field(row.get("pipeline"), None)
|
||||
row["pipeline"] = parsed_pipeline if parsed_pipeline else None
|
||||
return row
|
||||
|
||||
|
||||
@@ -144,13 +212,16 @@ def _update_task_status(
|
||||
task_id: str,
|
||||
*,
|
||||
status: str,
|
||||
run_token: Optional[str] = None,
|
||||
progress: Optional[int] = None,
|
||||
processed_images: Optional[int] = None,
|
||||
detected_objects: Optional[int] = None,
|
||||
total_images: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
output_dataset_id: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed: bool = False,
|
||||
clear_run_token: bool = False,
|
||||
) -> None:
|
||||
"""更新任务的状态和统计字段。"""
|
||||
|
||||
@@ -176,23 +247,318 @@ def _update_task_status(
|
||||
if output_path is not None:
|
||||
fields.append("output_path = :output_path")
|
||||
params["output_path"] = output_path
|
||||
if output_dataset_id is not None:
|
||||
fields.append("output_dataset_id = :output_dataset_id")
|
||||
params["output_dataset_id"] = output_dataset_id
|
||||
if error_message is not None:
|
||||
fields.append("error_message = :error_message")
|
||||
params["error_message"] = error_message[:2000]
|
||||
if status == "running":
|
||||
fields.append("heartbeat_at = :heartbeat_at")
|
||||
params["heartbeat_at"] = datetime.now()
|
||||
if completed:
|
||||
fields.append("completed_at = :completed_at")
|
||||
params["completed_at"] = datetime.now()
|
||||
if clear_run_token:
|
||||
fields.append("run_token = NULL")
|
||||
|
||||
where_clause = "id = :task_id"
|
||||
if run_token:
|
||||
where_clause += " AND run_token = :run_token"
|
||||
params["run_token"] = run_token
|
||||
|
||||
sql = text(
|
||||
f"""
|
||||
UPDATE t_dm_auto_annotation_tasks
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = :task_id
|
||||
WHERE {where_clause}
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(sql, params)
|
||||
result = conn.execute(sql, params)
|
||||
if int(getattr(result, "rowcount", 0) or 0) <= 0:
|
||||
logger.warning(
|
||||
"No rows updated for task status change: task_id={}, status={}, run_token={}",
|
||||
task_id,
|
||||
status,
|
||||
run_token,
|
||||
)
|
||||
|
||||
|
||||
def _is_stop_requested(task_id: str, run_token: Optional[str] = None) -> bool:
|
||||
"""检查任务是否请求停止。"""
|
||||
|
||||
where_clause = "id = :task_id"
|
||||
params: Dict[str, Any] = {"task_id": task_id}
|
||||
if run_token:
|
||||
where_clause += " AND run_token = :run_token"
|
||||
params["run_token"] = run_token
|
||||
|
||||
sql = text(
|
||||
f"""
|
||||
SELECT COALESCE(stop_requested, 0)
|
||||
FROM t_dm_auto_annotation_tasks
|
||||
WHERE {where_clause}
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
row = conn.execute(sql, params).fetchone()
|
||||
if not row:
|
||||
# 找不到任务(或 run_token 已失效)时保守停止
|
||||
return True
|
||||
return bool(row[0])
|
||||
|
||||
|
||||
def _extract_step_overrides(step: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并 pipeline 节点中的参数覆盖。"""
|
||||
|
||||
overrides: Dict[str, Any] = {}
|
||||
for key in ("overrides", "settingsOverride", "settings_override"):
|
||||
value = step.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(value, dict):
|
||||
overrides.update(value)
|
||||
return overrides
|
||||
|
||||
|
||||
def _build_legacy_pipeline(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""将 legacy_yolo 配置映射为单步 pipeline。"""
|
||||
|
||||
return [
|
||||
{
|
||||
"operatorId": "ImageObjectDetectionBoundingBox",
|
||||
"overrides": {
|
||||
"modelSize": config.get("modelSize", "l"),
|
||||
"confThreshold": float(config.get("confThreshold", 0.7)),
|
||||
"targetClasses": config.get("targetClasses", []) or [],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _get_output_dataset_name(
|
||||
task_id: str,
|
||||
dataset_id: str,
|
||||
source_dataset_name: str,
|
||||
task_name: str,
|
||||
config: Dict[str, Any],
|
||||
pipeline_raw: Optional[List[Any]],
|
||||
) -> str:
|
||||
"""确定输出数据集名称。"""
|
||||
|
||||
output_name = config.get("outputDatasetName")
|
||||
if output_name:
|
||||
return str(output_name)
|
||||
|
||||
if pipeline_raw:
|
||||
for step in pipeline_raw:
|
||||
if not isinstance(step, dict):
|
||||
continue
|
||||
overrides = _extract_step_overrides(step)
|
||||
output_name = overrides.get("outputDatasetName") or overrides.get("output_dataset_name")
|
||||
if output_name:
|
||||
return str(output_name)
|
||||
|
||||
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
|
||||
return f"{base_name}_auto_{task_id[:8]}"
|
||||
|
||||
|
||||
def _normalize_pipeline(
|
||||
task_mode: str,
|
||||
config: Dict[str, Any],
|
||||
pipeline_raw: Optional[List[Any]],
|
||||
output_dir: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""标准化 pipeline 结构并注入 outputDir。"""
|
||||
|
||||
source_pipeline = pipeline_raw
|
||||
if task_mode == "legacy_yolo" or not source_pipeline:
|
||||
source_pipeline = _build_legacy_pipeline(config)
|
||||
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for step in source_pipeline:
|
||||
if not isinstance(step, dict):
|
||||
continue
|
||||
|
||||
operator_id: Optional[str] = None
|
||||
overrides: Dict[str, Any] = {}
|
||||
|
||||
# 兼容 [{"OpName": {...}}] 风格
|
||||
if (
|
||||
"operatorId" not in step
|
||||
and "operator_id" not in step
|
||||
and "id" not in step
|
||||
and len(step) == 1
|
||||
):
|
||||
first_key = next(iter(step.keys()))
|
||||
first_value = step.get(first_key)
|
||||
if isinstance(first_key, str):
|
||||
operator_id = first_key
|
||||
if isinstance(first_value, dict):
|
||||
overrides.update(first_value)
|
||||
|
||||
operator_id = operator_id or step.get("operatorId") or step.get("operator_id") or step.get("id")
|
||||
if not operator_id:
|
||||
continue
|
||||
|
||||
overrides.update(_extract_step_overrides(step))
|
||||
overrides.setdefault("outputDir", output_dir)
|
||||
|
||||
normalized.append(
|
||||
{
|
||||
"operator_id": str(operator_id),
|
||||
"overrides": overrides,
|
||||
}
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _resolve_operator_class(operator_id: str):
|
||||
"""根据 operator_id 解析算子类。"""
|
||||
|
||||
if operator_id == "ImageObjectDetectionBoundingBox":
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
raise ImportError("ImageObjectDetectionBoundingBox is not available")
|
||||
return ImageObjectDetectionBoundingBox
|
||||
|
||||
registry_item = OPERATORS.get(operator_id) if OPERATORS is not None else None
|
||||
if registry_item is None:
|
||||
try:
|
||||
from core.base_op import OPERATORS as relative_operators # type: ignore
|
||||
registry_item = relative_operators.get(operator_id)
|
||||
except Exception:
|
||||
registry_item = None
|
||||
|
||||
if registry_item is None:
|
||||
raise ImportError(f"Operator not found in registry: {operator_id}")
|
||||
|
||||
if isinstance(registry_item, str):
|
||||
submodule = importlib.import_module(registry_item)
|
||||
operator_cls = getattr(submodule, operator_id, None)
|
||||
if operator_cls is None:
|
||||
raise ImportError(
|
||||
f"Operator class {operator_id} not found in module {registry_item}"
|
||||
)
|
||||
return operator_cls
|
||||
|
||||
return registry_item
|
||||
|
||||
|
||||
def _build_operator_chain(pipeline: List[Dict[str, Any]]) -> List[Tuple[str, Any]]:
|
||||
"""初始化算子链。"""
|
||||
|
||||
chain: List[Tuple[str, Any]] = []
|
||||
for step in pipeline:
|
||||
operator_id = step.get("operator_id")
|
||||
overrides = dict(step.get("overrides") or {})
|
||||
if not operator_id:
|
||||
continue
|
||||
|
||||
operator_cls = _resolve_operator_class(str(operator_id))
|
||||
operator = operator_cls(**overrides)
|
||||
chain.append((str(operator_id), operator))
|
||||
return chain
|
||||
|
||||
|
||||
def _run_pipeline_sample(sample: Dict[str, Any], chain: List[Tuple[str, Any]]) -> Dict[str, Any]:
|
||||
"""在单个样本上执行 pipeline。"""
|
||||
|
||||
current_sample: Dict[str, Any] = dict(sample)
|
||||
for operator_id, operator in chain:
|
||||
if hasattr(operator, "execute") and callable(getattr(operator, "execute")):
|
||||
result = operator.execute(current_sample)
|
||||
elif callable(operator):
|
||||
result = operator(current_sample)
|
||||
else:
|
||||
raise RuntimeError(f"Operator {operator_id} is not executable")
|
||||
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if isinstance(result, dict):
|
||||
current_sample.update(result)
|
||||
continue
|
||||
|
||||
if isinstance(result, list):
|
||||
# 仅取第一个 dict 结果,兼容部分返回 list 的算子
|
||||
if result and isinstance(result[0], dict):
|
||||
current_sample.update(result[0])
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
"Operator {} returned unsupported result type: {}",
|
||||
operator_id,
|
||||
type(result).__name__,
|
||||
)
|
||||
return current_sample
|
||||
|
||||
|
||||
def _count_detections(sample: Dict[str, Any]) -> int:
|
||||
"""从样本中提取检测数量。"""
|
||||
|
||||
annotations = sample.get("annotations")
|
||||
if isinstance(annotations, dict):
|
||||
detections = annotations.get("detections")
|
||||
if isinstance(detections, list):
|
||||
return len(detections)
|
||||
|
||||
detection_count = sample.get("detection_count")
|
||||
if detection_count is None:
|
||||
return 0
|
||||
try:
|
||||
return max(int(detection_count), 0)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _get_operator_whitelist() -> Optional[set[str]]:
|
||||
"""获取灰度白名单;返回 None 表示放开全部。"""
|
||||
|
||||
raw = str(DEFAULT_OPERATOR_WHITELIST or "").strip()
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
normalized = raw.lower()
|
||||
if normalized in {"*", "all", "any"}:
|
||||
return None
|
||||
|
||||
allow_set = {
|
||||
item.strip()
|
||||
for item in raw.split(",")
|
||||
if item and item.strip()
|
||||
}
|
||||
return allow_set or None
|
||||
|
||||
|
||||
def _validate_pipeline_whitelist(pipeline: List[Dict[str, Any]]) -> None:
|
||||
"""校验 pipeline 是否命中灰度白名单。"""
|
||||
|
||||
allow_set = _get_operator_whitelist()
|
||||
if allow_set is None:
|
||||
return
|
||||
|
||||
blocked = []
|
||||
for step in pipeline:
|
||||
operator_id = str(step.get("operator_id") or "")
|
||||
if not operator_id:
|
||||
continue
|
||||
if operator_id not in allow_set:
|
||||
blocked.append(operator_id)
|
||||
|
||||
if blocked:
|
||||
raise RuntimeError(
|
||||
"Operator not in whitelist: " + ", ".join(sorted(set(blocked)))
|
||||
)
|
||||
|
||||
|
||||
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
|
||||
@@ -455,45 +821,48 @@ def _register_output_dataset(
|
||||
def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
"""执行单个自动标注任务。"""
|
||||
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
logger.error(
|
||||
"YOLO operator not available (import failed earlier), skip auto-annotation task: {}",
|
||||
task["id"],
|
||||
)
|
||||
_update_task_status(
|
||||
task["id"],
|
||||
status="failed",
|
||||
error_message="YOLO operator not available in runtime container",
|
||||
)
|
||||
return
|
||||
|
||||
task_id = str(task["id"])
|
||||
dataset_id = str(task["dataset_id"])
|
||||
task_name = str(task.get("name") or "")
|
||||
source_dataset_name = str(task.get("dataset_name") or "")
|
||||
run_token = str(task.get("run_token") or "")
|
||||
task_mode = str(task.get("task_mode") or "legacy_yolo")
|
||||
executor_type = str(task.get("executor_type") or "annotation_local")
|
||||
cfg: Dict[str, Any] = task.get("config") or {}
|
||||
pipeline_raw = task.get("pipeline")
|
||||
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
|
||||
|
||||
model_size = cfg.get("modelSize", "l")
|
||||
conf_threshold = float(cfg.get("confThreshold", 0.7))
|
||||
target_classes = cfg.get("targetClasses", []) or []
|
||||
output_dataset_name = cfg.get("outputDatasetName")
|
||||
|
||||
if not output_dataset_name:
|
||||
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
|
||||
output_dataset_name = f"{base_name}_auto_{task_id[:8]}"
|
||||
output_dataset_name = _get_output_dataset_name(
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
source_dataset_name=source_dataset_name,
|
||||
task_name=task_name,
|
||||
config=cfg,
|
||||
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Start processing auto-annotation task: id={}, dataset_id={}, model_size={}, conf_threshold={}, target_classes={}, output_dataset_name={}",
|
||||
"Start processing auto-annotation task: id={}, dataset_id={}, task_mode={}, executor_type={}, output_dataset_name={}",
|
||||
task_id,
|
||||
dataset_id,
|
||||
model_size,
|
||||
conf_threshold,
|
||||
target_classes,
|
||||
task_mode,
|
||||
executor_type,
|
||||
output_dataset_name,
|
||||
)
|
||||
|
||||
_update_task_status(task_id, status="running", progress=0)
|
||||
if _is_stop_requested(task_id, run_token):
|
||||
logger.info("Task stop requested before processing started: {}", task_id)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="stopped",
|
||||
completed=True,
|
||||
clear_run_token=True,
|
||||
error_message="Task stopped before start",
|
||||
)
|
||||
return
|
||||
|
||||
_update_task_status(task_id, run_token=run_token, status="running", progress=0)
|
||||
|
||||
if selected_file_ids:
|
||||
all_files = _load_files_by_ids(selected_file_ids)
|
||||
@@ -507,6 +876,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="completed",
|
||||
progress=100,
|
||||
total_images=0,
|
||||
@@ -514,6 +884,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
detected_objects=0,
|
||||
completed=True,
|
||||
output_path=None,
|
||||
clear_run_token=True,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -524,51 +895,91 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
)
|
||||
output_dir = _ensure_output_dir(output_dir)
|
||||
|
||||
try:
|
||||
detector = ImageObjectDetectionBoundingBox(
|
||||
modelSize=model_size,
|
||||
confThreshold=conf_threshold,
|
||||
targetClasses=target_classes,
|
||||
outputDir=output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to init YOLO detector for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="running",
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
normalized_pipeline = _normalize_pipeline(
|
||||
task_mode=task_mode,
|
||||
config=cfg,
|
||||
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
if not normalized_pipeline:
|
||||
raise RuntimeError("Pipeline is empty after normalization")
|
||||
|
||||
_validate_pipeline_whitelist(normalized_pipeline)
|
||||
|
||||
chain = _build_operator_chain(normalized_pipeline)
|
||||
if not chain:
|
||||
raise RuntimeError("No valid operator instances initialized")
|
||||
except Exception as e:
|
||||
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="failed",
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
error_message=f"Init YOLO detector failed: {e}",
|
||||
error_message=f"Init pipeline failed: {e}",
|
||||
clear_run_token=True,
|
||||
)
|
||||
return
|
||||
|
||||
processed = 0
|
||||
detected_total = 0
|
||||
|
||||
try:
|
||||
|
||||
for file_path, file_name in files:
|
||||
if _is_stop_requested(task_id, run_token):
|
||||
logger.info("Task stop requested during processing: {}", task_id)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="stopped",
|
||||
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
completed=True,
|
||||
clear_run_token=True,
|
||||
error_message="Task stopped by request",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
sample = {
|
||||
"image": file_path,
|
||||
"filename": file_name,
|
||||
}
|
||||
result = detector.execute(sample)
|
||||
|
||||
annotations = (result or {}).get("annotations", {})
|
||||
detections = annotations.get("detections", [])
|
||||
detected_total += len(detections)
|
||||
result = _run_pipeline_sample(sample, chain)
|
||||
detected_total += _count_detections(result)
|
||||
processed += 1
|
||||
|
||||
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="running",
|
||||
progress=progress,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -581,13 +992,16 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="completed",
|
||||
progress=100,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
completed=True,
|
||||
clear_run_token=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -614,6 +1028,21 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
task_id,
|
||||
e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task execution failed for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
run_token=run_token,
|
||||
status="failed",
|
||||
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
output_dataset_id=output_dataset_id,
|
||||
error_message=f"Execute pipeline failed: {e}",
|
||||
clear_run_token=True,
|
||||
)
|
||||
|
||||
|
||||
def _worker_loop() -> None:
|
||||
|
||||
249
scripts/db/data-annotation-operator-pipeline-migration.sql
Normal file
249
scripts/db/data-annotation-operator-pipeline-migration.sql
Normal file
@@ -0,0 +1,249 @@
|
||||
-- DataMate 数据标注模块 - 通用算子编排改造(第1步:DDL)
|
||||
-- 说明:
|
||||
-- 1) 修改 t_dm_auto_annotation_tasks,新增编排相关字段和索引
|
||||
-- 2) 新建 t_dm_annotation_task_operator_instance,用于记录任务内算子实例
|
||||
-- 3) 本脚本按“幂等”方式编写,可重复执行
|
||||
|
||||
USE datamate;
|
||||
|
||||
SET @db_name = DATABASE();
|
||||
|
||||
-- =====================================================
|
||||
-- 1) 修改 t_dm_auto_annotation_tasks 表
|
||||
-- =====================================================
|
||||
|
||||
-- task_mode: 任务模式(legacy_yolo / pipeline)
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'task_mode'
|
||||
),
|
||||
'SELECT ''skip: column task_mode exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN task_mode VARCHAR(32) NOT NULL DEFAULT ''legacy_yolo'' COMMENT ''任务模式: legacy_yolo/pipeline'' AFTER status'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- executor_type: 执行器类型
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'executor_type'
|
||||
),
|
||||
'SELECT ''skip: column executor_type exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN executor_type VARCHAR(32) NOT NULL DEFAULT ''annotation_local'' COMMENT ''执行器类型'' AFTER task_mode'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- pipeline: 算子编排定义(JSON)
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'pipeline'
|
||||
),
|
||||
'SELECT ''skip: column pipeline exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN pipeline JSON NULL COMMENT ''算子编排定义'' AFTER executor_type'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- output_dataset_id: 输出数据集ID
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'output_dataset_id'
|
||||
),
|
||||
'SELECT ''skip: column output_dataset_id exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN output_dataset_id VARCHAR(36) NULL COMMENT ''输出数据集ID'' AFTER output_path'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- created_by: 任务创建人
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'created_by'
|
||||
),
|
||||
'SELECT ''skip: column created_by exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN created_by VARCHAR(255) NULL COMMENT ''任务创建人'' AFTER dataset_name'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- stop_requested: 停止请求标记
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'stop_requested'
|
||||
),
|
||||
'SELECT ''skip: column stop_requested exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN stop_requested TINYINT(1) NOT NULL DEFAULT 0 COMMENT ''是否请求停止: 0否/1是'' AFTER progress'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- started_at: 启动时间
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'started_at'
|
||||
),
|
||||
'SELECT ''skip: column started_at exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN started_at TIMESTAMP NULL COMMENT ''任务启动时间'' AFTER updated_at'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- heartbeat_at: worker 心跳时间
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'heartbeat_at'
|
||||
),
|
||||
'SELECT ''skip: column heartbeat_at exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN heartbeat_at TIMESTAMP NULL COMMENT ''worker心跳时间'' AFTER started_at'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- run_token: 运行令牌(用于任务 claim)
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'run_token'
|
||||
),
|
||||
'SELECT ''skip: column run_token exists''',
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN run_token VARCHAR(64) NULL COMMENT ''运行令牌'' AFTER heartbeat_at'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- status 注释补全 stopped(若字段已存在)
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND COLUMN_NAME = 'status'
|
||||
),
|
||||
'ALTER TABLE t_dm_auto_annotation_tasks MODIFY COLUMN status VARCHAR(50) NOT NULL DEFAULT ''pending'' COMMENT ''任务状态: pending/running/completed/failed/stopped''',
|
||||
'SELECT ''skip: column status not found'''
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- 索引:按状态 + 创建时间查询任务
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND INDEX_NAME = 'idx_status_created'
|
||||
),
|
||||
'SELECT ''skip: index idx_status_created exists''',
|
||||
'CREATE INDEX idx_status_created ON t_dm_auto_annotation_tasks (status, created_at)'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
-- 索引:按创建人过滤任务
|
||||
SET @ddl = (
|
||||
SELECT IF(
|
||||
EXISTS(
|
||||
SELECT 1
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = @db_name
|
||||
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||
AND INDEX_NAME = 'idx_created_by'
|
||||
),
|
||||
'SELECT ''skip: index idx_created_by exists''',
|
||||
'CREATE INDEX idx_created_by ON t_dm_auto_annotation_tasks (created_by)'
|
||||
)
|
||||
);
|
||||
PREPARE stmt FROM @ddl;
|
||||
EXECUTE stmt;
|
||||
DEALLOCATE PREPARE stmt;
|
||||
|
||||
|
||||
-- =====================================================
|
||||
-- 2) 创建 t_dm_annotation_task_operator_instance 表
|
||||
-- =====================================================
|
||||
|
||||
CREATE TABLE IF NOT EXISTS t_dm_annotation_task_operator_instance (
|
||||
id BIGINT AUTO_INCREMENT PRIMARY KEY COMMENT '自增主键',
|
||||
task_id VARCHAR(36) NOT NULL COMMENT '自动标注任务ID',
|
||||
op_index INT NOT NULL COMMENT '算子顺序(从1开始)',
|
||||
operator_id VARCHAR(64) NOT NULL COMMENT '算子ID(raw_id)',
|
||||
settings_override JSON NULL COMMENT '任务级算子参数覆盖',
|
||||
inputs VARCHAR(64) NULL COMMENT '输入模态',
|
||||
outputs VARCHAR(64) NULL COMMENT '输出模态',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
UNIQUE KEY uk_task_op_index (task_id, op_index),
|
||||
KEY idx_task_id (task_id),
|
||||
KEY idx_operator_id (operator_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='标注任务算子实例表';
|
||||
Reference in New Issue
Block a user