feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)

移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO
类型数据集均可用于自动标注任务。

- 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列
- 后端 schema/API/service 全链路传递 dataset_type
- Worker 动态构建 sample key(image/text/audio/video)和输出目录
- 前端移除数据集类型校验,下拉框显示数据集类型标识
- 输出数据集继承源数据集类型,不再硬编码为 IMAGE
- 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 23:23:05 +08:00
parent 807c2289e2
commit 8ffa131fad
7 changed files with 1161 additions and 1082 deletions

View File

@@ -15,9 +15,9 @@ from typing import List, Literal
from fastapi import APIRouter, Depends, HTTPException, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
@@ -29,15 +29,15 @@ from ..security import (
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
)
from ..service.auto import AutoAnnotationTaskService
router = APIRouter(
tags=["annotation/auto"],
)
logger = get_logger(__name__)
logger = get_logger(__name__)
service = AutoAnnotationTaskService()
@@ -85,23 +85,28 @@ async def _create_task_internal(
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
# 尝试获取数据集名称和总量用于冗余字段
dataset_name = None
dataset_type = "IMAGE"
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
try:
dm_client = DatasetManagementService(db)
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
if dataset is not None:
dataset_name = dataset.name
dataset_type = getattr(dataset, "datasetType", None) or "IMAGE"
if not normalized_request.file_ids:
total_images = getattr(dataset, "fileCount", 0) or 0
except Exception as e: # pragma: no cover - 容错
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
resolved_dataset_type = normalized_request.dataset_type or dataset_type
return await service.create_task(
db,
normalized_request,
user_context=user_context,
dataset_name=dataset_name,
total_images=total_images,
dataset_type=resolved_dataset_type,
)
@@ -177,10 +182,10 @@ async def get_auto_annotation_task_status(
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
return StandardResponse(
code=200,
message="success",
data=task,
)
@@ -192,12 +197,12 @@ async def delete_auto_annotation_task(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id, user_context)
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
return StandardResponse(
code=200,
message="success",
@@ -232,50 +237,50 @@ async def download_auto_annotation_result(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。"""
"""下载指定自动标注任务的结果 ZIP。"""
import os
import zipfile
import tempfile
# 复用服务层获取任务信息
# 复用服务层获取任务信息
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not task.output_path:
raise HTTPException(status_code=400, detail="Task has no output path")
output_dir = task.output_path
if not os.path.isdir(output_dir):
raise HTTPException(status_code=404, detail="Output directory not found")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(output_dir):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, output_dir)
zf.write(file_path, arcname)
file_size = os.path.getsize(tmp_path)
if file_size == 0:
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
def iterfile():
with open(tmp_path, "rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
break
yield chunk
filename = f"{task.name}_annotations.zip"
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size),
}
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not task.output_path:
raise HTTPException(status_code=400, detail="Task has no output path")
output_dir = task.output_path
if not os.path.isdir(output_dir):
raise HTTPException(status_code=404, detail="Output directory not found")
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(output_dir):
for filename in files:
file_path = os.path.join(root, filename)
arcname = os.path.relpath(file_path, output_dir)
zf.write(file_path, arcname)
file_size = os.path.getsize(tmp_path)
if file_size == 0:
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
def iterfile():
with open(tmp_path, "rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
break
yield chunk
filename = f"{task.name}_annotations.zip"
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size),
}
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)

View File

@@ -1,4 +1,4 @@
"""Schemas for Auto Annotation tasks"""
"""Schemas for Auto Annotation tasks"""
from __future__ import annotations
import json
@@ -7,24 +7,24 @@ from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict, model_validator
class AutoAnnotationConfig(BaseModel):
"""自动标注任务配置(与前端 payload 对齐)"""
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
target_classes: List[int] = Field(
default_factory=list,
alias="targetClasses",
description="目标类别ID列表,空表示全部类别",
)
output_dataset_name: Optional[str] = Field(
default=None,
alias="outputDatasetName",
description="自动标注结果要写入的新数据集名称(可选)",
)
"""自动标注任务配置(与前端 payload 对齐)"""
model_size: str = Field(alias="modelSize", description="模型规模: n/s/m/l/x")
conf_threshold: float = Field(alias="confThreshold", description="置信度阈值 0-1")
target_classes: List[int] = Field(
default_factory=list,
alias="targetClasses",
description="目标类别ID列表,空表示全部类别",
)
output_dataset_name: Optional[str] = Field(
default=None,
alias="outputDatasetName",
description="自动标注结果要写入的新数据集名称(可选)",
)
model_config = ConfigDict(populate_by_name=True)
@@ -68,13 +68,18 @@ class OperatorPipelineStep(BaseModel):
return normalized
model_config = ConfigDict(populate_by_name=True)
class CreateAutoAnnotationTaskRequest(BaseModel):
"""创建自动标注任务的请求体,对齐前端 CreateAutoAnnotationDialog 发送的结构"""
name: str = Field(..., min_length=1, max_length=255, description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_type: Optional[str] = Field(
default=None,
alias="datasetType",
description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO(不传时由后端自动获取)",
)
config: Optional[AutoAnnotationConfig] = Field(
default=None,
description="兼容旧版 YOLO 任务配置",
@@ -111,15 +116,16 @@ class CreateAutoAnnotationTaskRequest(BaseModel):
return self
model_config = ConfigDict(populate_by_name=True)
class AutoAnnotationTaskResponse(BaseModel):
"""自动标注任务响应模型(列表/详情均可复用)"""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
"""自动标注任务响应模型(列表/详情均可复用)"""
id: str = Field(..., description="任务ID")
name: str = Field(..., description="任务名称")
dataset_id: str = Field(..., alias="datasetId", description="数据集ID")
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
dataset_type: Optional[str] = Field(None, alias="datasetType", description="数据集类型: IMAGE/TEXT/AUDIO/VIDEO")
task_mode: Optional[str] = Field(None, alias="taskMode", description="任务模式")
executor_type: Optional[str] = Field(None, alias="executorType", description="执行器类型")
pipeline: Optional[List[Dict[str, Any]]] = Field(None, description="算子编排定义")
@@ -128,11 +134,11 @@ class AutoAnnotationTaskResponse(BaseModel):
alias="sourceDatasets",
description="本任务实际处理涉及到的所有数据集名称列表",
)
config: Dict[str, Any] = Field(..., description="任务配置")
status: str = Field(..., description="任务状态")
progress: int = Field(..., description="任务进度 0-100")
total_images: int = Field(..., alias="totalImages", description="总图片数")
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
config: Dict[str, Any] = Field(..., description="任务配置")
status: str = Field(..., description="任务状态")
progress: int = Field(..., description="任务进度 0-100")
total_images: int = Field(..., alias="totalImages", description="总图片数")
processed_images: int = Field(..., alias="processedImages", description="已处理图片数")
detected_objects: int = Field(..., alias="detectedObjects", description="检测到的对象总数")
output_path: Optional[str] = Field(None, alias="outputPath", description="输出路径")
output_dataset_id: Optional[str] = Field(
@@ -152,14 +158,14 @@ class AutoAnnotationTaskResponse(BaseModel):
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
completed_at: Optional[datetime] = Field(None, alias="completedAt", description="完成时间")
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
class AutoAnnotationTaskListResponse(BaseModel):
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
total: int = Field(..., description="总数")
model_config = ConfigDict(populate_by_name=True)
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
class AutoAnnotationTaskListResponse(BaseModel):
"""自动标注任务列表响应,目前前端直接使用数组,这里预留分页结构"""
content: List[AutoAnnotationTaskResponse] = Field(..., description="任务列表")
total: int = Field(..., description="总数")
model_config = ConfigDict(populate_by_name=True)

View File

@@ -15,13 +15,13 @@ from app.db.models.annotation_management import (
)
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
)
class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
@@ -141,11 +141,12 @@ class AutoAnnotationTaskService:
user_context: RequestUserContext,
dataset_name: Optional[str] = None,
total_images: int = 0,
dataset_type: str = "IMAGE",
) -> AutoAnnotationTaskResponse:
"""创建自动标注任务,初始状态为 pending。
这里仅插入任务记录,不负责真正执行 YOLO 推理,
后续可以由调度器/worker 读取该表并更新进度。
"""创建自动标注任务,初始状态为 pending。
这里仅插入任务记录,不负责真正执行 YOLO 推理,
后续可以由调度器/worker 读取该表并更新进度。
"""
now = datetime.now()
@@ -170,6 +171,7 @@ class AutoAnnotationTaskService:
name=request.name,
dataset_id=request.dataset_id,
dataset_name=dataset_name,
dataset_type=dataset_type,
created_by=user_context.user_id,
config=normalized_config,
task_mode=request.task_mode,
@@ -192,15 +194,15 @@ class AutoAnnotationTaskService:
db.add_all(operator_instances)
await db.commit()
await db.refresh(task)
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
if user_context.is_admin:
return query
@@ -222,21 +224,21 @@ class AutoAnnotationTaskService:
query.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = []
for task in tasks:
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
# 出错时降级为单个 datasetName/datasetId
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
responses.append(resp)
return responses
responses: List[AutoAnnotationTaskResponse] = []
for task in tasks:
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
# 出错时降级为单个 datasetName/datasetId
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
responses.append(resp)
return responses
async def get_task(
self,
db: AsyncSession,
@@ -252,43 +254,43 @@ class AutoAnnotationTaskService:
task = result.scalar_one_or_none()
if not task:
return None
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp
async def _compute_source_datasets(
self,
db: AsyncSession,
task: AutoAnnotationTask,
) -> List[str]:
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
"""
file_ids = task.file_ids or []
if file_ids:
stmt = (
select(Dataset.name)
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
.where(DatasetFiles.id.in_(file_ids))
.distinct()
)
result = await db.execute(stmt)
names = [row[0] for row in result.fetchall() if row[0]]
if names:
return names
# 回退:只显示一个数据集
if task.dataset_name:
return [task.dataset_name]
resp = AutoAnnotationTaskResponse.model_validate(task)
try:
resp.source_datasets = await self._compute_source_datasets(db, task)
except Exception:
fallback_name = getattr(task, "dataset_name", None)
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp
async def _compute_source_datasets(
self,
db: AsyncSession,
task: AutoAnnotationTask,
) -> List[str]:
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
"""
file_ids = task.file_ids or []
if file_ids:
stmt = (
select(Dataset.name)
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
.where(DatasetFiles.id.in_(file_ids))
.distinct()
)
result = await db.execute(stmt)
names = [row[0] for row in result.fetchall() if row[0]]
if names:
return names
# 回退:只显示一个数据集
if task.dataset_name:
return [task.dataset_name]
if task.dataset_id:
return [task.dataset_id]
return []
@@ -331,7 +333,7 @@ class AutoAnnotationTaskService:
fallback_id = getattr(task, "dataset_id", "")
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
return resp
async def soft_delete_task(
self,
db: AsyncSession,
@@ -347,7 +349,7 @@ class AutoAnnotationTaskService:
task = result.scalar_one_or_none()
if not task:
return False
task.deleted_at = datetime.now()
await db.commit()
return True
task.deleted_at = datetime.now()
await db.commit()
return True