You've already forked DataMate
feat(annotation): 添加导出格式与数据集类型的兼容性检查
- 实现 COCO 格式导出前的数据集类型验证 - COCO 格式仅适用于图像类和目标检测类数据集 - 文本类数据集尝试导出 COCO 格式时返回 HTTP 400 错误 - 添加清晰的错误提示信息,建议使用其他格式 新增功能: - 数据集类型常量定义(TEXT、IMAGE、OBJECT_DETECTION) - COCO 兼容类型集合 - 类型值标准化方法 - 数据集类型查询方法 - 模板标注类型解析方法 - 导出格式兼容性验证方法 相关文件: - runtime/datamate-python/app/module/annotation/service/export.py (+94, -7) Reviewed-by: Codex AI
This commit is contained in:
@@ -28,6 +28,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import (
|
||||
AnnotationResult,
|
||||
AnnotationTemplate,
|
||||
Dataset,
|
||||
DatasetFiles,
|
||||
LabelingProject,
|
||||
@@ -78,6 +79,15 @@ SEGMENTS_KEY = "segments"
|
||||
SEGMENT_RESULT_KEY = "result"
|
||||
SEGMENT_INDEX_KEY = "segmentIndex"
|
||||
SEGMENT_INDEX_FALLBACK_KEY = "segment_index"
|
||||
DATASET_TYPE_TEXT = "TEXT"
|
||||
DATASET_TYPE_IMAGE = "IMAGE"
|
||||
DATASET_TYPE_OBJECT_DETECTION = "OBJECT_DETECTION"
|
||||
LABELING_TYPE_CONFIG_KEY = "labeling_type"
|
||||
LABELING_TYPE_OBJECT_DETECTION = "OBJECT_DETECTION"
|
||||
COCO_COMPATIBLE_DATASET_TYPES = {
|
||||
DATASET_TYPE_IMAGE,
|
||||
DATASET_TYPE_OBJECT_DETECTION,
|
||||
}
|
||||
|
||||
|
||||
class AnnotationExportService:
|
||||
@@ -137,6 +147,19 @@ class AnnotationExportService:
|
||||
"""
|
||||
project = await self._get_project_or_404(project_id)
|
||||
|
||||
# 根据格式导出
|
||||
format_type = (
|
||||
ExportFormat(request.format)
|
||||
if isinstance(request.format, str)
|
||||
else request.format
|
||||
)
|
||||
|
||||
# 兼容性检查
|
||||
await self._validate_export_format_compatibility(
|
||||
project=project,
|
||||
format_type=format_type,
|
||||
)
|
||||
|
||||
# 获取标注数据
|
||||
items = await self._fetch_annotation_data(
|
||||
project_id=project_id,
|
||||
@@ -145,13 +168,6 @@ class AnnotationExportService:
|
||||
include_data=request.include_data,
|
||||
)
|
||||
|
||||
# 根据格式导出
|
||||
format_type = (
|
||||
ExportFormat(request.format)
|
||||
if isinstance(request.format, str)
|
||||
else request.format
|
||||
)
|
||||
|
||||
if format_type == ExportFormat.JSON:
|
||||
return self._export_json(items, project.name)
|
||||
elif format_type == ExportFormat.JSONL:
|
||||
@@ -180,6 +196,77 @@ class AnnotationExportService:
|
||||
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
||||
return project
|
||||
|
||||
@staticmethod
|
||||
def _normalize_type_value(value: Optional[str]) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
return str(value).strip().upper().replace("-", "_")
|
||||
|
||||
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
||||
result = await self.db.execute(
|
||||
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _get_template_labeling_type(
|
||||
self, template_id: Optional[str]
|
||||
) -> Optional[str]:
|
||||
if not template_id:
|
||||
return None
|
||||
result = await self.db.execute(
|
||||
select(AnnotationTemplate.labeling_type).where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _resolve_project_labeling_type(
|
||||
self, project: LabelingProject
|
||||
) -> Optional[str]:
|
||||
configuration = getattr(project, "configuration", None)
|
||||
if isinstance(configuration, dict):
|
||||
labeling_type = configuration.get(LABELING_TYPE_CONFIG_KEY)
|
||||
if isinstance(labeling_type, str) and labeling_type.strip():
|
||||
return labeling_type
|
||||
return await self._get_template_labeling_type(project.template_id)
|
||||
|
||||
async def _validate_export_format_compatibility(
|
||||
self,
|
||||
project: LabelingProject,
|
||||
format_type: ExportFormat,
|
||||
) -> None:
|
||||
if format_type != ExportFormat.COCO:
|
||||
return
|
||||
|
||||
dataset_type = self._normalize_type_value(
|
||||
await self._get_dataset_type(project.dataset_id)
|
||||
)
|
||||
labeling_type = self._normalize_type_value(
|
||||
await self._resolve_project_labeling_type(project)
|
||||
)
|
||||
|
||||
if dataset_type == DATASET_TYPE_TEXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="导出格式 COCO 不支持文本类数据集(TEXT),请改用 JSON/JSONL/CSV 格式",
|
||||
)
|
||||
|
||||
if (
|
||||
dataset_type in COCO_COMPATIBLE_DATASET_TYPES
|
||||
or labeling_type == LABELING_TYPE_OBJECT_DETECTION
|
||||
):
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"导出格式 COCO 仅适用于图像类或目标检测类数据集,"
|
||||
f"当前数据集类型: {dataset_type or 'UNKNOWN'},"
|
||||
f"标注类型: {labeling_type or 'UNKNOWN'}"
|
||||
),
|
||||
)
|
||||
|
||||
async def _fetch_annotation_data(
|
||||
self,
|
||||
project_id: str,
|
||||
|
||||
Reference in New Issue
Block a user