feat(auto-annotation): integrate YOLO auto-labeling and enhance data management (#223)

* feat(auto-annotation): initial setup

* chore: remove package-lock.json

* chore: 清理本地测试脚本与 Maven 设置

* chore: change package-lock.json
This commit is contained in:
Kecheng Sha
2026-01-05 14:22:44 +08:00
committed by GitHub
parent ccfb84c034
commit 3f1ad6a872
44 changed files with 8503 additions and 5238 deletions

View File

@@ -1,60 +1,95 @@
"""
Tables of Annotation Management Module
"""
import uuid
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date, ForeignKey
from sqlalchemy.sql import func
from app.db.session import Base
class AnnotationTemplate(Base):
"""标注配置模板模型"""
__tablename__ = "t_dm_annotation_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(100), nullable=False, comment="模板名称")
description = Column(String(500), nullable=True, comment="模板描述")
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
version = Column(String(20), default='1.0', comment="模板版本")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None
class LabelingProject(Base):
"""标注项目模型"""
__tablename__ = "t_dm_labeling_projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
name = Column(String(100), nullable=False, comment="项目名称")
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
progress = Column(JSON, nullable=True, comment="项目进度信息")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
"""Tables of Annotation Management Module"""
import uuid
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
from sqlalchemy.sql import func
from app.db.session import Base
class AnnotationTemplate(Base):
"""标注配置模板模型"""
__tablename__ = "t_dm_annotation_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(100), nullable=False, comment="模板名称")
description = Column(String(500), nullable=True, comment="模板描述")
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
version = Column(String(20), default='1.0', comment="模板版本")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None
class LabelingProject(Base):
"""标注项目模型"""
__tablename__ = "t_dm_labeling_projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
name = Column(String(100), nullable=False, comment="项目名称")
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
progress = Column(JSON, nullable=True, comment="项目进度信息")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None
class AutoAnnotationTask(Base):
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
__tablename__ = "t_dm_auto_annotation_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(255), nullable=False, comment="任务名称")
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
dataset_name = Column(String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)")
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
file_ids = Column(JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像")
status = Column(String(50), nullable=False, default="pending", comment="任务状态: pending/running/completed/failed")
progress = Column(Integer, default=0, comment="任务进度 0-100")
total_images = Column(Integer, default=0, comment="总图片数")
processed_images = Column(Integer, default=0, comment="已处理图片数")
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
output_path = Column(String(500), nullable=True, comment="输出路径")
error_message = Column(Text, nullable=True, comment="错误信息")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(
TIMESTAMP,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
comment="更新时间",
)
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter
from .config import router as about_router
from .project import router as project_router
from .task import router as task_router
from .template import router as template_router
router = APIRouter(
prefix="/annotation",
tags = ["annotation"]
)
router.include_router(about_router)
router.include_router(project_router)
router.include_router(task_router)
router.include_router(template_router)
from fastapi import APIRouter
from .config import router as about_router
from .project import router as project_router
from .task import router as task_router
from .template import router as template_router
from .auto import router as auto_router
router = APIRouter(
prefix="/annotation",
tags = ["annotation"]
)
router.include_router(about_router)
router.include_router(project_router)
router.include_router(task_router)
router.include_router(template_router)
router.include_router(auto_router)

View File

@@ -0,0 +1,196 @@
"""FastAPI routes for Auto Annotation tasks.
These routes back the frontend AutoAnnotation module:
- GET /api/annotation/auto
- POST /api/annotation/auto
- DELETE /api/annotation/auto/{task_id}
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
"""
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
from ..service.auto import AutoAnnotationTaskService
router = APIRouter(
prefix="/auto",
tags=["annotation/auto"],
)
logger = get_logger(__name__)
service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db),
):
"""获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。
"""
tasks = await service.list_tasks(db)
return StandardResponse(
code=200,
message="success",
data=tasks,
)
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
):
"""创建自动标注任务。
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
"""
logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name,
request.dataset_id,
request.config.model_dump(by_alias=True),
request.file_ids,
)
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None
total_images = 0
try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
dataset = await dm_client.get_dataset(request.dataset_id)
if dataset is not None:
dataset_name = dataset.name
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
if request.file_ids:
total_images = len(request.file_ids)
else:
total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset name for auto task: %s", e)
task = await service.create_task(
db,
request,
dataset_name=dataset_name,
total_images=total_images,
)
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
"""获取单个自动标注任务状态。
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
"""
task = await service.get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=task,
)
@router.delete("/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id)
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
data=True,
)
@router.get("/{task_id}/download")
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
):
"""下载指定自动标注任务的结果 ZIP。"""
import io
import os
import zipfile
import tempfile
# 复用服务层获取任务信息
task = await service.get_task(db, task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not task.output_path:
raise HTTPException(status_code=400, detail="Task has no output path")
output_dir = task.output_path
if not os.path.isdir(output_dir):
raise HTTPException(status_code=404, detail="Output directory not found")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(output_dir):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, output_dir)
zf.write(file_path, arcname)
file_size = os.path.getsize(tmp_path)
if file_size == 0:
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
def iterfile():
with open(tmp_path, "rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
break
yield chunk
filename = f"{task.name}_annotations.zip"
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size),
}
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)

View File

