You've already forked DataMate
feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)
移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO 类型数据集均可用于自动标注任务。 - 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列 - 后端 schema/API/service 全链路传递 dataset_type - Worker 动态构建 sample key(image/text/audio/video)和输出目录 - 前端移除数据集类型校验,下拉框显示数据集类型标识 - 输出数据集继承源数据集类型,不再硬编码为 IMAGE - 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,9 +15,9 @@ from typing import List, Literal
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.dataset import DatasetManagementService
|
||||
from app.core.logging import get_logger
|
||||
|
||||
@@ -29,15 +29,15 @@ from ..security import (
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
)
|
||||
from ..service.auto import AutoAnnotationTaskService
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["annotation/auto"],
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
service = AutoAnnotationTaskService()
|
||||
|
||||
|
||||
@@ -85,23 +85,28 @@ async def _create_task_internal(
|
||||
await assert_dataset_access(db, normalized_request.dataset_id, user_context)
|
||||
# 尝试获取数据集名称和总量用于冗余字段
|
||||
dataset_name = None
|
||||
dataset_type = "IMAGE"
|
||||
total_images = len(normalized_request.file_ids) if normalized_request.file_ids else 0
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
dataset = await dm_client.get_dataset(normalized_request.dataset_id)
|
||||
if dataset is not None:
|
||||
dataset_name = dataset.name
|
||||
dataset_type = getattr(dataset, "datasetType", None) or "IMAGE"
|
||||
if not normalized_request.file_ids:
|
||||
total_images = getattr(dataset, "fileCount", 0) or 0
|
||||
except Exception as e: # pragma: no cover - 容错
|
||||
logger.warning("Failed to fetch dataset summary for annotation task: %s", e)
|
||||
|
||||
resolved_dataset_type = normalized_request.dataset_type or dataset_type
|
||||
|
||||
return await service.create_task(
|
||||
db,
|
||||
normalized_request,
|
||||
user_context=user_context,
|
||||
dataset_name=dataset_name,
|
||||
total_images=total_images,
|
||||
dataset_type=resolved_dataset_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -177,10 +182,10 @@ async def get_auto_annotation_task_status(
|
||||
task = await service.get_task(db, task_id, user_context)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=task,
|
||||
)
|
||||
|
||||
@@ -192,12 +197,12 @@ async def delete_auto_annotation_task(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||
|
||||
ok = await service.soft_delete_task(db, task_id, user_context)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
@@ -232,50 +237,50 @@ async def download_auto_annotation_result(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||
):
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
"""下载指定自动标注任务的结果 ZIP。"""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
import tempfile
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
|
||||
# 复用服务层获取任务信息
|
||||
task = await service.get_task(db, task_id, user_context)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not task.output_path:
|
||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||
|
||||
output_dir = task.output_path
|
||||
if not os.path.isdir(output_dir):
|
||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||
os.close(tmp_fd)
|
||||
|
||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
arcname = os.path.relpath(file_path, output_dir)
|
||||
zf.write(file_path, arcname)
|
||||
|
||||
file_size = os.path.getsize(tmp_path)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||
|
||||
def iterfile():
|
||||
with open(tmp_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
filename = f"{task.name}_annotations.zip"
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
if not task.output_path:
|
||||
raise HTTPException(status_code=400, detail="Task has no output path")
|
||||
|
||||
output_dir = task.output_path
|
||||
if not os.path.isdir(output_dir):
|
||||
raise HTTPException(status_code=404, detail="Output directory not found")
|
||||
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
|
||||
os.close(tmp_fd)
|
||||
|
||||
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for root, _, files in os.walk(output_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
arcname = os.path.relpath(file_path, output_dir)
|
||||
zf.write(file_path, arcname)
|
||||
|
||||
file_size = os.path.getsize(tmp_path)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=500, detail="Generated ZIP is empty")
|
||||
|
||||
def iterfile():
|
||||
with open(tmp_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
filename = f"{task.name}_annotations.zip"
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(file_size),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="application/zip", headers=headers)
|
||||
|
||||
Reference in New Issue
Block a user