You've already forked DataMate
feat(auto-annotation): integrate YOLO auto-labeling and enhance data management (#223)
* feat(auto-annotation): initial setup * chore: remove package-lock.json * chore: 清理本地测试脚本与 Maven 设置 * chore: change package-lock.json
This commit is contained in:
@@ -1,60 +1,95 @@
|
||||
"""
|
||||
Tables of Annotation Management Module
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class AnnotationTemplate(Base):
|
||||
"""标注配置模板模型"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(100), nullable=False, comment="模板名称")
|
||||
description = Column(String(500), nullable=True, comment="模板描述")
|
||||
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
|
||||
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
|
||||
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
|
||||
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
|
||||
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
version = Column(String(20), default='1.0', comment="模板版本")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
class LabelingProject(Base):
|
||||
"""标注项目模型"""
|
||||
|
||||
__tablename__ = "t_dm_labeling_projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
name = Column(String(100), nullable=False, comment="项目名称")
|
||||
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
|
||||
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
|
||||
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
|
||||
progress = Column(JSON, nullable=True, comment="项目进度信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
"""Tables of Annotation Management Module"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class AnnotationTemplate(Base):
|
||||
"""标注配置模板模型"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(100), nullable=False, comment="模板名称")
|
||||
description = Column(String(500), nullable=True, comment="模板描述")
|
||||
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
|
||||
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
|
||||
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
|
||||
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
|
||||
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
version = Column(String(20), default='1.0', comment="模板版本")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
class LabelingProject(Base):
|
||||
"""标注项目模型"""
|
||||
|
||||
__tablename__ = "t_dm_labeling_projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
name = Column(String(100), nullable=False, comment="项目名称")
|
||||
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
|
||||
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
|
||||
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
|
||||
progress = Column(JSON, nullable=True, comment="项目进度信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
|
||||
class AutoAnnotationTask(Base):
|
||||
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
|
||||
|
||||
__tablename__ = "t_dm_auto_annotation_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(255), nullable=False, comment="任务名称")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
dataset_name = Column(String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)")
|
||||
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
||||
file_ids = Column(JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像")
|
||||
status = Column(String(50), nullable=False, default="pending", comment="任务状态: pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
||||
total_images = Column(Integer, default=0, comment="总图片数")
|
||||
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
||||
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
||||
output_path = Column(String(500), nullable=True, comment="输出路径")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
comment="更新时间",
|
||||
)
|
||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
|
||||
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .config import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
tags = ["annotation"]
|
||||
)
|
||||
|
||||
router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .config import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
from .auto import router as auto_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
tags = ["annotation"]
|
||||
)
|
||||
|
||||
router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
router.include_router(auto_router)
|
||||
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""FastAPI routes for Auto Annotation tasks.
|
||||
|
||||
These routes back the frontend AutoAnnotation module:
|
||||
- GET /api/annotation/auto
|
||||
- POST /api/annotation/auto
|
||||
- DELETE /api/annotation/auto/{task_id}
|
||||
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auto",
|
||||
tags=["annotation/auto"],
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_auto_annotation_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取自动标注任务列表。
|
||||
|
||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||
"""
|
||||
|
||||
tasks = await service.list_tasks(db)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建自动标注任务。
|
||||
|
||||
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
request.dataset_id,
|
||||
request.config.model_dump(by_alias=True),
|
||||
request.file_ids,
|
||||
)
|
||||
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||
dataset = await dm_client.get_dataset(request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
|
||||
if request.file_ids:
|
||||
total_images = len(request.file_ids)
|
||||
else:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset name for auto task: %s", e)
|
||||
|
||||
task = await service.create_task(
|
||||
db,
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个自动标注任务状态。
|
||||
|
||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||
"""
|
||||
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=StandardResponse[bool])
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
ok = await service.soft_delete_task(db, task_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=True,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/download")
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not task.output_path:
|
||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||
|
||||
output_dir = task.output_path
|
||||
if not os.path.isdir(output_dir):
|
||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||
os.close(tmp_fd)
|
||||
|
||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
arcname = os.path.relpath(file_path, output_dir)
|
||||
zf.write(file_path, arcname)
|
||||
|
||||
file_size = os.path.getsize(tmp_path)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||
|
||||
def iterfile():
|
||||
with open(tmp_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
filename = f"{task.name}_annotations.zip"
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Schemas for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class AutoAnnotationConfig(BaseModel):
|
||||
"""自动标注任务配置(与前端 payload 对齐)"""
|
||||
|
||||
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
|
||||
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
|
||||
target_classes: List[int] = Field(
|
||||
default_factory=list,
|
||||
alias="targetClasses",
|
||||
description="目标类别ID列表,空表示全部类别",
|
||||
)
|
||||
output_dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="outputDatasetName",
|
||||
description="自动标注结果要写入的新数据集名称(可选)",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class CreateAutoAnnotationTaskRequest(BaseModel):
|
||||
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
config: AutoAnnotationConfig = Field(..., description="任务配置")
|
||||
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskResponse(BaseModel):
|
||||
"""自动标注任务响应模型(列表/详情均可复用)"""
|
||||
|
||||
id: str = Field(..., description="任务ID")
|
||||
name: str = Field(..., description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||
source_datasets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
alias="sourceDatasets",
|
||||
description="本任务实际处理涉及到的所有数据集名称列表",
|
||||
)
|
||||
config: Dict[str, Any] = Field(..., description="任务配置")
|
||||
status: str = Field(..., description="任务状态")
|
||||
progress: int = Field(..., description="任务进度 0-100")
|
||||
total_images: int = Field(..., alias="totalImages", description="总图片数")
|
||||
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
||||
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
||||
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
||||
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskListResponse(BaseModel):
|
||||
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
|
||||
|
||||
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Service layer for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
dataset_name: Optional[str] = None,
|
||||
total_images: int = 0,
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
"""创建自动标注任务,初始状态为 pending。
|
||||
|
||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||
后续可以由调度器/worker 读取该表并更新进度。
|
||||
"""
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
task = AutoAnnotationTask(
|
||||
id=str(uuid4()),
|
||||
name=request.name,
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
config=request.config.model_dump(by_alias=True),
|
||||
file_ids=request.file_ids, # 存储用户选择的文件ID列表
|
||||
status="pending",
|
||||
progress=0,
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask)
|
||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
# 出错时降级为单个 datasetName/datasetId
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
responses.append(resp)
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def _compute_source_datasets(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task: AutoAnnotationTask,
|
||||
) -> List[str]:
|
||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||
|
||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||
"""
|
||||
|
||||
file_ids = task.file_ids or []
|
||||
if file_ids:
|
||||
stmt = (
|
||||
select(Dataset.name)
|
||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||
.where(DatasetFiles.id.in_(file_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
names = [row[0] for row in result.fetchall() if row[0]]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 回退:只显示一个数据集
|
||||
if task.dataset_name:
|
||||
return [task.dataset_name]
|
||||
if task.dataset_id:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
return True
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [-d $LOCAL_FILES_DOCUMENT_ROOT ] && $LOCAL_FILES_SERVING_ENABLED; then
|
||||
echo "Using local document root: $LOCAL_FILES_DOCUMENT_ROOT"
|
||||
if [ -d "${LOCAL_FILES_DOCUMENT_ROOT}" ] && [ "${LOCAL_FILES_SERVING_ENABLED}" = "true" ]; then
|
||||
echo "Using local document root: ${LOCAL_FILES_DOCUMENT_ROOT}"
|
||||
fi
|
||||
|
||||
# 启动应用
|
||||
|
||||
17
runtime/ops/__init__.py
Normal file
17
runtime/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Datamate built-in operators package.
|
||||
|
||||
This package contains built-in operators for filtering, slicing, annotation, etc.
|
||||
It is mounted into the runtime container under ``datamate.ops`` so that
|
||||
``from datamate.ops.annotation...`` imports work correctly.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"annotation",
|
||||
"filter",
|
||||
"formatter",
|
||||
"llms",
|
||||
"mapper",
|
||||
"slicer",
|
||||
"user",
|
||||
]
|
||||
6
runtime/ops/annotation/__init__.py
Normal file
6
runtime/ops/annotation/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Annotation-related operators (e.g. YOLO detection)."""
|
||||
|
||||
__all__ = [
|
||||
"image_object_detection_bounding_box",
|
||||
]
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Image object detection (YOLOv8) operator package.
|
||||
|
||||
This package exposes the ImageObjectDetectionBoundingBox annotator so that
|
||||
the auto-annotation worker can import it via different module paths.
|
||||
"""
|
||||
|
||||
from .process import ImageObjectDetectionBoundingBox
|
||||
|
||||
__all__ = ["ImageObjectDetectionBoundingBox"]
|
||||
@@ -0,0 +1,3 @@
|
||||
name: image_object_detection_bounding_box
|
||||
version: 0.1.0
|
||||
description: "YOLOv8-based object detection operator for auto annotation"
|
||||
@@ -0,0 +1,214 @@
|
||||
#!/user/bin/python
|
||||
# -- encoding: utf-8 --
|
||||
|
||||
"""
|
||||
Description: 图像目标检测算子
|
||||
Create: 2025/12/17
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import cv2
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
logger.warning("ultralytics not installed. Please install it using: pip install ultralytics")
|
||||
YOLO = None
|
||||
|
||||
from datamate.core.base_op import Mapper
|
||||
|
||||
|
||||
# COCO 80 类别映射
|
||||
COCO_CLASS_MAP = {
|
||||
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
|
||||
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
|
||||
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
|
||||
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
|
||||
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
|
||||
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
|
||||
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
|
||||
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
|
||||
37: "surfboard", 38: "tennis racket", 39: "bottle",
|
||||
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
|
||||
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
|
||||
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
|
||||
54: "donut", 55: "cake", 56: "chair", 57: "couch",
|
||||
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
|
||||
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
|
||||
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
|
||||
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
|
||||
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
|
||||
78: "hair drier", 79: "toothbrush"
|
||||
}
|
||||
|
||||
|
||||
class ImageObjectDetectionBoundingBox(Mapper):
|
||||
"""图像目标检测算子"""
|
||||
|
||||
# 模型映射
|
||||
MODEL_MAP = {
|
||||
"n": "yolov8n.pt",
|
||||
"s": "yolov8s.pt",
|
||||
"m": "yolov8m.pt",
|
||||
"l": "yolov8l.pt",
|
||||
"x": "yolov8x.pt",
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ImageObjectDetectionBoundingBox, self).__init__(*args, **kwargs)
|
||||
|
||||
# 获取参数
|
||||
self._model_size = kwargs.get("modelSize", "l")
|
||||
self._conf_threshold = kwargs.get("confThreshold", 0.7)
|
||||
self._target_classes = kwargs.get("targetClasses", [])
|
||||
self._output_dir = kwargs.get("outputDir", None) # 输出目录
|
||||
|
||||
# 如果目标类别为空列表,则检测所有类别
|
||||
if not self._target_classes:
|
||||
self._target_classes = None
|
||||
else:
|
||||
# 确保是整数列表
|
||||
self._target_classes = [int(cls_id) for cls_id in self._target_classes]
|
||||
|
||||
# 获取模型路径
|
||||
model_filename = self.MODEL_MAP.get(self._model_size, "yolov8l.pt")
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, model_filename)
|
||||
|
||||
# 初始化模型
|
||||
if YOLO is None:
|
||||
raise ImportError("ultralytics is not installed. Please install it.")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.warning(f"Model file {model_path} not found. Downloading from ultralytics...")
|
||||
self.model = YOLO(model_filename) # 自动下载
|
||||
else:
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
logger.info(f"Loaded YOLOv8 model: {model_filename}, "
|
||||
f"conf_threshold: {self._conf_threshold}, "
|
||||
f"target_classes: {self._target_classes}")
|
||||
|
||||
@staticmethod
|
||||
def _get_color_by_class_id(class_id: int):
|
||||
"""根据 class_id 生成稳定颜色(BGR,OpenCV 用)"""
|
||||
np.random.seed(class_id)
|
||||
color = np.random.randint(0, 255, size=3).tolist()
|
||||
return tuple(color)
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行目标检测"""
|
||||
start = time.time()
|
||||
|
||||
# 读取图像文件
|
||||
image_path = sample.get(self.image_key)
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
logger.warning(f"Image file not found: {image_path}")
|
||||
return sample
|
||||
|
||||
# 读取图像
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
logger.warning(f"Failed to read image: {image_path}")
|
||||
return sample
|
||||
|
||||
# 执行目标检测
|
||||
results = self.model(img, conf=self._conf_threshold)
|
||||
r = results[0]
|
||||
|
||||
# 准备标注数据
|
||||
h, w = img.shape[:2]
|
||||
annotations = {
|
||||
"image": os.path.basename(image_path),
|
||||
"width": w,
|
||||
"height": h,
|
||||
"model_size": self._model_size,
|
||||
"conf_threshold": self._conf_threshold,
|
||||
"selected_class_ids": self._target_classes,
|
||||
"detections": []
|
||||
}
|
||||
|
||||
# 处理检测结果
|
||||
if r.boxes is not None:
|
||||
for box in r.boxes:
|
||||
cls_id = int(box.cls[0])
|
||||
|
||||
# 过滤目标类别
|
||||
if self._target_classes is not None and cls_id not in self._target_classes:
|
||||
continue
|
||||
|
||||
conf = float(box.conf[0])
|
||||
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
||||
label = COCO_CLASS_MAP.get(cls_id, f"class_{cls_id}")
|
||||
|
||||
# 记录检测结果
|
||||
annotations["detections"].append({
|
||||
"label": label,
|
||||
"class_id": cls_id,
|
||||
"confidence": round(conf, 4),
|
||||
"bbox_xyxy": [x1, y1, x2, y2],
|
||||
"bbox_xywh": [x1, y1, x2 - x1, y2 - y1]
|
||||
})
|
||||
|
||||
# 在图像上绘制
|
||||
color = self._get_color_by_class_id(cls_id)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
img,
|
||||
f"{label} {conf:.2f}",
|
||||
(int(x1), max(int(y1) - 5, 10)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
1
|
||||
)
|
||||
|
||||
# 确定输出目录
|
||||
if self._output_dir and os.path.exists(self._output_dir):
|
||||
output_dir = self._output_dir
|
||||
else:
|
||||
output_dir = os.path.dirname(image_path)
|
||||
|
||||
# 创建输出子目录(可选,用于组织文件)
|
||||
images_dir = os.path.join(output_dir, "images")
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
os.makedirs(annotations_dir, exist_ok=True)
|
||||
|
||||
# 保持原始文件名(不添加后缀),确保一一对应
|
||||
base_name = os.path.basename(image_path)
|
||||
name_without_ext = os.path.splitext(base_name)[0]
|
||||
|
||||
# 保存标注图像(保持原始扩展名或使用jpg)
|
||||
output_filename = base_name
|
||||
output_path = os.path.join(images_dir, output_filename)
|
||||
cv2.imwrite(output_path, img)
|
||||
|
||||
# 保存标注 JSON(文件名与图像对应)
|
||||
json_filename = f"{name_without_ext}.json"
|
||||
json_path = os.path.join(annotations_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotations, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 更新样本数据
|
||||
sample["detection_count"] = len(annotations["detections"])
|
||||
sample["output_image"] = output_path
|
||||
sample["annotations_file"] = json_path
|
||||
sample["annotations"] = annotations
|
||||
|
||||
logger.info(f"Image: {os.path.basename(image_path)}, "
|
||||
f"Detections: {len(annotations['detections'])}, "
|
||||
f"Time: {(time.time() - start):.4f}s")
|
||||
|
||||
return sample
|
||||
166
runtime/ops/annotation/image_semantic_segmentation/process.py
Normal file
166
runtime/ops/annotation/image_semantic_segmentation/process.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_color_by_class_id(class_id: int):
|
||||
"""根据 class_id 生成稳定颜色(BGR)"""
|
||||
np.random.seed(class_id)
|
||||
color = np.random.randint(0, 255, size=3).tolist()
|
||||
return tuple(color)
|
||||
|
||||
|
||||
def mask_to_polygons(mask: np.ndarray):
|
||||
"""将二值 mask 转换为 COCO 风格多边形列表"""
|
||||
contours, _ = cv2.findContours(
|
||||
mask,
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
if contour.shape[0] < 3:
|
||||
continue
|
||||
polygon = contour.flatten().tolist()
|
||||
polygons.append(polygon)
|
||||
|
||||
return polygons
|
||||
|
||||
|
||||
IMAGE_DIR = "C:/Users/meta/Desktop/Datamate/yolo/Photos"
|
||||
OUT_IMG_DIR = "outputs_seg/images"
|
||||
OUT_JSON_DIR = "outputs_seg/annotations"
|
||||
|
||||
MODEL_MAP = {
|
||||
"n": "yolov8n-seg.pt",
|
||||
"s": "yolov8s-seg.pt",
|
||||
"m": "yolov8m-seg.pt",
|
||||
"l": "yolov8l-seg.pt",
|
||||
"x": "yolov8x-seg.pt",
|
||||
}
|
||||
MODEL_KEY = "x"
|
||||
MODEL_PATH = MODEL_MAP[MODEL_KEY]
|
||||
|
||||
CONF_THRES = 0.7
|
||||
DRAW_BBOX = True
|
||||
|
||||
COCO_CLASS_MAP = {
|
||||
0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
|
||||
5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
|
||||
10: "fire hydrant", 11: "stop sign", 12: "parking meter", 13: "bench",
|
||||
14: "bird", 15: "cat", 16: "dog", 17: "horse", 18: "sheep", 19: "cow",
|
||||
20: "elephant", 21: "bear", 22: "zebra", 23: "giraffe", 24: "backpack",
|
||||
25: "umbrella", 26: "handbag", 27: "tie", 28: "suitcase", 29: "frisbee",
|
||||
30: "skis", 31: "snowboard", 32: "sports ball", 33: "kite",
|
||||
34: "baseball bat", 35: "baseball glove", 36: "skateboard",
|
||||
37: "surfboard", 38: "tennis racket", 39: "bottle",
|
||||
40: "wine glass", 41: "cup", 42: "fork", 43: "knife", 44: "spoon",
|
||||
45: "bowl", 46: "banana", 47: "apple", 48: "sandwich", 49: "orange",
|
||||
50: "broccoli", 51: "carrot", 52: "hot dog", 53: "pizza",
|
||||
54: "donut", 55: "cake", 56: "chair", 57: "couch",
|
||||
58: "potted plant", 59: "bed", 60: "dining table", 61: "toilet",
|
||||
62: "tv", 63: "laptop", 64: "mouse", 65: "remote",
|
||||
66: "keyboard", 67: "cell phone", 68: "microwave", 69: "oven",
|
||||
70: "toaster", 71: "sink", 72: "refrigerator", 73: "book",
|
||||
74: "clock", 75: "vase", 76: "scissors", 77: "teddy bear",
|
||||
78: "hair drier", 79: "toothbrush"
|
||||
}
|
||||
|
||||
TARGET_CLASS_IDS = [0, 2, 5]
|
||||
|
||||
os.makedirs(OUT_IMG_DIR, exist_ok=True)
|
||||
os.makedirs(OUT_JSON_DIR, exist_ok=True)
|
||||
|
||||
if TARGET_CLASS_IDS is not None:
|
||||
for cid in TARGET_CLASS_IDS:
|
||||
if cid not in COCO_CLASS_MAP:
|
||||
raise ValueError(f"Invalid class id: {cid}")
|
||||
|
||||
model = YOLO(MODEL_PATH)
|
||||
|
||||
image_paths = list(Path(IMAGE_DIR).glob("*.*"))
|
||||
|
||||
for img_path in image_paths:
|
||||
img = cv2.imread(str(img_path))
|
||||
if img is None:
|
||||
print(f"[WARN] Failed to read {img_path}")
|
||||
continue
|
||||
|
||||
results = model(img, conf=CONF_THRES)
|
||||
r = results[0]
|
||||
|
||||
h, w = img.shape[:2]
|
||||
annotations = {
|
||||
"image": img_path.name,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"model_key": MODEL_KEY,
|
||||
"conf_threshold": CONF_THRES,
|
||||
"supported_classes": COCO_CLASS_MAP,
|
||||
"selected_class_ids": TARGET_CLASS_IDS,
|
||||
"instances": []
|
||||
}
|
||||
|
||||
if r.boxes is not None and r.masks is not None:
|
||||
for i, box in enumerate(r.boxes):
|
||||
cls_id = int(box.cls[0])
|
||||
if TARGET_CLASS_IDS is not None and cls_id not in TARGET_CLASS_IDS:
|
||||
continue
|
||||
|
||||
conf = float(box.conf[0])
|
||||
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
||||
label = COCO_CLASS_MAP[cls_id]
|
||||
|
||||
mask = r.masks.data[i].cpu().numpy()
|
||||
mask = (mask > 0.5).astype(np.uint8)
|
||||
mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
color = get_color_by_class_id(cls_id)
|
||||
img[mask == 1] = (
|
||||
img[mask == 1] * 0.5 + np.array(color) * 0.5
|
||||
).astype(np.uint8)
|
||||
|
||||
if True:
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(int(x1), int(y1)),
|
||||
(int(x2), int(y2)),
|
||||
color,
|
||||
2
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
img,
|
||||
f"{label} {conf:.2f}",
|
||||
(int(x1), max(int(y1) - 5, 10)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
color,
|
||||
1
|
||||
)
|
||||
|
||||
polygons = mask_to_polygons(mask)
|
||||
|
||||
annotations["instances"].append({
|
||||
"label": label,
|
||||
"class_id": cls_id,
|
||||
"confidence": round(conf, 4),
|
||||
"bbox_xyxy": [x1, y1, x2, y2],
|
||||
"segmentation": polygons
|
||||
})
|
||||
|
||||
out_img_path = os.path.join(OUT_IMG_DIR, img_path.name)
|
||||
out_json_path = os.path.join(OUT_JSON_DIR, img_path.stem + ".json")
|
||||
|
||||
cv2.imwrite(out_img_path, img)
|
||||
|
||||
with open(out_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotations, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"[OK] {img_path.name}")
|
||||
|
||||
print("Segmentation batch finished.")
|
||||
@@ -31,4 +31,5 @@ dependencies = [
|
||||
"sqlalchemy>=2.0.44",
|
||||
"xmltodict>=1.0.2",
|
||||
"zhconv>=1.4.3",
|
||||
"ultralytics>=8.0.0",
|
||||
]
|
||||
|
||||
603
runtime/python-executor/datamate/auto_annotation_worker.py
Normal file
603
runtime/python-executor/datamate/auto_annotation_worker.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Simple background worker for auto-annotation tasks.
|
||||
|
||||
This module runs inside the datamate-runtime container (operator_runtime service).
|
||||
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
|
||||
inference using the ImageObjectDetectionBoundingBox operator, updating
|
||||
progress back to the same table so that the datamate-python backend and
|
||||
frontend can display real-time status.
|
||||
|
||||
设计目标(最小可用版本):
|
||||
- 单实例 worker,串行处理 `pending` 状态的任务。
|
||||
- 对指定数据集下的所有已完成文件逐张执行目标检测。
|
||||
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
|
||||
- 失败时将任务标记为 `failed` 并记录 `error_message`。
|
||||
|
||||
注意:
|
||||
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时,
|
||||
已处于 running 的任务不会被重新拉起,需要后续扩展。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
|
||||
# 尝试多种导入路径,适配不同的打包/安装方式
|
||||
ImageObjectDetectionBoundingBox = None # type: ignore
|
||||
try:
|
||||
# 优先使用 datamate.ops 路径(源码 COPY 到 /opt/runtime/datamate/ops 情况)
|
||||
from datamate.ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from datamate.ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
except Exception as e1: # pragma: no cover - 导入失败时仅记录日志,避免整体崩溃
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox via datamate.ops: {}",
|
||||
e1,
|
||||
)
|
||||
try:
|
||||
# 兼容顶层 ops 包安装的情况(通过 ops.pth 暴露)
|
||||
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from top-level ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox via top-level ops package: {}",
|
||||
e2,
|
||||
)
|
||||
ImageObjectDetectionBoundingBox = None
|
||||
|
||||
|
||||
# 进一步兜底:直接从本地 runtime/ops 目录加载算子(开发环境常用场景)
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
try:
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
ops_root = project_root / "ops"
|
||||
if ops_root.is_dir():
|
||||
# 确保 ops 的父目录在 sys.path 中,这样可以按 "ops.xxx" 导入
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
||||
ImageObjectDetectionBoundingBox,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Imported ImageObjectDetectionBoundingBox from local runtime/ops.annotation.image_object_detection_bounding_box",
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Local runtime/ops directory not found when trying to import ImageObjectDetectionBoundingBox: {}",
|
||||
ops_root,
|
||||
)
|
||||
except Exception as e3: # pragma: no cover - 兜底失败仅记录日志
|
||||
logger.error(
|
||||
"Failed to import ImageObjectDetectionBoundingBox from local runtime/ops: {}",
|
||||
e3,
|
||||
)
|
||||
ImageObjectDetectionBoundingBox = None
|
||||
|
||||
|
||||
POLL_INTERVAL_SECONDS = float(os.getenv("AUTO_ANNOTATION_POLL_INTERVAL", "5"))
|
||||
|
||||
DEFAULT_OUTPUT_ROOT = os.getenv(
|
||||
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
|
||||
)
|
||||
|
||||
|
||||
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
||||
"""从 t_dm_auto_annotation_tasks 中取出一个 pending 任务。"""
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
SELECT id, name, dataset_id, dataset_name, config, file_ids, status,
|
||||
total_images, processed_images, detected_objects, output_path
|
||||
FROM t_dm_auto_annotation_tasks
|
||||
WHERE status = 'pending' AND deleted_at IS NULL
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
result = conn.execute(sql).fetchone()
|
||||
if not result:
|
||||
return None
|
||||
row = dict(result._mapping) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
row["config"] = json.loads(row["config"]) if row.get("config") else {}
|
||||
except Exception:
|
||||
row["config"] = {}
|
||||
|
||||
try:
|
||||
raw_ids = row.get("file_ids")
|
||||
if not raw_ids:
|
||||
row["file_ids"] = None
|
||||
elif isinstance(raw_ids, str):
|
||||
row["file_ids"] = json.loads(raw_ids)
|
||||
else:
|
||||
row["file_ids"] = raw_ids
|
||||
except Exception:
|
||||
row["file_ids"] = None
|
||||
return row
|
||||
|
||||
|
||||
def _update_task_status(
|
||||
task_id: str,
|
||||
*,
|
||||
status: str,
|
||||
progress: Optional[int] = None,
|
||||
processed_images: Optional[int] = None,
|
||||
detected_objects: Optional[int] = None,
|
||||
total_images: Optional[int] = None,
|
||||
output_path: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
completed: bool = False,
|
||||
) -> None:
|
||||
"""更新任务的状态和统计字段。"""
|
||||
|
||||
fields: List[str] = ["status = :status", "updated_at = :updated_at"]
|
||||
params: Dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"updated_at": datetime.now(),
|
||||
}
|
||||
|
||||
if progress is not None:
|
||||
fields.append("progress = :progress")
|
||||
params["progress"] = int(progress)
|
||||
if processed_images is not None:
|
||||
fields.append("processed_images = :processed_images")
|
||||
params["processed_images"] = int(processed_images)
|
||||
if detected_objects is not None:
|
||||
fields.append("detected_objects = :detected_objects")
|
||||
params["detected_objects"] = int(detected_objects)
|
||||
if total_images is not None:
|
||||
fields.append("total_images = :total_images")
|
||||
params["total_images"] = int(total_images)
|
||||
if output_path is not None:
|
||||
fields.append("output_path = :output_path")
|
||||
params["output_path"] = output_path
|
||||
if error_message is not None:
|
||||
fields.append("error_message = :error_message")
|
||||
params["error_message"] = error_message[:2000]
|
||||
if completed:
|
||||
fields.append("completed_at = :completed_at")
|
||||
params["completed_at"] = datetime.now()
|
||||
|
||||
sql = text(
|
||||
f"""
|
||||
UPDATE t_dm_auto_annotation_tasks
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = :task_id
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(sql, params)
|
||||
|
||||
|
||||
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
|
||||
"""加载指定数据集下的所有已完成文件。"""
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
SELECT id, file_path, file_name
|
||||
FROM t_dm_dataset_files
|
||||
WHERE dataset_id = :dataset_id
|
||||
AND status = 'ACTIVE'
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
rows = conn.execute(sql, {"dataset_id": dataset_id}).fetchall()
|
||||
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
||||
|
||||
|
||||
def _load_files_by_ids(file_ids: List[str]) -> List[Tuple[str, str, str]]:
|
||||
"""根据文件ID列表加载文件记录,支持跨多个数据集。"""
|
||||
|
||||
if not file_ids:
|
||||
return []
|
||||
|
||||
placeholders = ", ".join(f":id{i}" for i in range(len(file_ids)))
|
||||
sql = text(
|
||||
f"""
|
||||
SELECT id, file_path, file_name
|
||||
FROM t_dm_dataset_files
|
||||
WHERE id IN ({placeholders})
|
||||
AND status = 'ACTIVE'
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
)
|
||||
params = {f"id{i}": str(fid) for i, fid in enumerate(file_ids)}
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
||||
|
||||
|
||||
def _ensure_output_dir(output_dir: str) -> str:
|
||||
"""确保输出目录及其 images/、annotations/ 子目录存在。"""
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
|
||||
os.makedirs(os.path.join(output_dir, "annotations"), exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
|
||||
def _create_output_dataset(
|
||||
source_dataset_id: str,
|
||||
source_dataset_name: str,
|
||||
output_dataset_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
"""为自动标注结果创建一个新的数据集并返回 (dataset_id, path)。"""
|
||||
|
||||
new_dataset_id = str(uuid.uuid4())
|
||||
dataset_base_path = DEFAULT_OUTPUT_ROOT.rstrip("/") or "/dataset"
|
||||
output_dir = os.path.join(dataset_base_path, new_dataset_id)
|
||||
|
||||
description = (
|
||||
f"Auto annotations for dataset {source_dataset_name or source_dataset_id}"[:255]
|
||||
)
|
||||
|
||||
sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_datasets (id, name, description, dataset_type, path, status)
|
||||
VALUES (:id, :name, :description, :dataset_type, :path, :status)
|
||||
"""
|
||||
)
|
||||
params = {
|
||||
"id": new_dataset_id,
|
||||
"name": output_dataset_name,
|
||||
"description": description,
|
||||
"dataset_type": "IMAGE",
|
||||
"path": output_dir,
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
conn.execute(sql, params)
|
||||
|
||||
return new_dataset_id, output_dir
|
||||
|
||||
|
||||
def _register_output_dataset(
|
||||
task_id: str,
|
||||
output_dataset_id: str,
|
||||
output_dir: str,
|
||||
output_dataset_name: str,
|
||||
total_images: int,
|
||||
) -> None:
|
||||
"""将自动标注结果注册到新建的数据集。"""
|
||||
|
||||
images_dir = os.path.join(output_dir, "images")
|
||||
if not os.path.isdir(images_dir):
|
||||
logger.warning(
|
||||
"Auto-annotation images directory not found for task {}: {}",
|
||||
task_id,
|
||||
images_dir,
|
||||
)
|
||||
return
|
||||
|
||||
image_files: List[Tuple[str, str, int]] = []
|
||||
annotation_files: List[Tuple[str, str, int]] = []
|
||||
total_size = 0
|
||||
|
||||
for file_name in sorted(os.listdir(images_dir)):
|
||||
file_path = os.path.join(images_dir, file_name)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
except OSError:
|
||||
file_size = 0
|
||||
image_files.append((file_name, file_path, int(file_size)))
|
||||
total_size += int(file_size)
|
||||
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
if os.path.isdir(annotations_dir):
|
||||
for file_name in sorted(os.listdir(annotations_dir)):
|
||||
file_path = os.path.join(annotations_dir, file_name)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
except OSError:
|
||||
file_size = 0
|
||||
annotation_files.append((file_name, file_path, int(file_size)))
|
||||
total_size += int(file_size)
|
||||
|
||||
if not image_files:
|
||||
logger.warning(
|
||||
"No image files found in auto-annotation output for task {}: {}",
|
||||
task_id,
|
||||
images_dir,
|
||||
)
|
||||
return
|
||||
|
||||
insert_file_sql = text(
|
||||
"""
|
||||
INSERT INTO t_dm_dataset_files (
|
||||
id, dataset_id, file_name, file_path, file_type, file_size, status
|
||||
) VALUES (
|
||||
:id, :dataset_id, :file_name, :file_path, :file_type, :file_size, :status
|
||||
)
|
||||
"""
|
||||
)
|
||||
update_dataset_stat_sql = text(
|
||||
"""
|
||||
UPDATE t_dm_datasets
|
||||
SET file_count = COALESCE(file_count, 0) + :add_count,
|
||||
size_bytes = COALESCE(size_bytes, 0) + :add_size
|
||||
WHERE id = :dataset_id
|
||||
"""
|
||||
)
|
||||
|
||||
with SQLManager.create_connect() as conn:
|
||||
added_count = 0
|
||||
|
||||
for file_name, file_path, file_size in image_files:
|
||||
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
||||
conn.execute(
|
||||
insert_file_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"dataset_id": output_dataset_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_type": ext,
|
||||
"file_size": int(file_size),
|
||||
"status": "ACTIVE",
|
||||
},
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
for file_name, file_path, file_size in annotation_files:
|
||||
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
||||
conn.execute(
|
||||
insert_file_sql,
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"dataset_id": output_dataset_id,
|
||||
"file_name": file_name,
|
||||
"file_path": file_path,
|
||||
"file_type": ext,
|
||||
"file_size": int(file_size),
|
||||
"status": "ACTIVE",
|
||||
},
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
if added_count > 0:
|
||||
conn.execute(
|
||||
update_dataset_stat_sql,
|
||||
{
|
||||
"dataset_id": output_dataset_id,
|
||||
"add_count": added_count,
|
||||
"add_size": int(total_size),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Registered auto-annotation output into dataset: dataset_id={}, name={}, added_files={}, added_size_bytes={}, task_id={}, output_dir={}",
|
||||
output_dataset_id,
|
||||
output_dataset_name,
|
||||
len(image_files) + len(annotation_files),
|
||||
total_size,
|
||||
task_id,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
|
||||
def _process_single_task(task: Dict[str, Any]) -> None:
|
||||
"""执行单个自动标注任务。"""
|
||||
|
||||
if ImageObjectDetectionBoundingBox is None:
|
||||
logger.error(
|
||||
"YOLO operator not available (import failed earlier), skip auto-annotation task: {}",
|
||||
task["id"],
|
||||
)
|
||||
_update_task_status(
|
||||
task["id"],
|
||||
status="failed",
|
||||
error_message="YOLO operator not available in runtime container",
|
||||
)
|
||||
return
|
||||
|
||||
task_id = str(task["id"])
|
||||
dataset_id = str(task["dataset_id"])
|
||||
task_name = str(task.get("name") or "")
|
||||
source_dataset_name = str(task.get("dataset_name") or "")
|
||||
cfg: Dict[str, Any] = task.get("config") or {}
|
||||
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
|
||||
|
||||
model_size = cfg.get("modelSize", "l")
|
||||
conf_threshold = float(cfg.get("confThreshold", 0.7))
|
||||
target_classes = cfg.get("targetClasses", []) or []
|
||||
output_dataset_name = cfg.get("outputDatasetName")
|
||||
|
||||
if not output_dataset_name:
|
||||
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
|
||||
output_dataset_name = f"{base_name}_auto_{task_id[:8]}"
|
||||
|
||||
logger.info(
|
||||
"Start processing auto-annotation task: id={}, dataset_id={}, model_size={}, conf_threshold={}, target_classes={}, output_dataset_name={}",
|
||||
task_id,
|
||||
dataset_id,
|
||||
model_size,
|
||||
conf_threshold,
|
||||
target_classes,
|
||||
output_dataset_name,
|
||||
)
|
||||
|
||||
_update_task_status(task_id, status="running", progress=0)
|
||||
|
||||
if selected_file_ids:
|
||||
all_files = _load_files_by_ids(selected_file_ids)
|
||||
else:
|
||||
all_files = _load_dataset_files(dataset_id)
|
||||
|
||||
files = [(path, name) for _, path, name in all_files]
|
||||
|
||||
total_images = len(files)
|
||||
if total_images == 0:
|
||||
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
total_images=0,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
completed=True,
|
||||
output_path=None,
|
||||
)
|
||||
return
|
||||
|
||||
output_dataset_id, output_dir = _create_output_dataset(
|
||||
source_dataset_id=dataset_id,
|
||||
source_dataset_name=source_dataset_name,
|
||||
output_dataset_name=output_dataset_name,
|
||||
)
|
||||
output_dir = _ensure_output_dir(output_dir)
|
||||
|
||||
try:
|
||||
detector = ImageObjectDetectionBoundingBox(
|
||||
modelSize=model_size,
|
||||
confThreshold=conf_threshold,
|
||||
targetClasses=target_classes,
|
||||
outputDir=output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to init YOLO detector for task {}: {}", task_id, e)
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="failed",
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
error_message=f"Init YOLO detector failed: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
processed = 0
|
||||
detected_total = 0
|
||||
|
||||
for file_path, file_name in files:
|
||||
try:
|
||||
sample = {
|
||||
"image": file_path,
|
||||
"filename": file_name,
|
||||
}
|
||||
result = detector.execute(sample)
|
||||
|
||||
annotations = (result or {}).get("annotations", {})
|
||||
detections = annotations.get("detections", [])
|
||||
detected_total += len(detections)
|
||||
processed += 1
|
||||
|
||||
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="running",
|
||||
progress=progress,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process image for task {}: file_path={}, error={}",
|
||||
task_id,
|
||||
file_path,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
_update_task_status(
|
||||
task_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
processed_images=processed,
|
||||
detected_objects=detected_total,
|
||||
total_images=total_images,
|
||||
output_path=output_dir,
|
||||
completed=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
||||
task_id,
|
||||
total_images,
|
||||
processed,
|
||||
detected_total,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
if output_dataset_name and output_dataset_id:
|
||||
try:
|
||||
_register_output_dataset(
|
||||
task_id=task_id,
|
||||
output_dataset_id=output_dataset_id,
|
||||
output_dir=output_dir,
|
||||
output_dataset_name=output_dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error(
|
||||
"Failed to register auto-annotation output as dataset for task {}: {}",
|
||||
task_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _worker_loop() -> None:
|
||||
"""Worker 主循环,在独立线程中运行。"""
|
||||
|
||||
logger.info(
|
||||
"Auto-annotation worker started with poll interval {} seconds, output root {}",
|
||||
POLL_INTERVAL_SECONDS,
|
||||
DEFAULT_OUTPUT_ROOT,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
task = _fetch_pending_task()
|
||||
if not task:
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
continue
|
||||
|
||||
_process_single_task(task)
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error("Auto-annotation worker loop error: {}", e)
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_auto_annotation_worker() -> None:
|
||||
"""在后台线程中启动自动标注 worker。"""
|
||||
|
||||
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
|
||||
thread.start()
|
||||
logger.info("Auto-annotation worker thread started: {}", thread.name)
|
||||
@@ -1,163 +1,174 @@
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from jsonargparse import ArgumentParser
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from datamate.common.error_code import ErrorCode
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
from datamate.scheduler import func_scheduler
|
||||
from datamate.wrappers import WRAPPERS
|
||||
|
||||
# 日志配置
|
||||
LOG_DIR = "/var/log/datamate/runtime"
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger.add(
|
||||
f"{LOG_DIR}/runtime.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
"""自定义API异常"""
|
||||
|
||||
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
|
||||
extra_data: Optional[Dict] = None):
|
||||
self.error_code = error_code
|
||||
self.detail = detail or error_code.value[1]
|
||||
self.code = error_code.value[0]
|
||||
self.extra_data = extra_data
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"code": self.code,
|
||||
"message": self.detail,
|
||||
"success": False
|
||||
}
|
||||
if self.extra_data:
|
||||
result["data"] = self.extra_data
|
||||
return result
|
||||
|
||||
|
||||
@app.exception_handler(APIException)
|
||||
async def api_exception_handler(request: Request, exc: APIException):
|
||||
return JSONResponse(
|
||||
status_code=200, # 业务错误返回 200,错误信息在响应体中
|
||||
content=exc.to_dict()
|
||||
)
|
||||
|
||||
|
||||
class QueryTaskRequest(BaseModel):
|
||||
task_ids: List[str]
|
||||
|
||||
|
||||
@app.post("/api/task/list")
|
||||
async def query_task_info(request: QueryTaskRequest):
|
||||
try:
|
||||
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
|
||||
except Exception as e:
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/submit")
|
||||
async def submit_task(task_id):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
logger.info("Start submitting job...")
|
||||
|
||||
dataset_path = get_from_cfg(task_id, "dataset_path")
|
||||
if not check_valid_path(dataset_path):
|
||||
logger.error(f"dataset_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
await WRAPPERS.get(executor_type).submit(task_id, config_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error happens during submitting task. Error Info following: {e}")
|
||||
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
|
||||
|
||||
logger.info(f"task id: {task_id} has been submitted.")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been submitted"},
|
||||
status_code=200
|
||||
)
|
||||
return success_json_info
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/stop")
|
||||
async def stop_task(task_id):
|
||||
logger.info("Start stopping ray job...")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been stopped"},
|
||||
status_code=200
|
||||
)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
if not WRAPPERS.get(executor_type).cancel(task_id):
|
||||
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
|
||||
except Exception as e:
|
||||
if isinstance(e, APIException):
|
||||
raise e
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
logger.info(f"{task_id} has been stopped.")
|
||||
return success_json_info
|
||||
|
||||
|
||||
def check_valid_path(file_path):
|
||||
full_path = os.path.abspath(file_path)
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def get_from_cfg(task_id, key):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
if not check_valid_path(config_path):
|
||||
logger.error(f"config_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
with open(config_path, "r", encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
cfg = yaml.safe_load(content)
|
||||
return cfg[key]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
|
||||
|
||||
parser.add_argument(
|
||||
'--ip',
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help='Service ip for this API, default to use 0.0.0.0.'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Service port for this API, default to use 8600.'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
p_args = parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=p_args.ip,
|
||||
port=p_args.port
|
||||
)
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from jsonargparse import ArgumentParser
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from datamate.common.error_code import ErrorCode
|
||||
from datamate.scheduler import cmd_scheduler
|
||||
from datamate.scheduler import func_scheduler
|
||||
from datamate.wrappers import WRAPPERS
|
||||
from datamate.auto_annotation_worker import start_auto_annotation_worker
|
||||
|
||||
# 日志配置
|
||||
LOG_DIR = "/var/log/datamate/runtime"
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
logger.add(
|
||||
f"{LOG_DIR}/runtime.log",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} - {message}",
|
||||
level="DEBUG",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
"""自定义API异常"""
|
||||
|
||||
def __init__(self, error_code: ErrorCode, detail: Optional[str] = None,
|
||||
extra_data: Optional[Dict] = None):
|
||||
self.error_code = error_code
|
||||
self.detail = detail or error_code.value[1]
|
||||
self.code = error_code.value[0]
|
||||
self.extra_data = extra_data
|
||||
super().__init__(self.detail)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"code": self.code,
|
||||
"message": self.detail,
|
||||
"success": False
|
||||
}
|
||||
if self.extra_data:
|
||||
result["data"] = self.extra_data
|
||||
return result
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""FastAPI 启动时初始化后台自动标注 worker。"""
|
||||
|
||||
try:
|
||||
start_auto_annotation_worker()
|
||||
except Exception as e: # pragma: no cover - 防御性日志
|
||||
logger.error("Failed to start auto-annotation worker: {}", e)
|
||||
|
||||
|
||||
@app.exception_handler(APIException)
|
||||
async def api_exception_handler(request: Request, exc: APIException):
|
||||
return JSONResponse(
|
||||
status_code=200, # 业务错误返回 200,错误信息在响应体中
|
||||
content=exc.to_dict()
|
||||
)
|
||||
|
||||
|
||||
class QueryTaskRequest(BaseModel):
|
||||
task_ids: List[str]
|
||||
|
||||
|
||||
@app.post("/api/task/list")
|
||||
async def query_task_info(request: QueryTaskRequest):
|
||||
try:
|
||||
return [{task_id: cmd_scheduler.get_task_status(task_id)} for task_id in request.task_ids]
|
||||
except Exception as e:
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/submit")
|
||||
async def submit_task(task_id):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
logger.info("Start submitting job...")
|
||||
|
||||
dataset_path = get_from_cfg(task_id, "dataset_path")
|
||||
if not check_valid_path(dataset_path):
|
||||
logger.error(f"dataset_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
await WRAPPERS.get(executor_type).submit(task_id, config_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error happens during submitting task. Error Info following: {e}")
|
||||
raise APIException(ErrorCode.SUBMIT_TASK_ERROR)
|
||||
|
||||
logger.info(f"task id: {task_id} has been submitted.")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been submitted"},
|
||||
status_code=200
|
||||
)
|
||||
return success_json_info
|
||||
|
||||
|
||||
@app.post("/api/task/{task_id}/stop")
|
||||
async def stop_task(task_id):
|
||||
logger.info("Start stopping ray job...")
|
||||
success_json_info = JSONResponse(
|
||||
content={"status": "Success", "message": f"{task_id} has been stopped"},
|
||||
status_code=200
|
||||
)
|
||||
|
||||
try:
|
||||
executor_type = get_from_cfg(task_id, "executor_type")
|
||||
if not WRAPPERS.get(executor_type).cancel(task_id):
|
||||
raise APIException(ErrorCode.CANCEL_TASK_ERROR)
|
||||
except Exception as e:
|
||||
if isinstance(e, APIException):
|
||||
raise e
|
||||
raise APIException(ErrorCode.UNKNOWN_ERROR)
|
||||
|
||||
logger.info(f"{task_id} has been stopped.")
|
||||
return success_json_info
|
||||
|
||||
|
||||
def check_valid_path(file_path):
|
||||
full_path = os.path.abspath(file_path)
|
||||
return os.path.exists(full_path)
|
||||
|
||||
|
||||
def get_from_cfg(task_id, key):
|
||||
config_path = f"/flow/{task_id}/process.yaml"
|
||||
if not check_valid_path(config_path):
|
||||
logger.error(f"config_path is not existed! please check this path.")
|
||||
raise APIException(ErrorCode.FILE_NOT_FOUND_ERROR)
|
||||
|
||||
with open(config_path, "r", encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
cfg = yaml.safe_load(content)
|
||||
return cfg[key]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Create API for Submitting Job to Data-juicer")
|
||||
|
||||
parser.add_argument(
|
||||
'--ip',
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help='Service ip for this API, default to use 0.0.0.0.'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8080,
|
||||
help='Service port for this API, default to use 8600.'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
p_args = parse_args()
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=p_args.ip,
|
||||
port=p_args.port
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user