@@ -0,0 +1,73 @@
"""Schemas for Auto Annotation tasks"""
from __future__ import annotations
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
class AutoAnnotationConfig(BaseModel):
"""自动标注任务配置(与前端 payload 对齐)"""
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
target_classes: List[int] = Field(
default_factory=list,
alias="targetClasses",
description="目标类别ID列表,空表示全部类别",
)
output_dataset_name: Optional[str] = Field(
default=None,
alias="outputDatasetName",
description="自动标注结果要写入的新数据集名称(可选)",
)
model_config = ConfigDict(populate_by_name=True)
class CreateAutoAnnotationTaskRequest(BaseModel):
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
config: AutoAnnotationConfig = Field(..., description="任务配置")
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
model_config = ConfigDict(populate_by_name=True)
class AutoAnnotationTaskResponse(BaseModel):
"""自动标注任务响应模型(列表/详情均可复用)"""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
source_datasets: Optional[List[str]] = Field(
default=None,
alias="sourceDatasets",
description="本任务实际处理涉及到的所有数据集名称列表",
)
config: Dict[str, Any] = Field(..., description="任务配置")
status: str = Field(..., description="任务状态")
progress: int = Field(..., description="任务进度 0-100")
total_images: int = Field(..., alias="totalImages", description="总图片数")
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
class AutoAnnotationTaskListResponse(BaseModel):
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
total: int = Field(..., description="总数")
model_config = ConfigDict(populate_by_name=True)

View File

