You've already forked DataMate
feat(auto-annotation): integrate YOLO auto-labeling and enhance data management (#223)
* feat(auto-annotation): initial setup * chore: remove package-lock.json * chore: 清理本地测试脚本与 Maven 设置 * chore: change package-lock.json
This commit is contained in:
@@ -1,60 +1,95 @@
|
||||
"""
|
||||
Tables of Annotation Management Module
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class AnnotationTemplate(Base):
|
||||
"""标注配置模板模型"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(100), nullable=False, comment="模板名称")
|
||||
description = Column(String(500), nullable=True, comment="模板描述")
|
||||
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
|
||||
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
|
||||
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
|
||||
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
|
||||
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
version = Column(String(20), default='1.0', comment="模板版本")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
class LabelingProject(Base):
|
||||
"""标注项目模型"""
|
||||
|
||||
__tablename__ = "t_dm_labeling_projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
name = Column(String(100), nullable=False, comment="项目名称")
|
||||
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
|
||||
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
|
||||
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
|
||||
progress = Column(JSON, nullable=True, comment="项目进度信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
"""Tables of Annotation Management Module"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, TIMESTAMP, Text, Integer, JSON, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class AnnotationTemplate(Base):
|
||||
"""标注配置模板模型"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(100), nullable=False, comment="模板名称")
|
||||
description = Column(String(500), nullable=True, comment="模板描述")
|
||||
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
|
||||
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
|
||||
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
|
||||
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
|
||||
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
version = Column(String(20), default='1.0', comment="模板版本")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
class LabelingProject(Base):
|
||||
"""标注项目模型"""
|
||||
|
||||
__tablename__ = "t_dm_labeling_projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
name = Column(String(100), nullable=False, comment="项目名称")
|
||||
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
|
||||
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
|
||||
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
|
||||
progress = Column(JSON, nullable=True, comment="项目进度信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
|
||||
|
||||
class AutoAnnotationTask(Base):
|
||||
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
|
||||
|
||||
__tablename__ = "t_dm_auto_annotation_tasks"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(255), nullable=False, comment="任务名称")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
dataset_name = Column(String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)")
|
||||
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
||||
file_ids = Column(JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像")
|
||||
status = Column(String(50), nullable=False, default="pending", comment="任务状态: pending/running/completed/failed")
|
||||
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
||||
total_images = Column(Integer, default=0, comment="总图片数")
|
||||
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
||||
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
||||
output_path = Column(String(500), nullable=True, comment="输出路径")
|
||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(
|
||||
TIMESTAMP,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
comment="更新时间",
|
||||
)
|
||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
|
||||
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
"""检查是否已被软删除"""
|
||||
return self.deleted_at is not None
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .config import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
tags = ["annotation"]
|
||||
)
|
||||
|
||||
router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .config import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
from .auto import router as auto_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
tags = ["annotation"]
|
||||
)
|
||||
|
||||
router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
router.include_router(auto_router)
|
||||
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
196
runtime/datamate-python/app/module/annotation/interface/auto.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""FastAPI routes for Auto Annotation tasks.
|
||||
|
||||
These routes back the frontend AutoAnnotation module:
|
||||
- GET /api/annotation/auto
|
||||
- POST /api/annotation/auto
|
||||
- DELETE /api/annotation/auto/{task_id}
|
||||
- GET /api/annotation/auto/{task_id}/status (simple wrapper)
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/auto",
|
||||
tags=["annotation/auto"],
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||
async def list_auto_annotation_tasks(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取自动标注任务列表。
|
||||
|
||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||
"""
|
||||
|
||||
tasks = await service.list_tasks(db)
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def create_auto_annotation_task(
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建自动标注任务。
|
||||
|
||||
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
|
||||
request.name,
|
||||
request.dataset_id,
|
||||
request.config.model_dump(by_alias=True),
|
||||
request.file_ids,
|
||||
)
|
||||
|
||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||
dataset_name = None
|
||||
total_images = 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||
dataset = await dm_client.get_dataset(request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
# 如果提供了 file_ids,则 total_images 为选中文件数;否则使用数据集文件数
|
||||
if request.file_ids:
|
||||
total_images = len(request.file_ids)
|
||||
else:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset name for auto task: %s", e)
|
||||
|
||||
task = await service.create_task(
|
||||
db,
|
||||
request,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
|
||||
async def get_auto_annotation_task_status(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个自动标注任务状态。
|
||||
|
||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||
"""
|
||||
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=StandardResponse[bool])
|
||||
async def delete_auto_annotation_task(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
ok = await service.soft_delete_task(db, task_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=True,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}/download")
|
||||
async def download_auto_annotation_result(
|
||||
task_id: str = Path(..., description="任务ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
task = await service.get_task(db, task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not task.output_path:
|
||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||
|
||||
output_dir = task.output_path
|
||||
if not os.path.isdir(output_dir):
|
||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||
os.close(tmp_fd)
|
||||
|
||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
arcname = os.path.relpath(file_path, output_dir)
|
||||
zf.write(file_path, arcname)
|
||||
|
||||
file_size = os.path.getsize(tmp_path)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||
|
||||
def iterfile():
|
||||
with open(tmp_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
filename = f"{task.name}_annotations.zip"
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
73
runtime/datamate-python/app/module/annotation/schema/auto.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Schemas for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class AutoAnnotationConfig(BaseModel):
|
||||
"""自动标注任务配置(与前端 payload 对齐)"""
|
||||
|
||||
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
|
||||
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
|
||||
target_classes: List[int] = Field(
|
||||
default_factory=list,
|
||||
alias="targetClasses",
|
||||
description="目标类别ID列表,空表示全部类别",
|
||||
)
|
||||
output_dataset_name: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="outputDatasetName",
|
||||
description="自动标注结果要写入的新数据集名称(可选)",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class CreateAutoAnnotationTaskRequest(BaseModel):
|
||||
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
config: AutoAnnotationConfig = Field(..., description="任务配置")
|
||||
file_ids: Optional[List[str]] = Field(None, alias="fileIds", description="要处理的文件ID列表,为空则处理数据集中所有图像")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskResponse(BaseModel):
|
||||
"""自动标注任务响应模型(列表/详情均可复用)"""
|
||||
|
||||
id: str = Field(..., description="任务ID")
|
||||
name: str = Field(..., description="任务名称")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||
source_datasets: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
alias="sourceDatasets",
|
||||
description="本任务实际处理涉及到的所有数据集名称列表",
|
||||
)
|
||||
config: Dict[str, Any] = Field(..., description="任务配置")
|
||||
status: str = Field(..., description="任务状态")
|
||||
progress: int = Field(..., description="任务进度 0-100")
|
||||
total_images: int = Field(..., alias="totalImages", description="总图片数")
|
||||
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
||||
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
||||
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
||||
error_message: Optional[str] = Field(None, alias="errorMessage", description="错误信息")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
||||
|
||||
|
||||
class AutoAnnotationTaskListResponse(BaseModel):
|
||||
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
|
||||
|
||||
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
154
runtime/datamate-python/app/module/annotation/service/auto.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Service layer for Auto Annotation tasks"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models.annotation_management import AutoAnnotationTask
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAutoAnnotationTaskRequest,
|
||||
dataset_name: Optional[str] = None,
|
||||
total_images: int = 0,
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
"""创建自动标注任务,初始状态为 pending。
|
||||
|
||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||
后续可以由调度器/worker 读取该表并更新进度。
|
||||
"""
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
task = AutoAnnotationTask(
|
||||
id=str(uuid4()),
|
||||
name=request.name,
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
config=request.config.model_dump(by_alias=True),
|
||||
file_ids=request.file_ids, # 存储用户选择的文件ID列表
|
||||
status="pending",
|
||||
progress=0,
|
||||
total_images=total_images,
|
||||
processed_images=0,
|
||||
detected_objects=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(task)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask)
|
||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
||||
.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
# 出错时降级为单个 datasetName/datasetId
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
responses.append(resp)
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def _compute_source_datasets(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task: AutoAnnotationTask,
|
||||
) -> List[str]:
|
||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||
|
||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||
"""
|
||||
|
||||
file_ids = task.file_ids or []
|
||||
if file_ids:
|
||||
stmt = (
|
||||
select(Dataset.name)
|
||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||
.where(DatasetFiles.id.in_(file_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
names = [row[0] for row in result.fetchall() if row[0]]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 回退:只显示一个数据集
|
||||
if task.dataset_name:
|
||||
return [task.dataset_name]
|
||||
if task.dataset_id:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
|
||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
||||
result = await db.execute(
|
||||
select(AutoAnnotationTask).where(
|
||||
AutoAnnotationTask.id == task_id,
|
||||
AutoAnnotationTask.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
return True
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [-d $LOCAL_FILES_DOCUMENT_ROOT ] && $LOCAL_FILES_SERVING_ENABLED; then
|
||||
echo "Using local document root: $LOCAL_FILES_DOCUMENT_ROOT"
|
||||
if [ -d "${LOCAL_FILES_DOCUMENT_ROOT}" ] && [ "${LOCAL_FILES_SERVING_ENABLED}" = "true" ]; then
|
||||
echo "Using local document root: ${LOCAL_FILES_DOCUMENT_ROOT}"
|
||||
fi
|
||||
|
||||
# 启动应用
|
||||
|
||||
Reference in New Issue
Block a user