You've already forked DataMate
feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)
移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO 类型数据集均可用于自动标注任务。 - 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列 - 后端 schema/API/service 全链路传递 dataset_type - Worker 动态构建 sample key(image/text/audio/video)和输出目录 - 前端移除数据集类型校验,下拉框显示数据集类型标识 - 输出数据集继承源数据集类型,不再硬编码为 IMAGE - 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import { ArrowLeft } from "lucide-react";
|
|||||||
import { Link, useNavigate } from "react-router";
|
import { Link, useNavigate } from "react-router";
|
||||||
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
import { queryDatasetsUsingGet } from "@/pages/DataManagement/dataset.api";
|
||||||
import { mapDataset } from "@/pages/DataManagement/dataset.const";
|
import { mapDataset } from "@/pages/DataManagement/dataset.const";
|
||||||
import { Dataset, DatasetType } from "@/pages/DataManagement/dataset.model";
|
import { Dataset } from "@/pages/DataManagement/dataset.model";
|
||||||
import { createAnnotationOperatorTaskUsingPost } from "../annotation.api";
|
import { createAnnotationOperatorTaskUsingPost } from "../annotation.api";
|
||||||
import { useCreateStepTwo } from "./hooks/useCreateStepTwo";
|
import { useCreateStepTwo } from "./hooks/useCreateStepTwo";
|
||||||
import PipelinePreview from "./components/PipelinePreview";
|
import PipelinePreview from "./components/PipelinePreview";
|
||||||
@@ -85,11 +85,6 @@ export default function AnnotationOperatorTaskCreate() {
|
|||||||
try {
|
try {
|
||||||
if (currentStep === 1) {
|
if (currentStep === 1) {
|
||||||
await form.validateFields();
|
await form.validateFields();
|
||||||
|
|
||||||
if (selectedDataset?.datasetType !== DatasetType.IMAGE) {
|
|
||||||
message.error("自动标注算子编排当前仅支持图片数据集");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
setCurrentStep((prev) => Math.min(prev + 1, 2));
|
setCurrentStep((prev) => Math.min(prev + 1, 2));
|
||||||
} catch {
|
} catch {
|
||||||
@@ -109,11 +104,6 @@ export default function AnnotationOperatorTaskCreate() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectedDataset?.datasetType !== DatasetType.IMAGE) {
|
|
||||||
message.error("自动标注算子编排当前仅支持图片数据集");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const outputDatasetName = values.outputDatasetName?.trim();
|
const outputDatasetName = values.outputDatasetName?.trim();
|
||||||
const pipeline = selectedOperators.map((operator, index) => {
|
const pipeline = selectedOperators.map((operator, index) => {
|
||||||
const overrides = {
|
const overrides = {
|
||||||
@@ -200,10 +190,10 @@ export default function AnnotationOperatorTaskCreate() {
|
|||||||
label="选择数据集"
|
label="选择数据集"
|
||||||
name="datasetId"
|
name="datasetId"
|
||||||
rules={[{ required: true, message: "请选择数据集" }]}
|
rules={[{ required: true, message: "请选择数据集" }]}
|
||||||
extra="自动标注算子编排当前仅支持图片数据集"
|
extra="请选择用于自动标注的数据集"
|
||||||
>
|
>
|
||||||
<Select
|
<Select
|
||||||
placeholder="请选择图片数据集"
|
placeholder="请选择数据集"
|
||||||
optionFilterProp="label"
|
optionFilterProp="label"
|
||||||
options={datasets.map((dataset) => ({
|
options={datasets.map((dataset) => ({
|
||||||
label: (
|
label: (
|
||||||
@@ -215,12 +205,11 @@ export default function AnnotationOperatorTaskCreate() {
|
|||||||
{dataset.name}
|
{dataset.name}
|
||||||
</div>
|
</div>
|
||||||
<div className="text-xs text-gray-500">
|
<div className="text-xs text-gray-500">
|
||||||
{dataset?.fileCount} 文件 • {dataset.size}
|
{dataset.datasetType} • {dataset?.fileCount} 文件 • {dataset.size}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
),
|
),
|
||||||
value: dataset.id,
|
value: dataset.id,
|
||||||
disabled: dataset.datasetType !== DatasetType.IMAGE,
|
|
||||||
}))}
|
}))}
|
||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ class AnnotationResult(Base):
|
|||||||
return f"<AnnotationResult(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>"
|
return f"<AnnotationResult(id={self.id}, project_id={self.project_id}, file_id={self.file_id})>"
|
||||||
|
|
||||||
|
|
||||||
class AutoAnnotationTask(Base):
|
class AutoAnnotationTask(Base):
|
||||||
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
|
"""自动标注任务模型,对应表 t_dm_auto_annotation_tasks"""
|
||||||
|
|
||||||
__tablename__ = "t_dm_auto_annotation_tasks"
|
__tablename__ = "t_dm_auto_annotation_tasks"
|
||||||
@@ -206,94 +206,98 @@ class AutoAnnotationTask(Base):
|
|||||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID"
|
String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID"
|
||||||
)
|
)
|
||||||
name = Column(String(255), nullable=False, comment="任务名称")
|
name = Column(String(255), nullable=False, comment="任务名称")
|
||||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||||
dataset_name = Column(
|
dataset_name = Column(
|
||||||
String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)"
|
String(255), nullable=True, comment="数据集名称(冗余字段,方便查询)"
|
||||||
)
|
)
|
||||||
created_by = Column(String(255), nullable=True, comment="任务创建人")
|
dataset_type = Column(
|
||||||
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
String(50), nullable=False, default="IMAGE",
|
||||||
file_ids = Column(
|
comment="数据集类型: IMAGE/TEXT/AUDIO/VIDEO",
|
||||||
JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像"
|
)
|
||||||
)
|
created_by = Column(String(255), nullable=True, comment="任务创建人")
|
||||||
status = Column(
|
config = Column(JSON, nullable=False, comment="任务配置(模型规模、置信度等)")
|
||||||
String(50),
|
file_ids = Column(
|
||||||
nullable=False,
|
JSON, nullable=True, comment="要处理的文件ID列表,为空则处理数据集所有图像"
|
||||||
default="pending",
|
)
|
||||||
comment="任务状态: pending/running/completed/failed/stopped",
|
status = Column(
|
||||||
)
|
String(50),
|
||||||
task_mode = Column(
|
nullable=False,
|
||||||
String(32),
|
default="pending",
|
||||||
nullable=False,
|
comment="任务状态: pending/running/completed/failed/stopped",
|
||||||
default="legacy_yolo",
|
)
|
||||||
comment="任务模式: legacy_yolo/pipeline",
|
task_mode = Column(
|
||||||
)
|
String(32),
|
||||||
executor_type = Column(
|
nullable=False,
|
||||||
String(32),
|
default="legacy_yolo",
|
||||||
nullable=False,
|
comment="任务模式: legacy_yolo/pipeline",
|
||||||
default="annotation_local",
|
)
|
||||||
comment="执行器类型",
|
executor_type = Column(
|
||||||
)
|
String(32),
|
||||||
pipeline = Column(JSON, nullable=True, comment="算子编排定义")
|
nullable=False,
|
||||||
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
default="annotation_local",
|
||||||
stop_requested = Column(Boolean, default=False, comment="是否请求停止")
|
comment="执行器类型",
|
||||||
total_images = Column(Integer, default=0, comment="总图片数")
|
)
|
||||||
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
pipeline = Column(JSON, nullable=True, comment="算子编排定义")
|
||||||
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
progress = Column(Integer, default=0, comment="任务进度 0-100")
|
||||||
output_path = Column(String(500), nullable=True, comment="输出路径")
|
stop_requested = Column(Boolean, default=False, comment="是否请求停止")
|
||||||
output_dataset_id = Column(String(36), nullable=True, comment="输出数据集ID")
|
total_images = Column(Integer, default=0, comment="总图片数")
|
||||||
error_message = Column(Text, nullable=True, comment="错误信息")
|
processed_images = Column(Integer, default=0, comment="已处理图片数")
|
||||||
created_at = Column(
|
detected_objects = Column(Integer, default=0, comment="检测到的对象总数")
|
||||||
TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间"
|
output_path = Column(String(500), nullable=True, comment="输出路径")
|
||||||
)
|
output_dataset_id = Column(String(36), nullable=True, comment="输出数据集ID")
|
||||||
|
error_message = Column(Text, nullable=True, comment="错误信息")
|
||||||
|
created_at = Column(
|
||||||
|
TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间"
|
||||||
|
)
|
||||||
updated_at = Column(
|
updated_at = Column(
|
||||||
TIMESTAMP,
|
TIMESTAMP,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
onupdate=func.current_timestamp(),
|
onupdate=func.current_timestamp(),
|
||||||
comment="更新时间",
|
comment="更新时间",
|
||||||
)
|
)
|
||||||
started_at = Column(TIMESTAMP, nullable=True, comment="任务启动时间")
|
started_at = Column(TIMESTAMP, nullable=True, comment="任务启动时间")
|
||||||
heartbeat_at = Column(TIMESTAMP, nullable=True, comment="worker心跳时间")
|
heartbeat_at = Column(TIMESTAMP, nullable=True, comment="worker心跳时间")
|
||||||
run_token = Column(String(64), nullable=True, comment="运行令牌")
|
run_token = Column(String(64), nullable=True, comment="运行令牌")
|
||||||
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
|
||||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||||
|
|
||||||
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
|
def __repr__(self) -> str: # pragma: no cover - repr 简单返回
|
||||||
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
|
return f"<AutoAnnotationTask(id={self.id}, name={self.name}, status={self.status})>"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_deleted(self) -> bool:
|
def is_deleted(self) -> bool:
|
||||||
"""检查是否已被软删除"""
|
"""检查是否已被软删除"""
|
||||||
return self.deleted_at is not None
|
return self.deleted_at is not None
|
||||||
|
|
||||||
|
|
||||||
class AnnotationTaskOperatorInstance(Base):
|
class AnnotationTaskOperatorInstance(Base):
|
||||||
"""自动标注任务内算子实例模型,对应表 t_dm_annotation_task_operator_instance"""
|
"""自动标注任务内算子实例模型,对应表 t_dm_annotation_task_operator_instance"""
|
||||||
|
|
||||||
__tablename__ = "t_dm_annotation_task_operator_instance"
|
__tablename__ = "t_dm_annotation_task_operator_instance"
|
||||||
|
|
||||||
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="自增主键")
|
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="自增主键")
|
||||||
task_id = Column(String(36), nullable=False, comment="自动标注任务ID")
|
task_id = Column(String(36), nullable=False, comment="自动标注任务ID")
|
||||||
op_index = Column(Integer, nullable=False, comment="算子顺序(从1开始)")
|
op_index = Column(Integer, nullable=False, comment="算子顺序(从1开始)")
|
||||||
operator_id = Column(String(64), nullable=False, comment="算子ID(raw_id)")
|
operator_id = Column(String(64), nullable=False, comment="算子ID(raw_id)")
|
||||||
settings_override = Column(JSON, nullable=True, comment="任务级算子参数覆盖")
|
settings_override = Column(JSON, nullable=True, comment="任务级算子参数覆盖")
|
||||||
inputs = Column(String(64), nullable=True, comment="输入模态")
|
inputs = Column(String(64), nullable=True, comment="输入模态")
|
||||||
outputs = Column(String(64), nullable=True, comment="输出模态")
|
outputs = Column(String(64), nullable=True, comment="输出模态")
|
||||||
created_at = Column(
|
created_at = Column(
|
||||||
TIMESTAMP,
|
TIMESTAMP,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
comment="创建时间",
|
comment="创建时间",
|
||||||
)
|
)
|
||||||
updated_at = Column(
|
updated_at = Column(
|
||||||
TIMESTAMP,
|
TIMESTAMP,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
onupdate=func.current_timestamp(),
|
onupdate=func.current_timestamp(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
comment="更新时间",
|
comment="更新时间",
|
||||||
)
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("task_id", "op_index", name="uk_task_op_index"),
|
UniqueConstraint("task_id", "op_index", name="uk_task_op_index"),
|
||||||
Index("idx_task_id", "task_id"),
|
Index("idx_task_id", "task_id"),
|
||||||
Index("idx_operator_id", "operator_id"),
|
Index("idx_operator_id", "operator_id"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from typing import List, Literal
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db.session import get_db
|
from app.db.session import get_db
|
||||||
from app.module.shared.schema import StandardResponse
|
from app.module.shared.schema import StandardResponse
|
||||||
from app.module.dataset import DatasetManagementService
|
from app.module.dataset import DatasetManagementService
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
@@ -29,15 +29,15 @@ from ..security import (
|
|||||||
from ..schema.auto import (
|
from ..schema.auto import (
|
||||||
CreateAutoAnnotationTaskRequest,
|
CreateAutoAnnotationTaskRequest,
|
||||||
AutoAnnotationTaskResponse,
|
AutoAnnotationTaskResponse,
|
||||||
)
|
)
|
||||||
from ..service.auto import AutoAnnotationTaskService
|
from ..service.auto import AutoAnnotationTaskService
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["annotation/auto"],
|
tags=["annotation/auto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
service = AutoAnnotationTaskService()
|
service = AutoAnnotationTaskService()
|
||||||
|
|
||||||
|
|
||||||
@@ -85,23 +85,28 @@ async def _create_task_internal(
|
|||||||
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
|
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
|
||||||
# 尝试获取数据集名称和总量用于冗余字段
|
# 尝试获取数据集名称和总量用于冗余字段
|
||||||
dataset_name = None
|
dataset_name = None
|
||||||
|
dataset_type = "IMAGE"
|
||||||
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
|
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
|
||||||
try:
|
try:
|
||||||
dm_client = DatasetManagementService(db)
|
dm_client = DatasetManagementService(db)
|
||||||
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
|
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
dataset_name = dataset.name
|
dataset_name = dataset.name
|
||||||
|
dataset_type = getattr(dataset, "datasetType", None) or "IMAGE"
|
||||||
if not normalized_request.file_ids:
|
if not normalized_request.file_ids:
|
||||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||||
except Exception as e: # pragma: no cover - 容错
|
except Exception as e: # pragma: no cover - 容错
|
||||||
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
|
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
|
||||||
|
|
||||||
|
resolved_dataset_type = normalized_request.dataset_type or dataset_type
|
||||||
|
|
||||||
return await service.create_task(
|
return await service.create_task(
|
||||||
db,
|
db,
|
||||||
normalized_request,
|
normalized_request,
|
||||||
user_context=user_context,
|
user_context=user_context,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
total_images=total_images,
|
total_images=total_images,
|
||||||
|
dataset_type=resolved_dataset_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,10 +182,10 @@ async def get_auto_annotation_task_status(
|
|||||||
task = await service.get_task(db, task_id, user_context)
|
task = await service.get_task(db, task_id, user_context)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="success",
|
message="success",
|
||||||
data=task,
|
data=task,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -192,12 +197,12 @@ async def delete_auto_annotation_task(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||||
|
|
||||||
ok = await service.soft_delete_task(db, task_id, user_context)
|
ok = await service.soft_delete_task(db, task_id, user_context)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="success",
|
message="success",
|
||||||
@@ -232,50 +237,50 @@ async def download_auto_annotation_result(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""下载指定自动标注任务的结果 ZIP。"""
|
"""下载指定自动标注任务的结果 ZIP。"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
# 复用服务层获取任务信息
|
# 复用服务层获取任务信息
|
||||||
task = await service.get_task(db, task_id, user_context)
|
task = await service.get_task(db, task_id, user_context)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
if not task.output_path:
|
if not task.output_path:
|
||||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||||
|
|
||||||
output_dir = task.output_path
|
output_dir = task.output_path
|
||||||
if not os.path.isdir(output_dir):
|
if not os.path.isdir(output_dir):
|
||||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||||
|
|
||||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||||
os.close(tmp_fd)
|
os.close(tmp_fd)
|
||||||
|
|
||||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
for root, _, files in os.walk(output_dir):
|
for root, _, files in os.walk(output_dir):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
file_path = os.path.join(root, filename)
|
file_path = os.path.join(root, filename)
|
||||||
arcname = os.path.relpath(file_path, output_dir)
|
arcname = os.path.relpath(file_path, output_dir)
|
||||||
zf.write(file_path, arcname)
|
zf.write(file_path, arcname)
|
||||||
|
|
||||||
file_size = os.path.getsize(tmp_path)
|
file_size = os.path.getsize(tmp_path)
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||||
|
|
||||||
def iterfile():
|
def iterfile():
|
||||||
with open(tmp_path, "rb") as f:
|
with open(tmp_path, "rb") as f:
|
||||||
while True:
|
while True:
|
||||||
chunk = f.read(8192)
|
chunk = f.read(8192)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
filename = f"{task.name}_annotations.zip"
|
filename = f"{task.name}_annotations.zip"
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
"Content-Length": str(file_size),
|
"Content-Length": str(file_size),
|
||||||
}
|
}
|
||||||
|
|
||||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Schemas for Auto Annotation tasks"""
|
"""Schemas for Auto Annotation tasks"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -7,24 +7,24 @@ from typing import List, Optional, Dict, Any
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||||
|
|
||||||
|
|
||||||
class AutoAnnotationConfig(BaseModel):
|
class AutoAnnotationConfig(BaseModel):
|
||||||
"""自动标注任务配置(与前端 payload 对齐)"""
|
"""自动标注任务配置(与前端 payload 对齐)"""
|
||||||
|
|
||||||
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
|
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
|
||||||
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
|
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
|
||||||
target_classes: List[int] = Field(
|
target_classes: List[int] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
alias="targetClasses",
|
alias="targetClasses",
|
||||||
description="目标类别ID列表,空表示全部类别",
|
description="目标类别ID列表,空表示全部类别",
|
||||||
)
|
)
|
||||||
output_dataset_name: Optional[str] = Field(
|
output_dataset_name: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
alias="outputDatasetName",
|
alias="outputDatasetName",
|
||||||
description="自动标注结果要写入的新数据集名称(可选)",
|
description="自动标注结果要写入的新数据集名称(可选)",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,13 +68,18 @@ class OperatorPipelineStep(BaseModel):
|
|||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
class CreateAutoAnnotationTaskRequest(BaseModel):
|
class CreateAutoAnnotationTaskRequest(BaseModel):
|
||||||
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
|
||||||
|
|
||||||
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
|
||||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||||
|
dataset_type: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
alias="datasetType",
|
||||||
|
description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO(不传时由后端自动获取)",
|
||||||
|
)
|
||||||
config: Optional[AutoAnnotationConfig] = Field(
|
config: Optional[AutoAnnotationConfig] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="兼容旧版 YOLO 任务配置",
|
description="兼容旧版 YOLO 任务配置",
|
||||||
@@ -111,15 +116,16 @@ class CreateAutoAnnotationTaskRequest(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
class AutoAnnotationTaskResponse(BaseModel):
|
class AutoAnnotationTaskResponse(BaseModel):
|
||||||
"""自动标注任务响应模型(列表/详情均可复用)"""
|
"""自动标注任务响应模型(列表/详情均可复用)"""
|
||||||
|
|
||||||
id: str = Field(..., description="任务ID")
|
id: str = Field(..., description="任务ID")
|
||||||
name: str = Field(..., description="任务名称")
|
name: str = Field(..., description="任务名称")
|
||||||
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
|
||||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||||
|
dataset_type: Optional[str] = Field(None, alias="datasetType", description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO")
|
||||||
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
|
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
|
||||||
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
|
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
|
||||||
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
|
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
|
||||||
@@ -128,11 +134,11 @@ class AutoAnnotationTaskResponse(BaseModel):
|
|||||||
alias="sourceDatasets",
|
alias="sourceDatasets",
|
||||||
description="本任务实际处理涉及到的所有数据集名称列表",
|
description="本任务实际处理涉及到的所有数据集名称列表",
|
||||||
)
|
)
|
||||||
config: Dict[str, Any] = Field(..., description="任务配置")
|
config: Dict[str, Any] = Field(..., description="任务配置")
|
||||||
status: str = Field(..., description="任务状态")
|
status: str = Field(..., description="任务状态")
|
||||||
progress: int = Field(..., description="任务进度 0-100")
|
progress: int = Field(..., description="任务进度 0-100")
|
||||||
total_images: int = Field(..., alias="totalImages", description="总图片数")
|
total_images: int = Field(..., alias="totalImages", description="总图片数")
|
||||||
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
|
||||||
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
|
||||||
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
|
||||||
output_dataset_id: Optional[str] = Field(
|
output_dataset_id: Optional[str] = Field(
|
||||||
@@ -152,14 +158,14 @@ class AutoAnnotationTaskResponse(BaseModel):
|
|||||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||||
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class AutoAnnotationTaskListResponse(BaseModel):
|
class AutoAnnotationTaskListResponse(BaseModel):
|
||||||
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
|
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
|
||||||
|
|
||||||
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
|
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
|
||||||
total: int = Field(..., description="总数")
|
total: int = Field(..., description="总数")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|||||||
@@ -15,13 +15,13 @@ from app.db.models.annotation_management import (
|
|||||||
)
|
)
|
||||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||||
from app.module.annotation.security import RequestUserContext
|
from app.module.annotation.security import RequestUserContext
|
||||||
|
|
||||||
from ..schema.auto import (
|
from ..schema.auto import (
|
||||||
CreateAutoAnnotationTaskRequest,
|
CreateAutoAnnotationTaskRequest,
|
||||||
AutoAnnotationTaskResponse,
|
AutoAnnotationTaskResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoAnnotationTaskService:
|
class AutoAnnotationTaskService:
|
||||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||||
|
|
||||||
@@ -141,11 +141,12 @@ class AutoAnnotationTaskService:
|
|||||||
user_context: RequestUserContext,
|
user_context: RequestUserContext,
|
||||||
dataset_name: Optional[str] = None,
|
dataset_name: Optional[str] = None,
|
||||||
total_images: int = 0,
|
total_images: int = 0,
|
||||||
|
dataset_type: str = "IMAGE",
|
||||||
) -> AutoAnnotationTaskResponse:
|
) -> AutoAnnotationTaskResponse:
|
||||||
"""创建自动标注任务,初始状态为 pending。
|
"""创建自动标注任务,初始状态为 pending。
|
||||||
|
|
||||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||||
后续可以由调度器/worker 读取该表并更新进度。
|
后续可以由调度器/worker 读取该表并更新进度。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
@@ -170,6 +171,7 @@ class AutoAnnotationTaskService:
|
|||||||
name=request.name,
|
name=request.name,
|
||||||
dataset_id=request.dataset_id,
|
dataset_id=request.dataset_id,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
|
dataset_type=dataset_type,
|
||||||
created_by=user_context.user_id,
|
created_by=user_context.user_id,
|
||||||
config=normalized_config,
|
config=normalized_config,
|
||||||
task_mode=request.task_mode,
|
task_mode=request.task_mode,
|
||||||
@@ -192,15 +194,15 @@ class AutoAnnotationTaskService:
|
|||||||
db.add_all(operator_instances)
|
db.add_all(operator_instances)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(task)
|
await db.refresh(task)
|
||||||
|
|
||||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||||
try:
|
try:
|
||||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||||
except Exception:
|
except Exception:
|
||||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
||||||
if user_context.is_admin:
|
if user_context.is_admin:
|
||||||
return query
|
return query
|
||||||
@@ -222,21 +224,21 @@ class AutoAnnotationTaskService:
|
|||||||
query.order_by(AutoAnnotationTask.created_at.desc())
|
query.order_by(AutoAnnotationTask.created_at.desc())
|
||||||
)
|
)
|
||||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||||
|
|
||||||
responses: List[AutoAnnotationTaskResponse] = []
|
responses: List[AutoAnnotationTaskResponse] = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||||
try:
|
try:
|
||||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 出错时降级为单个 datasetName/datasetId
|
# 出错时降级为单个 datasetName/datasetId
|
||||||
fallback_name = getattr(task, "dataset_name", None)
|
fallback_name = getattr(task, "dataset_name", None)
|
||||||
fallback_id = getattr(task, "dataset_id", "")
|
fallback_id = getattr(task, "dataset_id", "")
|
||||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||||
responses.append(resp)
|
responses.append(resp)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
async def get_task(
|
async def get_task(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -252,43 +254,43 @@ class AutoAnnotationTaskService:
|
|||||||
task = result.scalar_one_or_none()
|
task = result.scalar_one_or_none()
|
||||||
if not task:
|
if not task:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||||
try:
|
try:
|
||||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||||
except Exception:
|
except Exception:
|
||||||
fallback_name = getattr(task, "dataset_name", None)
|
fallback_name = getattr(task, "dataset_name", None)
|
||||||
fallback_id = getattr(task, "dataset_id", "")
|
fallback_id = getattr(task, "dataset_id", "")
|
||||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def _compute_source_datasets(
|
async def _compute_source_datasets(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
task: AutoAnnotationTask,
|
task: AutoAnnotationTask,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||||
|
|
||||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_ids = task.file_ids or []
|
file_ids = task.file_ids or []
|
||||||
if file_ids:
|
if file_ids:
|
||||||
stmt = (
|
stmt = (
|
||||||
select(Dataset.name)
|
select(Dataset.name)
|
||||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||||
.where(DatasetFiles.id.in_(file_ids))
|
.where(DatasetFiles.id.in_(file_ids))
|
||||||
.distinct()
|
.distinct()
|
||||||
)
|
)
|
||||||
result = await db.execute(stmt)
|
result = await db.execute(stmt)
|
||||||
names = [row[0] for row in result.fetchall() if row[0]]
|
names = [row[0] for row in result.fetchall() if row[0]]
|
||||||
if names:
|
if names:
|
||||||
return names
|
return names
|
||||||
|
|
||||||
# 回退:只显示一个数据集
|
# 回退:只显示一个数据集
|
||||||
if task.dataset_name:
|
if task.dataset_name:
|
||||||
return [task.dataset_name]
|
return [task.dataset_name]
|
||||||
if task.dataset_id:
|
if task.dataset_id:
|
||||||
return [task.dataset_id]
|
return [task.dataset_id]
|
||||||
return []
|
return []
|
||||||
@@ -331,7 +333,7 @@ class AutoAnnotationTaskService:
|
|||||||
fallback_id = getattr(task, "dataset_id", "")
|
fallback_id = getattr(task, "dataset_id", "")
|
||||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def soft_delete_task(
|
async def soft_delete_task(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -347,7 +349,7 @@ class AutoAnnotationTaskService:
|
|||||||
task = result.scalar_one_or_none()
|
task = result.scalar_one_or_none()
|
||||||
if not task:
|
if not task:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
task.deleted_at = datetime.now()
|
task.deleted_at = datetime.now()
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
25
scripts/db/data-annotation-multitype-migration.sql
Normal file
25
scripts/db/data-annotation-multitype-migration.sql
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
-- =============================================
|
||||||
|
-- 自动标注任务支持多数据集类型迁移
|
||||||
|
-- 为 t_dm_auto_annotation_tasks 表添加 dataset_type 列
|
||||||
|
-- =============================================
|
||||||
|
|
||||||
|
USE datamate;
|
||||||
|
SET @db_name = DATABASE();
|
||||||
|
|
||||||
|
-- 添加 dataset_type 列(IMAGE/TEXT/AUDIO/VIDEO),已有记录默认为 IMAGE
|
||||||
|
SET @ddl = (
|
||||||
|
SELECT IF(
|
||||||
|
EXISTS(
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.COLUMNS
|
||||||
|
WHERE TABLE_SCHEMA = @db_name
|
||||||
|
AND TABLE_NAME = 't_dm_auto_annotation_tasks'
|
||||||
|
AND COLUMN_NAME = 'dataset_type'
|
||||||
|
),
|
||||||
|
'SELECT ''skip: column dataset_type already exists''',
|
||||||
|
'ALTER TABLE t_dm_auto_annotation_tasks ADD COLUMN dataset_type VARCHAR(50) NOT NULL DEFAULT ''IMAGE'' COMMENT ''数据集类型: IMAGE/TEXT/AUDIO/VIDEO'' AFTER dataset_name'
|
||||||
|
)
|
||||||
|
);
|
||||||
|
PREPARE stmt FROM @ddl;
|
||||||
|
EXECUTE stmt;
|
||||||
|
DEALLOCATE PREPARE stmt;
|
||||||
Reference in New Issue
Block a user