@@ -0,0 +1,154 @@
"""Service layer for Auto Annotation tasks"""
from __future__ import annotations
from typing import List, Optional
from datetime import datetime
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
async def create_task(
self,
db: AsyncSession,
request: CreateAutoAnnotationTaskRequest,
dataset_name: Optional[str] = None,
total_images: int = 0,
) -> AutoAnnotationTaskResponse:
"""创建自动标注任务,初始状态为 pending。
这里仅插入任务记录,不负责真正执行 YOLO 推理,
后续可以由调度器/worker 读取该表并更新进度。
"""
now = datetime.now()
task = AutoAnnotationTask(
id=str(uuid4()),
name=request.name,
dataset_id=request.dataset_id,
dataset_name=dataset_name,
config=request.config.model_dump(by_alias=True),
file_ids=request.file_ids, # 存储用户选择的文件ID列表
status="pending",
progress=0,
total_images=total_images,
processed_images=0,
detected_objects=0,
created_at=now,
updated_at=now,
)
db.add(task)
await db.commit()
await db.refresh(task)
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
result = await db.execute(
select(AutoAnnotationTask)
.where(AutoAnnotationTask.deleted_at.is_(None))
.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = []
for task in tasks:
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
# 出错时降级为单个 datasetName/datasetId
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
responses.append(resp)
return responses
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
result = await db.execute(
select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
task = result.scalar_one_or_none()
if not task:
return None
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp
async def _compute_source_datasets(
self,
db: AsyncSession,
task: AutoAnnotationTask,
) -> List[str]:
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
"""
file_ids = task.file_ids or []
if file_ids:
stmt = (
select(Dataset.name)
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
.where(DatasetFiles.id.in_(file_ids))
.distinct()
)
result = await db.execute(stmt)
names = [row[0] for row in result.fetchall() if row[0]]
if names:
return names
# 回退:只显示一个数据集
if task.dataset_name:
return [task.dataset_name]
if task.dataset_id:
return [task.dataset_id]
return []
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
result = await db.execute(
select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
task = result.scalar_one_or_none()
if not task:
return False
task.deleted_at = datetime.now()
await db.commit()
return True

View File

@@ -1,8 +1,8 @@
#!/bin/bash
set -e
if [-d $LOCAL_FILES_DOCUMENT_ROOT ] && $LOCAL_FILES_SERVING_ENABLED; then
echo "Using local document root: $LOCAL_FILES_DOCUMENT_ROOT"
if [ -d "${LOCAL_FILES_DOCUMENT_ROOT}" ] && [ "${LOCAL_FILES_SERVING_ENABLED}" = "true" ]; then
echo "Using local document root: ${LOCAL_FILES_DOCUMENT_ROOT}"
fi
# 启动应用

17
runtime/ops/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
"""Datamate built-in operators package.
This package contains built-in operators for filtering, slicing, annotation, etc.
It is mounted into the runtime container under ``datamate.ops`` so that
``from datamate.ops.annotation...`` imports work correctly.
"""
__all__ = [
"annotation",
"filter",
"formatter",
"llms",
"mapper",
"slicer",
"user",
]

View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
"""Annotation-related operators (e.g. YOLO detection)."""
__all__ = [
"image_object_detection_bounding_box",
]

View File

@@ -0,0 +1,9 @@
"""Image object detection (YOLOv8) operator package.
This package exposes the ImageObjectDetectionBoundingBox annotator so that
the auto-annotation worker can import it via different module paths.
"""
from .process import ImageObjectDetectionBoundingBox
__all__ = ["ImageObjectDetectionBoundingBox"]

View File

@@ -0,0 +1,3 @@
name: image_object_detection_bounding_box
version: 0.1.0
description: "YOLOv8-based object detection operator for auto annotation"

View File

@@ -0,0 +1,214 @@
#!/user/bin/python
# -- encoding: utf-8 --
"""
Description: 图像目标检测算子
Create: 2025/12/17
"""
import os
import json
import time
from typing import Dict, Any
import cv2
import numpy as np
from loguru import logger
try:
from ultralytics import YOLO
except ImportError:
logger.warning("ultralytics not installed. Please install it using: pip install ultralytics")
YOLO = None
from datamate.core.base_op import Mapper
# COCO 80 类别映射
COCO_CLASS_MAP = {
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
37: "surfboard", 38: "tennis racket", 39: "bottle",
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
54: "donut", 55: "cake", 56: "chair", 57: "couch",
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
78: "hair drier", 79: "toothbrush"
}
class ImageObjectDetectionBoundingBox(Mapper):
"""图像目标检测算子"""
# 模型映射
MODEL_MAP = {
"n": "yolov8n.pt",
"s": "yolov8s.pt",
"m": "yolov8m.pt",
"l": "yolov8l.pt",
"x": "yolov8x.pt",
}
def __init__(self, *args, **kwargs):
super(ImageObjectDetectionBoundingBox, self).__init__(*args, **kwargs)
# 获取参数
self._model_size = kwargs.get("modelSize", "l")
self._conf_threshold = kwargs.get("confThreshold", 0.7)
self._target_classes = kwargs.get("targetClasses", [])
self._output_dir = kwargs.get("outputDir", None) # 输出目录
# 如果目标类别为空列表,则检测所有类别
if not self._target_classes:
self._target_classes = None
else:
# 确保是整数列表
self._target_classes = [int(cls_id) for cls_id in self._target_classes]
# 获取模型路径
model_filename = self.MODEL_MAP.get(self._model_size, "yolov8l.pt")
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, model_filename)
# 初始化模型
if YOLO is None:
raise ImportError("ultralytics is not installed. Please install it.")
if not os.path.exists(model_path):
logger.warning(f"Model file {model_path} not found. Downloading from ultralytics...")
self.model = YOLO(model_filename) # 自动下载
else:
self.model = YOLO(model_path)
logger.info(f"Loaded YOLOv8 model: {model_filename}, "
f"conf_threshold: {self._conf_threshold}, "
f"target_classes: {self._target_classes}")
@staticmethod
def _get_color_by_class_id(class_id: int):
"""根据 class_id 生成稳定颜色(BGR,OpenCV 用)"""
np.random.seed(class_id)
color = np.random.randint(0, 255, size=3).tolist()
return tuple(color)
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""执行目标检测"""
start = time.time()
# 读取图像文件
image_path = sample.get(self.image_key)
if not image_path or not os.path.exists(image_path):
logger.warning(f"Image file not found: {image_path}")
return sample
# 读取图像
img = cv2.imread(image_path)
if img is None:
logger.warning(f"Failed to read image: {image_path}")
return sample
# 执行目标检测
results = self.model(img, conf=self._conf_threshold)
r = results[0]
# 准备标注数据
h, w = img.shape[:2]
annotations = {
"image": os.path.basename(image_path),
"width": w,
"height": h,
"model_size": self._model_size,
"conf_threshold": self._conf_threshold,
"selected_class_ids": self._target_classes,
"detections": []
}
# 处理检测结果
if r.boxes is not None:
for box in r.boxes:
cls_id = int(box.cls[0])
# 过滤目标类别
if self._target_classes is not None and cls_id not in self._target_classes:
continue
conf = float(box.conf[0])
x1, y1, x2, y2 = map(float, box.xyxy[0])
label = COCO_CLASS_MAP.get(cls_id, f"class_{cls_id}")
# 记录检测结果
annotations["detections"].append({
"label": label,
"class_id": cls_id,
"confidence": round(conf, 4),
"bbox_xyxy": [x1, y1, x2, y2],
"bbox_xywh": [x1, y1, x2 - x1, y2 - y1]
})
# 在图像上绘制
color = self._get_color_by_class_id(cls_id)
cv2.rectangle(
img,
(int(x1), int(y1)),
(int(x2), int(y2)),
color,
2
)
cv2.putText(
img,
f"{label} {conf:.2f}",
(int(x1), max(int(y1) - 5, 10)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
1
)
# 确定输出目录
if self._output_dir and os.path.exists(self._output_dir):
output_dir = self._output_dir
else:
output_dir = os.path.dirname(image_path)
# 创建输出子目录(可选,用于组织文件)
images_dir = os.path.join(output_dir, "images")
annotations_dir = os.path.join(output_dir, "annotations")
os.makedirs(images_dir, exist_ok=True)
os.makedirs(annotations_dir, exist_ok=True)
# 保持原始文件名(不添加后缀),确保一一对应
base_name = os.path.basename(image_path)
name_without_ext = os.path.splitext(base_name)[0]
# 保存标注图像(保持原始扩展名或使用jpg)
output_filename = base_name
output_path = os.path.join(images_dir, output_filename)
cv2.imwrite(output_path, img)
# 保存标注 JSON(文件名与图像对应)
json_filename = f"{name_without_ext}.json"
json_path = os.path.join(annotations_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotations, f, indent=2, ensure_ascii=False)
# 更新样本数据
sample["detection_count"] = len(annotations["detections"])
sample["output_image"] = output_path
sample["annotations_file"] = json_path
sample["annotations"] = annotations
logger.info(f"Image: {os.path.basename(image_path)}, "
f"Detections: {len(annotations['detections'])}, "
f"Time: {(time.time() - start):.4f}s")
return sample

View File

@@ -0,0 +1,166 @@
import os
import json
from pathlib import Path
from ultralytics import YOLO
import cv2
import numpy as np
def get_color_by_class_id(class_id: int):
"""根据 class_id 生成稳定颜色(BGR)"""
np.random.seed(class_id)
color = np.random.randint(0, 255, size=3).tolist()
return tuple(color)
def mask_to_polygons(mask: np.ndarray):
"""将二值 mask 转换为 COCO 风格多边形列表"""
contours, _ = cv2.findContours(
mask,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
polygons = []
for contour in contours:
if contour.shape[0] < 3:
continue
polygon = contour.flatten().tolist()
polygons.append(polygon)
return polygons
IMAGE_DIR = "C:/Users/meta/Desktop/Datamate/yolo/Photos"
OUT_IMG_DIR = "outputs_seg/images"
OUT_JSON_DIR = "outputs_seg/annotations"
MODEL_MAP = {
"n": "yolov8n-seg.pt",
"s": "yolov8s-seg.pt",
"m": "yolov8m-seg.pt",
"l": "yolov8l-seg.pt",
"x": "yolov8x-seg.pt",
}
MODEL_KEY = "x"
MODEL_PATH = MODEL_MAP[MODEL_KEY]
CONF_THRES = 0.7
DRAW_BBOX = True
COCO_CLASS_MAP = {
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
37: "surfboard", 38: "tennis racket", 39: "bottle",
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
54: "donut", 55: "cake", 56: "chair", 57: "couch",
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
78: "hair drier", 79: "toothbrush"
}
TARGET_CLASS_IDS = [0, 2, 5]
os.makedirs(OUT_IMG_DIR, exist_ok=True)
os.makedirs(OUT_JSON_DIR, exist_ok=True)
if TARGET_CLASS_IDS is not None:
for cid in TARGET_CLASS_IDS:
if cid not in COCO_CLASS_MAP:
raise ValueError(f"Invalid class id: {cid}")
model = YOLO(MODEL_PATH)
image_paths = list(Path(IMAGE_DIR).glob("*.*"))
for img_path in image_paths:
img = cv2.imread(str(img_path))
if img is None:
print(f"[WARN] Failed to read {img_path}")
continue
results = model(img, conf=CONF_THRES)
r = results[0]
h, w = img.shape[:2]
annotations = {
"image": img_path.name,
"width": w,
"height": h,
"model_key": MODEL_KEY,
"conf_threshold": CONF_THRES,
"supported_classes": COCO_CLASS_MAP,
"selected_class_ids": TARGET_CLASS_IDS,
"instances": []
}
if r.boxes is not None and r.masks is not None:
for i, box in enumerate(r.boxes):
cls_id = int(box.cls[0])
if TARGET_CLASS_IDS is not None and cls_id not in TARGET_CLASS_IDS:
continue
conf = float(box.conf[0])
x1, y1, x2, y2 = map(float, box.xyxy[0])
label = COCO_CLASS_MAP[cls_id]
mask = r.masks.data[i].cpu().numpy()
mask = (mask > 0.5).astype(np.uint8)
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
color = get_color_by_class_id(cls_id)
img[mask == 1] = (
img[mask == 1] * 0.5 + np.array(color) * 0.5
).astype(np.uint8)
if True:
cv2.rectangle(
img,
(int(x1), int(y1)),
(int(x2), int(y2)),
color,
2
)
cv2.putText(
img,
f"{label} {conf:.2f}",
(int(x1), max(int(y1) - 5, 10)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
color,
1
)
polygons = mask_to_polygons(mask)
annotations["instances"].append({
"label": label,
"class_id": cls_id,
"confidence": round(conf, 4),
"bbox_xyxy": [x1, y1, x2, y2],
"segmentation": polygons
})
out_img_path = os.path.join(OUT_IMG_DIR, img_path.name)
out_json_path = os.path.join(OUT_JSON_DIR, img_path.stem + ".json")
cv2.imwrite(out_img_path, img)
with open(out_json_path, "w", encoding="utf-8") as f:
json.dump(annotations, f, indent=2, ensure_ascii=False)
print(f"[OK] {img_path.name}")
print("Segmentation batch finished.")

View File

@@ -31,4 +31,5 @@ dependencies = [
"sqlalchemy>=2.0.44",
"xmltodict>=1.0.2",
"zhconv>=1.4.3",
"ultralytics>=8.0.0",
]

View File

@@ -0,0 +1,603 @@
# -*- coding: utf-8 -*-
"""Simple background worker for auto-annotation tasks.
This module runs inside the datamate-runtime container (operator_runtime service).
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
inference using the ImageObjectDetectionBoundingBox operator, updating
progress back to the same table so that the datamate-python backend and
frontend can display real-time status.
设计目标(最小可用版本):
- 单实例 worker,串行处理 `pending` 状态的任务。
- 对指定数据集下的所有已完成文件逐张执行目标检测。
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
- 失败时将任务标记为 `failed` 并记录 `error_message`。
注意:
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时,
已处于 running 的任务不会被重新拉起,需要后续扩展。
"""
from __future__ import annotations
import json
import os
import sys
import threading
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
from sqlalchemy import text
from datamate.sql_manager.sql_manager import SQLManager
# 尝试多种导入路径,适配不同的打包/安装方式
ImageObjectDetectionBoundingBox = None # type: ignore
try:
# 优先使用 datamate.ops 路径(源码 COPY 到 /opt/runtime/datamate/ops 情况)
from datamate.ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from datamate.ops.annotation.image_object_detection_bounding_box",
)
except Exception as e1: # pragma: no cover - 导入失败时仅记录日志,避免整体崩溃
logger.error(
"Failed to import ImageObjectDetectionBoundingBox via datamate.ops: {}",
e1,
)
try:
# 兼容顶层 ops 包安装的情况(通过 ops.pth 暴露)
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from top-level ops.annotation.image_object_detection_bounding_box",
)
except Exception as e2:
logger.error(
"Failed to import ImageObjectDetectionBoundingBox via top-level ops package: {}",
e2,
)
ImageObjectDetectionBoundingBox = None
# 进一步兜底:直接从本地 runtime/ops 目录加载算子(开发环境常用场景)
if ImageObjectDetectionBoundingBox is None:
try:
project_root = Path(__file__).resolve().parents[2]
ops_root = project_root / "ops"
if ops_root.is_dir():
# 确保 ops 的父目录在 sys.path 中,这样可以按 "ops.xxx" 导入
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from local runtime/ops.annotation.image_object_detection_bounding_box",
)
else:
logger.warning(
"Local runtime/ops directory not found when trying to import ImageObjectDetectionBoundingBox: {}",
ops_root,
)
except Exception as e3: # pragma: no cover - 兜底失败仅记录日志
logger.error(
"Failed to import ImageObjectDetectionBoundingBox from local runtime/ops: {}",
e3,
)
ImageObjectDetectionBoundingBox = None
POLL_INTERVAL_SECONDS = float(os.getenv("AUTO_ANNOTATION_POLL_INTERVAL", "5"))
DEFAULT_OUTPUT_ROOT = os.getenv(
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
)
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
"""从 t_dm_auto_annotation_tasks 中取出一个 pending 任务。"""
sql = text(
"""
SELECT id, name, dataset_id, dataset_name, config, file_ids, status,
total_images, processed_images, detected_objects, output_path
FROM t_dm_auto_annotation_tasks
WHERE status = 'pending' AND deleted_at IS NULL
ORDER BY created_at ASC
LIMIT 1
"""
)
with SQLManager.create_connect() as conn:
result = conn.execute(sql).fetchone()
if not result:
return None
row = dict(result._mapping) # type: ignore[attr-defined]
try:
row["config"] = json.loads(row["config"]) if row.get("config") else {}
except Exception:
row["config"] = {}
try:
raw_ids = row.get("file_ids")
if not raw_ids:
row["file_ids"] = None
elif isinstance(raw_ids, str):
row["file_ids"] = json.loads(raw_ids)
else:
row["file_ids"] = raw_ids
except Exception:
row["file_ids"] = None
return row
def _update_task_status(
task_id: str,
*,
status: str,
progress: Optional[int] = None,
processed_images: Optional[int] = None,
detected_objects: Optional[int] = None,
total_images: Optional[int] = None,
output_path: Optional[str] = None,
error_message: Optional[str] = None,
completed: bool = False,
) -> None:
"""更新任务的状态和统计字段。"""
fields: List[str] = ["status = :status", "updated_at = :updated_at"]
params: Dict[str, Any] = {
"task_id": task_id,
"status": status,
"updated_at": datetime.now(),
}
if progress is not None:
fields.append("progress = :progress")
params["progress"] = int(progress)
if processed_images is not None:
fields.append("processed_images = :processed_images")
params["processed_images"] = int(processed_images)
if detected_objects is not None:
fields.append("detected_objects = :detected_objects")
params["detected_objects"] = int(detected_objects)
if total_images is not None:
fields.append("total_images = :total_images")
params["total_images"] = int(total_images)
if output_path is not None:
fields.append("output_path = :output_path")
params["output_path"] = output_path
if error_message is not None:
fields.append("error_message = :error_message")
params["error_message"] = error_message[:2000]
if completed:
fields.append("completed_at = :completed_at")
params["completed_at"] = datetime.now()
sql = text(
f"""
UPDATE t_dm_auto_annotation_tasks
SET {', '.join(fields)}
WHERE id = :task_id
"""
)
with SQLManager.create_connect() as conn:
conn.execute(sql, params)
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
"""加载指定数据集下的所有已完成文件。"""
sql = text(
"""
SELECT id, file_path, file_name
FROM t_dm_dataset_files
WHERE dataset_id = :dataset_id
AND status = 'ACTIVE'
ORDER BY created_at ASC
"""
)
with SQLManager.create_connect() as conn:
rows = conn.execute(sql, {"dataset_id": dataset_id}).fetchall()
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
def _load_files_by_ids(file_ids: List[str]) -> List[Tuple[str, str, str]]:
"""根据文件ID列表加载文件记录,支持跨多个数据集。"""
if not file_ids:
return []
placeholders = ", ".join(f":id{i}" for i in range(len(file_ids)))
sql = text(
f"""
SELECT id, file_path, file_name
FROM t_dm_dataset_files
WHERE id IN ({placeholders})
AND status = 'ACTIVE'
ORDER BY created_at ASC
"""
)
params = {f"id{i}": str(fid) for i, fid in enumerate(file_ids)}
with SQLManager.create_connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
def _ensure_output_dir(output_dir: str) -> str:
"""确保输出目录及其 images/、annotations/ 子目录存在。"""
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "annotations"), exist_ok=True)
return output_dir
def _create_output_dataset(
source_dataset_id: str,
source_dataset_name: str,
output_dataset_name: str,
) -> Tuple[str, str]:
"""为自动标注结果创建一个新的数据集并返回 (dataset_id, path)。"""
new_dataset_id = str(uuid.uuid4())
dataset_base_path = DEFAULT_OUTPUT_ROOT.rstrip("/") or "/dataset"
output_dir = os.path.join(dataset_base_path, new_dataset_id)
description = (
f"Auto annotations for dataset {source_dataset_name or source_dataset_id}"[:255]
)
sql = text(
"""
INSERT INTO t_dm_datasets (id, name, description, dataset_type, path, status)
VALUES (:id, :name, :description, :dataset_type, :path, :status)
"""
)
params = {
"id": new_dataset_id,
"name": output_dataset_name,
"description": description,
"dataset_type": "IMAGE",
"path": output_dir,
"status": "ACTIVE",
}
with SQLManager.create_connect() as conn:
conn.execute(sql, params)
return new_dataset_id, output_dir
def _register_output_dataset(
task_id: str,
output_dataset_id: str,
output_dir: str,
output_dataset_name: str,
total_images: int,
) -> None:
"""将自动标注结果注册到新建的数据集。"""
images_dir = os.path.join(output_dir, "images")
if not os.path.isdir(images_dir):
logger.warning(
"Auto-annotation images directory not found for task {}: {}",
task_id,
images_dir,
)
return
image_files: List[Tuple[str, str, int]] = []
annotation_files: List[Tuple[str, str, int]] = []
total_size = 0
for file_name in sorted(os.listdir(images_dir)):
file_path = os.path.join(images_dir, file_name)
if not os.path.isfile(file_path):
continue
try:
file_size = os.path.getsize(file_path)
except OSError:
file_size = 0
image_files.append((file_name, file_path, int(file_size)))
total_size += int(file_size)
annotations_dir = os.path.join(output_dir, "annotations")
if os.path.isdir(annotations_dir):
for file_name in sorted(os.listdir(annotations_dir)):
file_path = os.path.join(annotations_dir, file_name)
if not os.path.isfile(file_path):
continue
try:
file_size = os.path.getsize(file_path)
except OSError:
file_size = 0
annotation_files.append((file_name, file_path, int(file_size)))
total_size += int(file_size)
if not image_files:
logger.warning(
"No image files found in auto-annotation output for task {}: {}",
task_id,
images_dir,
)
return
insert_file_sql = text(
"""
INSERT INTO t_dm_dataset_files (
id, dataset_id, file_name, file_path, file_type, file_size, status
) VALUES (
:id, :dataset_id, :file_name, :file_path, :file_type, :file_size, :status
)
"""
)
update_dataset_stat_sql = text(
"""
UPDATE t_dm_datasets
SET file_count = COALESCE(file_count, 0) + :add_count,
size_bytes = COALESCE(size_bytes, 0) + :add_size
WHERE id = :dataset_id
"""
)
with SQLManager.create_connect() as conn:
added_count = 0
for file_name, file_path, file_size in image_files:
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
conn.execute(
insert_file_sql,
{
"id": str(uuid.uuid4()),
"dataset_id": output_dataset_id,
"file_name": file_name,
"file_path": file_path,
"file_type": ext,
"file_size": int(file_size),
"status": "ACTIVE",
},
)
added_count += 1
for file_name, file_path, file_size in annotation_files:
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
conn.execute(
insert_file_sql,
{
"id": str(uuid.uuid4()),
"dataset_id": output_dataset_id,
"file_name": file_name,
"file_path": file_path,
"file_type": ext,
"file_size": int(file_size),
"status": "ACTIVE",
},
)
added_count += 1
if added_count > 0:
conn.execute(
update_dataset_stat_sql,
{
"dataset_id": output_dataset_id,
"add_count": added_count,
"add_size": int(total_size),
},
)
logger.info(
"Registered auto-annotation output into dataset: dataset_id={}, name={}, added_files={}, added_size_bytes={}, task_id={}, output_dir={}",
output_dataset_id,
output_dataset_name,
len(image_files) + len(annotation_files),
total_size,
task_id,
output_dir,
)
def _process_single_task(task: Dict[str, Any]) -> None:
"""执行单个自动标注任务。"""
if ImageObjectDetectionBoundingBox is None:
logger.error(
"YOLO operator not available (import failed earlier), skip auto-annotation task: {}",
task["id"],
)
_update_task_status(
task["id"],
status="failed",
error_message="YOLO operator not available in runtime container",
)
return
task_id = str(task["id"])
dataset_id = str(task["dataset_id"])
task_name = str(task.get("name") or "")
source_dataset_name = str(task.get("dataset_name") or "")
cfg: Dict[str, Any] = task.get("config") or {}
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
model_size = cfg.get("modelSize", "l")
conf_threshold = float(cfg.get("confThreshold", 0.7))
target_classes = cfg.get("targetClasses", []) or []
output_dataset_name = cfg.get("outputDatasetName")
if not output_dataset_name:
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
output_dataset_name = f"{base_name}_auto_{task_id[:8]}"
logger.info(
"Start processing auto-annotation task: id={}, dataset_id={}, model_size={}, conf_threshold={}, target_classes={}, output_dataset_name={}",
task_id,
dataset_id,
model_size,
conf_threshold,
target_classes,
output_dataset_name,
)
_update_task_status(task_id, status="running", progress=0)
if selected_file_ids:
all_files = _load_files_by_ids(selected_file_ids)
else:
all_files = _load_dataset_files(dataset_id)
files = [(path, name) for _, path, name in all_files]
total_images = len(files)
if total_images == 0:
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
_update_task_status(
task_id,
status="completed",
progress=100,
total_images=0,
processed_images=0,
detected_objects=0,
completed=True,
output_path=None,
)
return
output_dataset_id, output_dir = _create_output_dataset(
source_dataset_id=dataset_id,
source_dataset_name=source_dataset_name,
output_dataset_name=output_dataset_name,
)
output_dir = _ensure_output_dir(output_dir)
try:
detector = ImageObjectDetectionBoundingBox(
modelSize=model_size,
confThreshold=conf_threshold,
targetClasses=target_classes,
outputDir=output_dir,
)
except Exception as e:
logger.error("Failed to init YOLO detector for task {}: {}", task_id, e)
_update_task_status(
task_id,
status="failed",
total_images=total_images,
processed_images=0,
detected_objects=0,
error_message=f"Init YOLO detector failed: {e}",
)
return
processed = 0
detected_total = 0
for file_path, file_name in files:
try:
sample = {
"image": file_path,
"filename": file_name,
}
result = detector.execute(sample)
annotations = (result or {}).get("annotations", {})
detections = annotations.get("detections", [])
detected_total += len(detections)
processed += 1
progress = int(processed * 100 / total_images) if total_images > 0 else 100
_update_task_status(
task_id,
status="running",
progress=progress,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
)
except Exception as e:
logger.error(
"Failed to process image for task {}: file_path={}, error={}",
task_id,
file_path,
e,
)
continue
_update_task_status(
task_id,
status="completed",
progress=100,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
completed=True,
)
logger.info(
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
task_id,
total_images,
processed,
detected_total,
output_dir,
)
if output_dataset_name and output_dataset_id:
try:
_register_output_dataset(
task_id=task_id,
output_dataset_id=output_dataset_id,
output_dir=output_dir,
output_dataset_name=output_dataset_name,
total_images=total_images,
)
except Exception as e: # pragma: no cover - 防御性日志
logger.error(
"Failed to register auto-annotation output as dataset for task {}: {}",
task_id,
e,
)
def _worker_loop() -> None:
"""Worker 主循环,在独立线程中运行。"""
logger.info(
"Auto-annotation worker started with poll interval {} seconds, output root {}",
POLL_INTERVAL_SECONDS,
DEFAULT_OUTPUT_ROOT,
)
while True:
try:
task = _fetch_pending_task()
if not task:
time.sleep(POLL_INTERVAL_SECONDS)
continue
_process_single_task(task)
except Exception as e: # pragma: no cover - 防御性日志
logger.error("Auto-annotation worker loop error: {}", e)
time.sleep(POLL_INTERVAL_SECONDS)
def start_auto_annotation_worker() -> None:
"""在后台线程中启动自动标注 worker。"""
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
thread.start()
logger.info("Auto-annotation worker thread started: {}", thread.name)

