You've already forked DataMate
feat(annotation): 自动标注任务支持非图像类型数据集(TEXT/AUDIO/VIDEO)
移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO 类型数据集均可用于自动标注任务。 - 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列 - 后端 schema/API/service 全链路传递 dataset_type - Worker 动态构建 sample key(image/text/audio/video)和输出目录 - 前端移除数据集类型校验,下拉框显示数据集类型标识 - 输出数据集继承源数据集类型,不再硬编码为 IMAGE - 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,13 +15,13 @@ from app.db.models.annotation_management import (
|
||||
)
|
||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||
from app.module.annotation.security import RequestUserContext
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
|
||||
|
||||
|
||||
from ..schema.auto import (
|
||||
CreateAutoAnnotationTaskRequest,
|
||||
AutoAnnotationTaskResponse,
|
||||
)
|
||||
|
||||
|
||||
class AutoAnnotationTaskService:
|
||||
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
|
||||
|
||||
@@ -141,11 +141,12 @@ class AutoAnnotationTaskService:
|
||||
user_context: RequestUserContext,
|
||||
dataset_name: Optional[str] = None,
|
||||
total_images: int = 0,
|
||||
dataset_type: str = "IMAGE",
|
||||
) -> AutoAnnotationTaskResponse:
|
||||
"""创建自动标注任务,初始状态为 pending。
|
||||
|
||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||
后续可以由调度器/worker 读取该表并更新进度。
|
||||
"""创建自动标注任务,初始状态为 pending。
|
||||
|
||||
这里仅插入任务记录,不负责真正执行 YOLO 推理,
|
||||
后续可以由调度器/worker 读取该表并更新进度。
|
||||
"""
|
||||
|
||||
now = datetime.now()
|
||||
@@ -170,6 +171,7 @@ class AutoAnnotationTaskService:
|
||||
name=request.name,
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
dataset_type=dataset_type,
|
||||
created_by=user_context.user_id,
|
||||
config=normalized_config,
|
||||
task_mode=request.task_mode,
|
||||
@@ -192,15 +194,15 @@ class AutoAnnotationTaskService:
|
||||
db.add_all(operator_instances)
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
|
||||
# 创建后附带 sourceDatasets 信息(通常只有一个原始数据集)
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||
return resp
|
||||
|
||||
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
||||
if user_context.is_admin:
|
||||
return query
|
||||
@@ -222,21 +224,21 @@ class AutoAnnotationTaskService:
|
||||
query.order_by(AutoAnnotationTask.created_at.desc())
|
||||
)
|
||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
# 出错时降级为单个 datasetName/datasetId
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
responses.append(resp)
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
responses: List[AutoAnnotationTaskResponse] = []
|
||||
for task in tasks:
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
# 出错时降级为单个 datasetName/datasetId
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
responses.append(resp)
|
||||
|
||||
return responses
|
||||
|
||||
async def get_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
@@ -252,43 +254,43 @@ class AutoAnnotationTaskService:
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def _compute_source_datasets(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task: AutoAnnotationTask,
|
||||
) -> List[str]:
|
||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||
|
||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||
"""
|
||||
|
||||
file_ids = task.file_ids or []
|
||||
if file_ids:
|
||||
stmt = (
|
||||
select(Dataset.name)
|
||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||
.where(DatasetFiles.id.in_(file_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
names = [row[0] for row in result.fetchall() if row[0]]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 回退:只显示一个数据集
|
||||
if task.dataset_name:
|
||||
return [task.dataset_name]
|
||||
|
||||
resp = AutoAnnotationTaskResponse.model_validate(task)
|
||||
try:
|
||||
resp.source_datasets = await self._compute_source_datasets(db, task)
|
||||
except Exception:
|
||||
fallback_name = getattr(task, "dataset_name", None)
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
async def _compute_source_datasets(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
task: AutoAnnotationTask,
|
||||
) -> List[str]:
|
||||
"""根据任务的 file_ids 推断实际涉及到的所有数据集名称。
|
||||
|
||||
- 如果存在 file_ids,则通过 t_dm_dataset_files 反查 dataset_id,再关联 t_dm_datasets 获取名称;
|
||||
- 如果没有 file_ids,则退回到任务上冗余的 dataset_name/dataset_id。
|
||||
"""
|
||||
|
||||
file_ids = task.file_ids or []
|
||||
if file_ids:
|
||||
stmt = (
|
||||
select(Dataset.name)
|
||||
.join(DatasetFiles, Dataset.id == DatasetFiles.dataset_id)
|
||||
.where(DatasetFiles.id.in_(file_ids))
|
||||
.distinct()
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
names = [row[0] for row in result.fetchall() if row[0]]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 回退:只显示一个数据集
|
||||
if task.dataset_name:
|
||||
return [task.dataset_name]
|
||||
if task.dataset_id:
|
||||
return [task.dataset_id]
|
||||
return []
|
||||
@@ -331,7 +333,7 @@ class AutoAnnotationTaskService:
|
||||
fallback_id = getattr(task, "dataset_id", "")
|
||||
resp.source_datasets = [fallback_name] if fallback_name else [fallback_id]
|
||||
return resp
|
||||
|
||||
|
||||
async def soft_delete_task(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
@@ -347,7 +349,7 @@ class AutoAnnotationTaskService:
|
||||
task = result.scalar_one_or_none()
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
task.deleted_at = datetime.now()
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user