feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)

移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO
类型数据集均可用于自动标注任务。

- 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列
- 后端 schema/API/service 全链路传递 dataset_type
- Worker 动态构建 sample key(image/text/audio/video)和输出目录
- 前端移除数据集类型校验,下拉框显示数据集类型标识
- 输出数据集继承源数据集类型,不再硬编码为 IMAGE
- 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 23:23:05 +08:00
parent 807c2289e2
commit 8ffa131fad
7 changed files with 1161 additions and 1082 deletions

View File

@@ -6,7 +6,7 @@ import { ArrowLeft } from "lucide-react";
import { Link, useNavigate } from "react-router"; import { Link, useNavigate } from "react-router";
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api"; import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
import { mapDataset } from "@/pages/DataManagement/dataset.const"; import { mapDataset } from "@/pages/DataManagement/dataset.const";
import { Dataset, DatasetType } from "@/pages/DataManagement/dataset.model"; import { Dataset } from "@/pages/DataManagement/dataset.model";
import { createAnnotationOperatorTaskUsingPost } from "../annotation.api"; import { createAnnotationOperatorTaskUsingPost } from "../annotation.api";
import { useCreateStepTwo } from "./hooks/useCreateStepTwo"; import { useCreateStepTwo } from "./hooks/useCreateStepTwo";
import PipelinePreview from "./components/PipelinePreview"; import PipelinePreview from "./components/PipelinePreview";
@@ -85,11 +85,6 @@ export default function AnnotationOperatorTaskCreate() {
try { try {
if (currentStep === 1) { if (currentStep === 1) {
await form.validateFields(); await form.validateFields();
if (selectedDataset?.datasetType !== DatasetType.IMAGE) {
message.error("自动标注算子编排当前仅支持图片数据集");
return;
}
} }
setCurrentStep((prev) => Math.min(prev + 1, 2)); setCurrentStep((prev) => Math.min(prev + 1, 2));
} catch { } catch {
@@ -109,11 +104,6 @@ export default function AnnotationOperatorTaskCreate() {
return; return;
} }
if (selectedDataset?.datasetType !== DatasetType.IMAGE) {
message.error("自动标注算子编排当前仅支持图片数据集");
return;
}
const outputDatasetName = values.outputDatasetName?.trim(); const outputDatasetName = values.outputDatasetName?.trim();
const pipeline = selectedOperators.map((operator, index) => { const pipeline = selectedOperators.map((operator, index) => {
const overrides = { const overrides = {
@@ -200,10 +190,10 @@ export default function AnnotationOperatorTaskCreate() {
label="选择数据集" label="选择数据集"
name="datasetId" name="datasetId"
rules={[{ required: true, message: "请选择数据集" }]} rules={[{ required: true, message: "请选择数据集" }]}
extra="自动标注算子编排当前仅支持图片数据集" extra="请选择用于自动标注的数据集"
> >
<Select <Select
placeholder="请选择图片数据集" placeholder="请选择数据集"
optionFilterProp="label" optionFilterProp="label"
options={datasets.map((dataset) => ({ options={datasets.map((dataset) => ({
label: ( label: (
@@ -215,12 +205,11 @@ export default function AnnotationOperatorTaskCreate() {
{dataset.name} {dataset.name}
</div> </div>
<div className="text-xs text-gray-500"> <div className="text-xs text-gray-500">
{dataset?.fileCount} {dataset.size} {dataset.datasetType} &bull; {dataset?.fileCount} &bull; {dataset.size}
</div> </div>
</div> </div>
), ),
value: dataset.id, value: dataset.id,
disabled: dataset.datasetType !== DatasetType.IMAGE,
}))} }))}
/> />
</Form.Item> </Form.Item>

View File

@@ -197,7 +197,7 @@ class AnnotationResult(Base):
return f"<AnnotationResult(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>" return f"<AnnotationResult(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>"
class AutoAnnotationTask(Base): class AutoAnnotationTask(Base):
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks""" """自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
__tablename__ = "t_dm_auto_annotation_tasks" __tablename__ = "t_dm_auto_annotation_tasks"
@@ -206,94 +206,98 @@ class AutoAnnotationTask(Base):
String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID" String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID"
) )
name = Column(String(255), nullable=False, comment="任务名称") name = Column(String(255), nullable=False, comment="任务名称")
dataset_id = Column(String(36), nullable=False, comment="数据集ID") dataset_id = Column(String(36), nullable=False, comment="数据集ID")
dataset_name = Column( dataset_name = Column(
String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)" String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)"
) )
created_by = Column(String(255), nullable=True, comment="任务创建人") dataset_type = Column(
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)") String(50), nullable=False, default="IMAGE",
file_ids = Column( comment="数据集类型: IMAGE/TEXT/AUDIO/VIDEO",
JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像" )
) created_by = Column(String(255), nullable=True, comment="任务创建人")
status = Column( config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
String(50), file_ids = Column(
nullable=False, JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像"
default="pending", )
comment="任务状态: pending/running/completed/failed/stopped", status = Column(
) String(50),
task_mode = Column( nullable=False,
String(32), default="pending",
nullable=False, comment="任务状态: pending/running/completed/failed/stopped",
default="legacy_yolo", )
comment="任务模式: legacy_yolo/pipeline", task_mode = Column(
) String(32),
executor_type = Column( nullable=False,
String(32), default="legacy_yolo",
nullable=False, comment="任务模式: legacy_yolo/pipeline",
default="annotation_local", )
comment="执行器类型", executor_type = Column(
) String(32),
pipeline = Column(JSON, nullable=True, comment="算子编排定义") nullable=False,
progress = Column(Integer, default=0, comment="任务进度 0-100") default="annotation_local",
stop_requested = Column(Boolean, default=False, comment="是否请求停止") comment="执行器类型",
total_images = Column(Integer, default=0, comment="总图片数") )
processed_images = Column(Integer, default=0, comment="已处理图片数") pipeline = Column(JSON, nullable=True, comment="算子编排定义")
detected_objects = Column(Integer, default=0, comment="检测到的对象总数") progress = Column(Integer, default=0, comment="任务进度 0-100")
output_path = Column(String(500), nullable=True, comment="输出路径") stop_requested = Column(Boolean, default=False, comment="是否请求停止")
output_dataset_id = Column(String(36), nullable=True, comment="输出数据集ID") total_images = Column(Integer, default=0, comment="总图片数")
error_message = Column(Text, nullable=True, comment="错误信息") processed_images = Column(Integer, default=0, comment="已处理图片数")
created_at = Column( detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间" output_path = Column(String(500), nullable=True, comment="输出路径")
) output_dataset_id = Column(String(36), nullable=True, comment="输出数据集ID")
error_message = Column(Text, nullable=True, comment="错误信息")
created_at = Column(
TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间"
)
updated_at = Column( updated_at = Column(
TIMESTAMP, TIMESTAMP,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
onupdate=func.current_timestamp(), onupdate=func.current_timestamp(),
comment="更新时间", comment="更新时间",
) )
started_at = Column(TIMESTAMP, nullable=True, comment="任务启动时间") started_at = Column(TIMESTAMP, nullable=True, comment="任务启动时间")
heartbeat_at = Column(TIMESTAMP, nullable=True, comment="worker心跳时间") heartbeat_at = Column(TIMESTAMP, nullable=True, comment="worker心跳时间")
run_token = Column(String(64), nullable=True, comment="运行令牌") run_token = Column(String(64), nullable=True, comment="运行令牌")
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间") completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)") deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self) -> str: # pragma: no cover - repr 简单返回 def __repr__(self) -> str: # pragma: no cover - repr 简单返回
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>" return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
@property @property
def is_deleted(self) -> bool: def is_deleted(self) -> bool:
"""检查是否已被软删除""" """检查是否已被软删除"""
return self.deleted_at is not None return self.deleted_at is not None
class AnnotationTaskOperatorInstance(Base): class AnnotationTaskOperatorInstance(Base):
"""自动标注任务内算子实例模型,对应表 t_dm_annotation_task_operator_instance""" """自动标注任务内算子实例模型,对应表 t_dm_annotation_task_operator_instance"""
__tablename__ = "t_dm_annotation_task_operator_instance" __tablename__ = "t_dm_annotation_task_operator_instance"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="自增主键") id = Column(BigInteger, primary_key=True, autoincrement=True, comment="自增主键")
task_id = Column(String(36), nullable=False, comment="自动标注任务ID") task_id = Column(String(36), nullable=False, comment="自动标注任务ID")
op_index = Column(Integer, nullable=False, comment="算子顺序(从1开始)") op_index = Column(Integer, nullable=False, comment="算子顺序(从1开始)")
operator_id = Column(String(64), nullable=False, comment="算子ID(raw_id)") operator_id = Column(String(64), nullable=False, comment="算子ID(raw_id)")
settings_override = Column(JSON, nullable=True, comment="任务级算子参数覆盖") settings_override = Column(JSON, nullable=True, comment="任务级算子参数覆盖")
inputs = Column(String(64), nullable=True, comment="输入模态") inputs = Column(String(64), nullable=True, comment="输入模态")
outputs = Column(String(64), nullable=True, comment="输出模态") outputs = Column(String(64), nullable=True, comment="输出模态")
created_at = Column( created_at = Column(
TIMESTAMP, TIMESTAMP,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
nullable=False, nullable=False,
comment="创建时间", comment="创建时间",
) )
updated_at = Column( updated_at = Column(
TIMESTAMP, TIMESTAMP,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
onupdate=func.current_timestamp(), onupdate=func.current_timestamp(),
nullable=False, nullable=False,
comment="更新时间", comment="更新时间",
) )
__table_args__ = ( __table_args__ = (
UniqueConstraint("task_id", "op_index", name="uk_task_op_index"), UniqueConstraint("task_id", "op_index", name="uk_task_op_index"),
Index("idx_task_id", "task_id"), Index("idx_task_id", "task_id"),
Index("idx_operator_id", "operator_id"), Index("idx_operator_id", "operator_id"),
) )

View File

@@ -15,9 +15,9 @@ from typing import List, Literal
from fastapi import APIRouter, Depends, HTTPException, Path from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db from app.db.session import get_db
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
@@ -29,15 +29,15 @@ from ..security import (
from ..schema.auto import ( from ..schema.auto import (
CreateAutoAnnotationTaskRequest, CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse, AutoAnnotationTaskResponse,
) )
from ..service.auto import AutoAnnotationTaskService from ..service.auto import AutoAnnotationTaskService
router = APIRouter( router = APIRouter(
tags=["annotation/auto"], tags=["annotation/auto"],
) )
logger = get_logger(__name__) logger = get_logger(__name__)
service = AutoAnnotationTaskService() service = AutoAnnotationTaskService()
@@ -85,23 +85,28 @@ async def _create_task_internal(
await assert_dataset_access(db, normalized_request.dataset_id, user_context) await assert_dataset_access(db, normalized_request.dataset_id, user_context)
# 尝试获取数据集名称和总量用于冗余字段 # 尝试获取数据集名称和总量用于冗余字段
dataset_name = None dataset_name = None
dataset_type = "IMAGE"
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0 total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
try: try:
dm_client = DatasetManagementService(db) dm_client = DatasetManagementService(db)
dataset = await dm_client.get_dataset(normalized_request.dataset_id) dataset = await dm_client.get_dataset(normalized_request.dataset_id)
if dataset is not None: if dataset is not None:
dataset_name = dataset.name dataset_name = dataset.name
dataset_type = getattr(dataset, "datasetType", None) or "IMAGE"
if not normalized_request.file_ids: if not normalized_request.file_ids:
total_images = getattr(dataset, "fileCount", 0) or 0 total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错 except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset summary for annotation task: %s", e) logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
resolved_dataset_type = normalized_request.dataset_type or dataset_type
return await service.create_task( return await service.create_task(
db, db,
normalized_request, normalized_request,
user_context=user_context, user_context=user_context,
dataset_name=dataset_name, dataset_name=dataset_name,
total_images=total_images, total_images=total_images,
dataset_type=resolved_dataset_type,
) )
@@ -177,10 +182,10 @@ async def get_auto_annotation_task_status(
task = await service.get_task(db, task_id, user_context) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=task, data=task,
) )
@@ -192,12 +197,12 @@ async def delete_auto_annotation_task(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context), user_context: RequestUserContext = Depends(get_request_user_context),
): ):
"""删除(软删除)自动标注任务,仅标记 deleted_at。""" """删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id, user_context) ok = await service.soft_delete_task(db, task_id, user_context)
if not ok: if not ok:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
@@ -232,50 +237,50 @@ async def download_auto_annotation_result(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context), user_context: RequestUserContext = Depends(get_request_user_context),
): ):
"""下载指定自动标注任务的结果 ZIP。""" """下载指定自动标注任务的结果 ZIP。"""
import os import os
import zipfile import zipfile
import tempfile import tempfile
# 复用服务层获取任务信息 # 复用服务层获取任务信息
task = await service.get_task(db, task_id, user_context) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
if not task.output_path: if not task.output_path:
raise HTTPException(status_code=400, detail="Task has no output path") raise HTTPException(status_code=400, detail="Task has no output path")
output_dir = task.output_path output_dir = task.output_path
if not os.path.isdir(output_dir): if not os.path.isdir(output_dir):
raise HTTPException(status_code=404, detail="Output directory not found") raise HTTPException(status_code=404, detail="Output directory not found")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip") tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd) os.close(tmp_fd)
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf: with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(output_dir): for root, _, files in os.walk(output_dir):
for filename in files: for filename in files:
file_path = os.path.join(root, filename) file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, output_dir) arcname = os.path.relpath(file_path, output_dir)
zf.write(file_path, arcname) zf.write(file_path, arcname)
file_size = os.path.getsize(tmp_path) file_size = os.path.getsize(tmp_path)
if file_size == 0: if file_size == 0:
raise HTTPException(status_code=500, detail="Generated ZIP is empty") raise HTTPException(status_code=500, detail="Generated ZIP is empty")
def iterfile(): def iterfile():
with open(tmp_path, "rb") as f: with open(tmp_path, "rb") as f:
while True: while True:
chunk = f.read(8192) chunk = f.read(8192)
if not chunk: if not chunk:
break break
yield chunk yield chunk
filename = f"{task.name}_annotations.zip" filename = f"{task.name}_annotations.zip"
headers = { headers = {
"Content-Disposition": f'attachment; filename="{filename}"', "Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size), "Content-Length": str(file_size),
} }
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers) return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)

View File

@@ -1,4 +1,4 @@
"""Schemas for Auto Annotation tasks""" """Schemas for Auto Annotation tasks"""
from __future__ import annotations from __future__ import annotations
import json import json
@@ -7,24 +7,24 @@ from typing import List, Optional, Dict, Any
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict, model_validator from pydantic import BaseModel, Field, ConfigDict, model_validator
class AutoAnnotationConfig(BaseModel): class AutoAnnotationConfig(BaseModel):
"""自动标注任务配置(与前端 payload 对齐)""" """自动标注任务配置(与前端 payload 对齐)"""
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x") model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1") conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
target_classes: List[int] = Field( target_classes: List[int] = Field(
default_factory=list, default_factory=list,
alias="targetClasses", alias="targetClasses",
description="目标类别ID列表,空表示全部类别", description="目标类别ID列表,空表示全部类别",
) )
output_dataset_name: Optional[str] = Field( output_dataset_name: Optional[str] = Field(
default=None, default=None,
alias="outputDatasetName", alias="outputDatasetName",
description="自动标注结果要写入的新数据集名称(可选)", description="自动标注结果要写入的新数据集名称(可选)",
) )
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@@ -68,13 +68,18 @@ class OperatorPipelineStep(BaseModel):
return normalized return normalized
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
class CreateAutoAnnotationTaskRequest(BaseModel): class CreateAutoAnnotationTaskRequest(BaseModel):
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构""" """创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
name: str = Field(..., min_length=1, max_length=255, description="任务名称") name: str = Field(..., min_length=1, max_length=255, description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID") dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_type: Optional[str] = Field(
default=None,
alias="datasetType",
description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO(不传时由后端自动获取)",
)
config: Optional[AutoAnnotationConfig] = Field( config: Optional[AutoAnnotationConfig] = Field(
default=None, default=None,
description="兼容旧版 YOLO 任务配置", description="兼容旧版 YOLO 任务配置",
@@ -111,15 +116,16 @@ class CreateAutoAnnotationTaskRequest(BaseModel):
return self return self
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
class AutoAnnotationTaskResponse(BaseModel): class AutoAnnotationTaskResponse(BaseModel):
"""自动标注任务响应模型(列表/详情均可复用)""" """自动标注任务响应模型(列表/详情均可复用)"""
id: str = Field(..., description="任务ID") id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称") name: str = Field(..., description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID") dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称") dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
dataset_type: Optional[str] = Field(None, alias="datasetType", description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO")
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式") task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型") executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义") pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
@@ -128,11 +134,11 @@ class AutoAnnotationTaskResponse(BaseModel):
alias="sourceDatasets", alias="sourceDatasets",
description="本任务实际处理涉及到的所有数据集名称列表", description="本任务实际处理涉及到的所有数据集名称列表",
) )
config: Dict[str, Any] = Field(..., description="任务配置") config: Dict[str, Any] = Field(..., description="任务配置")
status: str = Field(..., description="任务状态") status: str = Field(..., description="任务状态")
progress: int = Field(..., description="任务进度 0-100") progress: int = Field(..., description="任务进度 0-100")
total_images: int = Field(..., alias="totalImages", description="总图片数") total_images: int = Field(..., alias="totalImages", description="总图片数")
processed_images: int = Field(..., alias="processedImages", description="已处理图片数") processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数") detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径") output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
output_dataset_id: Optional[str] = Field( output_dataset_id: Optional[str] = Field(
@@ -152,14 +158,14 @@ class AutoAnnotationTaskResponse(BaseModel):
created_at: datetime = Field(..., alias="createdAt", description="创建时间") created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间") completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
model_config = ConfigDict(populate_by_name=True, from_attributes=True) model_config = ConfigDict(populate_by_name=True, from_attributes=True)
class AutoAnnotationTaskListResponse(BaseModel): class AutoAnnotationTaskListResponse(BaseModel):
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构""" """自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表") content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
total: int = Field(..., description="总数") total: int = Field(..., description="总数")
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)

View File

@@ -15,13 +15,13 @@ from app.db.models.annotation_management import (
) )
from app.db.models.dataset_management import Dataset, DatasetFiles from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext from app.module.annotation.security import RequestUserContext
from ..schema.auto import ( from ..schema.auto import (
CreateAutoAnnotationTaskRequest, CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse, AutoAnnotationTaskResponse,
) )
class AutoAnnotationTaskService: class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)""" """自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
@@ -141,11 +141,12 @@ class AutoAnnotationTaskService:
user_context: RequestUserContext, user_context: RequestUserContext,
dataset_name: Optional[str] = None, dataset_name: Optional[str] = None,
total_images: int = 0, total_images: int = 0,
dataset_type: str = "IMAGE",
) -> AutoAnnotationTaskResponse: ) -> AutoAnnotationTaskResponse:
"""创建自动标注任务,初始状态为 pending。 """创建自动标注任务,初始状态为 pending。
这里仅插入任务记录,不负责真正执行 YOLO 推理, 这里仅插入任务记录,不负责真正执行 YOLO 推理,
后续可以由调度器/worker 读取该表并更新进度。 后续可以由调度器/worker 读取该表并更新进度。
""" """
now = datetime.now() now = datetime.now()
@@ -170,6 +171,7 @@ class AutoAnnotationTaskService:
name=request.name, name=request.name,
dataset_id=request.dataset_id, dataset_id=request.dataset_id,
dataset_name=dataset_name, dataset_name=dataset_name,
dataset_type=dataset_type,
created_by=user_context.user_id, created_by=user_context.user_id,
config=normalized_config, config=normalized_config,
task_mode=request.task_mode, task_mode=request.task_mode,
@@ -192,15 +194,15 @@ class AutoAnnotationTaskService:
db.add_all(operator_instances) db.add_all(operator_instances)
await db.commit() await db.commit()
await db.refresh(task) await db.refresh(task)
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集) # 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
resp = AutoAnnotationTaskResponse.model_validate(task) resp = AutoAnnotationTaskResponse.model_validate(task)
try: try:
resp.source_datasets = await self._compute_source_datasets(db, task) resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception: except Exception:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id] resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp return resp
def _apply_dataset_scope(self, query, user_context: RequestUserContext): def _apply_dataset_scope(self, query, user_context: RequestUserContext):
if user_context.is_admin: if user_context.is_admin:
return query return query
@@ -222,21 +224,21 @@ class AutoAnnotationTaskService:
query.order_by(AutoAnnotationTask.created_at.desc()) query.order_by(AutoAnnotationTask.created_at.desc())
) )
tasks: List[AutoAnnotationTask] = list(result.scalars().all()) tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = [] responses: List[AutoAnnotationTaskResponse] = []
for task in tasks: for task in tasks:
resp = AutoAnnotationTaskResponse.model_validate(task) resp = AutoAnnotationTaskResponse.model_validate(task)
try: try:
resp.source_datasets = await self._compute_source_datasets(db, task) resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception: except Exception:
# 出错时降级为单个 datasetName/datasetId # 出错时降级为单个 datasetName/datasetId
fallback_name = getattr(task, "dataset_name", None) fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "") fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
responses.append(resp) responses.append(resp)
return responses return responses
async def get_task( async def get_task(
self, self,
db: AsyncSession, db: AsyncSession,
@@ -252,43 +254,43 @@ class AutoAnnotationTaskService:
task = result.scalar_one_or_none() task = result.scalar_one_or_none()
if not task: if not task:
return None return None
resp = AutoAnnotationTaskResponse.model_validate(task) resp = AutoAnnotationTaskResponse.model_validate(task)
try: try:
resp.source_datasets = await self._compute_source_datasets(db, task) resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception: except Exception:
fallback_name = getattr(task, "dataset_name", None) fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "") fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp return resp
async def _compute_source_datasets( async def _compute_source_datasets(
self, self,
db: AsyncSession, db: AsyncSession,
task: AutoAnnotationTask, task: AutoAnnotationTask,
) -> List[str]: ) -> List[str]:
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。 """根据任务的 file_ids 推断实际涉及到的所有数据集名称。
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称; - 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。 - 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
""" """
file_ids = task.file_ids or [] file_ids = task.file_ids or []
if file_ids: if file_ids:
stmt = ( stmt = (
select(Dataset.name) select(Dataset.name)
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id) .join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
.where(DatasetFiles.id.in_(file_ids)) .where(DatasetFiles.id.in_(file_ids))
.distinct() .distinct()
) )
result = await db.execute(stmt) result = await db.execute(stmt)
names = [row[0] for row in result.fetchall() if row[0]] names = [row[0] for row in result.fetchall() if row[0]]
if names: if names:
return names return names
# 回退:只显示一个数据集 # 回退:只显示一个数据集
if task.dataset_name: if task.dataset_name:
return [task.dataset_name] return [task.dataset_name]
if task.dataset_id: if task.dataset_id:
return [task.dataset_id] return [task.dataset_id]
return [] return []
@@ -331,7 +333,7 @@ class AutoAnnotationTaskService:
fallback_id = getattr(task, "dataset_id", "") fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id] resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp return resp
async def soft_delete_task( async def soft_delete_task(
self, self,
db: AsyncSession, db: AsyncSession,
@@ -347,7 +349,7 @@ class AutoAnnotationTaskService:
task = result.scalar_one_or_none() task = result.scalar_one_or_none()
if not task: if not task:
return False return False
task.deleted_at = datetime.now() task.deleted_at = datetime.now()
await db.commit() await db.commit()
return True return True

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
-- =============================================
-- 自动标注任务支持多数据集类型迁移
-- 为 t_dm_auto_annotation_tasks 表添加 dataset_type 列
-- =============================================
USE datamate;
SET @db_name = DATABASE();
-- 添加 dataset_type 列(IMAGE/TEXT/AUDIO/VIDEO),已有记录默认为 IMAGE
SET @ddl = (
SELECT IF(
EXISTS(
SELECT 1
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = @db_name
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
AND COLUMN_NAME = 'dataset_type'
),
'SELECT ''skip: column dataset_type already exists''',
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN dataset_type VARCHAR(50) NOT NULL DEFAULT ''IMAGE'' COMMENT ''数据集类型: IMAGE/TEXT/AUDIO/VIDEO'' AFTER dataset_name'
)
);
PREPARE stmt FROM @ddl;
EXECUTE stmt;
DEALLOCATE PREPARE stmt;