View File

@@ -1,163 +1,174 @@
import os
from typing import Optional, Dict, Any, List
import uvicorn
import yaml
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from jsonargparse import ArgumentParser
from loguru import logger
from pydantic import BaseModel
from datamate.common.error_code import ErrorCode
from datamate.scheduler import cmd_scheduler
from datamate.scheduler import func_scheduler
from datamate.wrappers import WRAPPERS
# 日志配置
LOG_DIR = "/var/log/datamate/runtime"
os.makedirs(LOG_DIR, exist_ok=True)
logger.add(
f"{LOG_DIR}/runtime.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
level="DEBUG",
enqueue=True
)
app = FastAPI()
class APIException(Exception):
"""自定义API异常"""
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
extra_data: Optional[Dict] = None):
self.error_code = error_code
self.detail = detail or error_code.value[1]
self.code = error_code.value[0]
self.extra_data = extra_data
super().__init__(self.detail)
def to_dict(self) -> Dict[str, Any]:
result = {
"code": self.code,
"message": self.detail,
"success": False
}
if self.extra_data:
result["data"] = self.extra_data
return result
@app.exception_handler(APIException)
async def api_exception_handler(request: Request, exc: APIException):
return JSONResponse(
status_code=200, # 业务错误返回 200,错误信息在响应体中
content=exc.to_dict()
)
class QueryTaskRequest(BaseModel):
task_ids: List[str]
@app.post("/api/task/list")
async def query_task_info(request: QueryTaskRequest):
try:
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
except Exception as e:
raise APIException(ErrorCode.UNKNOWN_ERROR)
@app.post("/api/task/{task_id}/submit")
async def submit_task(task_id):
config_path = f"/flow/{task_id}/process.yaml"
logger.info("Start submitting job...")
dataset_path = get_from_cfg(task_id, "dataset_path")
if not check_valid_path(dataset_path):
logger.error(f"dataset_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
try:
executor_type = get_from_cfg(task_id, "executor_type")
await WRAPPERS.get(executor_type).submit(task_id, config_path)
except Exception as e:
logger.error(f"Error happens during submitting task. Error Info following: {e}")
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
logger.info(f"task id: {task_id} has been submitted.")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been submitted"},
status_code=200
)
return success_json_info
@app.post("/api/task/{task_id}/stop")
async def stop_task(task_id):
logger.info("Start stopping ray job...")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been stopped"},
status_code=200
)
try:
executor_type = get_from_cfg(task_id, "executor_type")
if not WRAPPERS.get(executor_type).cancel(task_id):
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
except Exception as e:
if isinstance(e, APIException):
raise e
raise APIException(ErrorCode.UNKNOWN_ERROR)
logger.info(f"{task_id} has been stopped.")
return success_json_info
def check_valid_path(file_path):
full_path = os.path.abspath(file_path)
return os.path.exists(full_path)
def get_from_cfg(task_id, key):
config_path = f"/flow/{task_id}/process.yaml"
if not check_valid_path(config_path):
logger.error(f"config_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
with open(config_path, "r", encoding='utf-8') as f:
content = f.read()
cfg = yaml.safe_load(content)
return cfg[key]
def parse_args():
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
parser.add_argument(
'--ip',
type=str,
default="0.0.0.0",
help='Service ip for this API, default to use 0.0.0.0.'
)
parser.add_argument(
'--port',
type=int,
default=8080,
help='Service port for this API, default to use 8600.'
)
return parser.parse_args()
if __name__ == '__main__':
p_args = parse_args()
uvicorn.run(
app,
host=p_args.ip,
port=p_args.port
)
import os
from typing import Optional, Dict, Any, List
import uvicorn
import yaml
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from jsonargparse import ArgumentParser
from loguru import logger
from pydantic import BaseModel
from datamate.common.error_code import ErrorCode
from datamate.scheduler import cmd_scheduler
from datamate.scheduler import func_scheduler
from datamate.wrappers import WRAPPERS
from datamate.auto_annotation_worker import start_auto_annotation_worker
# 日志配置
LOG_DIR = "/var/log/datamate/runtime"
os.makedirs(LOG_DIR, exist_ok=True)
logger.add(
f"{LOG_DIR}/runtime.log",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
level="DEBUG",
enqueue=True
)
app = FastAPI()
class APIException(Exception):
"""自定义API异常"""
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
extra_data: Optional[Dict] = None):
self.error_code = error_code
self.detail = detail or error_code.value[1]
self.code = error_code.value[0]
self.extra_data = extra_data
super().__init__(self.detail)
def to_dict(self) -> Dict[str, Any]:
result = {
"code": self.code,
"message": self.detail,
"success": False
}
if self.extra_data:
result["data"] = self.extra_data
return result
@app.on_event("startup")
async def startup_event():
"""FastAPI 启动时初始化后台自动标注 worker。"""
try:
start_auto_annotation_worker()
except Exception as e: # pragma: no cover - 防御性日志
logger.error("Failed to start auto-annotation worker: {}", e)
@app.exception_handler(APIException)
async def api_exception_handler(request: Request, exc: APIException):
return JSONResponse(
status_code=200, # 业务错误返回 200,错误信息在响应体中
content=exc.to_dict()
)
class QueryTaskRequest(BaseModel):
task_ids: List[str]
@app.post("/api/task/list")
async def query_task_info(request: QueryTaskRequest):
try:
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
except Exception as e:
raise APIException(ErrorCode.UNKNOWN_ERROR)
@app.post("/api/task/{task_id}/submit")
async def submit_task(task_id):
config_path = f"/flow/{task_id}/process.yaml"
logger.info("Start submitting job...")
dataset_path = get_from_cfg(task_id, "dataset_path")
if not check_valid_path(dataset_path):
logger.error(f"dataset_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
try:
executor_type = get_from_cfg(task_id, "executor_type")
await WRAPPERS.get(executor_type).submit(task_id, config_path)
except Exception as e:
logger.error(f"Error happens during submitting task. Error Info following: {e}")
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
logger.info(f"task id: {task_id} has been submitted.")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been submitted"},
status_code=200
)
return success_json_info
@app.post("/api/task/{task_id}/stop")
async def stop_task(task_id):
logger.info("Start stopping ray job...")
success_json_info = JSONResponse(
content={"status": "Success", "message": f"{task_id} has been stopped"},
status_code=200
)
try:
executor_type = get_from_cfg(task_id, "executor_type")
if not WRAPPERS.get(executor_type).cancel(task_id):
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
except Exception as e:
if isinstance(e, APIException):
raise e
raise APIException(ErrorCode.UNKNOWN_ERROR)
logger.info(f"{task_id} has been stopped.")
return success_json_info
def check_valid_path(file_path):
full_path = os.path.abspath(file_path)
return os.path.exists(full_path)
def get_from_cfg(task_id, key):
config_path = f"/flow/{task_id}/process.yaml"
if not check_valid_path(config_path):
logger.error(f"config_path is not existed! please check this path.")
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
with open(config_path, "r", encoding='utf-8') as f:
content = f.read()
cfg = yaml.safe_load(content)
return cfg[key]
def parse_args():
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
parser.add_argument(
'--ip',
type=str,
default="0.0.0.0",
help='Service ip for this API, default to use 0.0.0.0.'
)
parser.add_argument(
'--port',
type=int,
default=8080,
help='Service port for this API, default to use 8600.'
)
return parser.parse_args()
if __name__ == '__main__':
p_args = parse_args()
uvicorn.run(
app,
host=p_args.ip,
port=p_args.port
)