diff --git a/runtime/datamate-python/app/module/annotation/service/export.py b/runtime/datamate-python/app/module/annotation/service/export.py index 0b502b0..c345eec 100644 --- a/runtime/datamate-python/app/module/annotation/service/export.py +++ b/runtime/datamate-python/app/module/annotation/service/export.py @@ -28,6 +28,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import ( AnnotationResult, + AnnotationTemplate, Dataset, DatasetFiles, LabelingProject, @@ -78,6 +79,15 @@ SEGMENTS_KEY = "segments" SEGMENT_RESULT_KEY = "result" SEGMENT_INDEX_KEY = "segmentIndex" SEGMENT_INDEX_FALLBACK_KEY = "segment_index" +DATASET_TYPE_TEXT = "TEXT" +DATASET_TYPE_IMAGE = "IMAGE" +DATASET_TYPE_OBJECT_DETECTION = "OBJECT_DETECTION" +LABELING_TYPE_CONFIG_KEY = "labeling_type" +LABELING_TYPE_OBJECT_DETECTION = "OBJECT_DETECTION" +COCO_COMPATIBLE_DATASET_TYPES = { + DATASET_TYPE_IMAGE, + DATASET_TYPE_OBJECT_DETECTION, +} class AnnotationExportService: @@ -137,6 +147,19 @@ class AnnotationExportService: """ project = await self._get_project_or_404(project_id) + # 根据格式导出 + format_type = ( + ExportFormat(request.format) + if isinstance(request.format, str) + else request.format + ) + + # 兼容性检查 + await self._validate_export_format_compatibility( + project=project, + format_type=format_type, + ) + # 获取标注数据 items = await self._fetch_annotation_data( project_id=project_id, @@ -145,13 +168,6 @@ class AnnotationExportService: include_data=request.include_data, ) - # 根据格式导出 - format_type = ( - ExportFormat(request.format) - if isinstance(request.format, str) - else request.format - ) - if format_type == ExportFormat.JSON: return self._export_json(items, project.name) elif format_type == ExportFormat.JSONL: @@ -180,6 +196,77 @@ class AnnotationExportService: raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}") return project + @staticmethod + def _normalize_type_value(value: Optional[str]) -> str: + if not value: + return "" + return str(value).strip().upper().replace("-", "_") + + async def _get_dataset_type(self, dataset_id: str) -> Optional[str]: + result = await self.db.execute( + select(Dataset.dataset_type).where(Dataset.id == dataset_id) + ) + return result.scalar_one_or_none() + + async def _get_template_labeling_type( + self, template_id: Optional[str] + ) -> Optional[str]: + if not template_id: + return None + result = await self.db.execute( + select(AnnotationTemplate.labeling_type).where( + AnnotationTemplate.id == template_id, + AnnotationTemplate.deleted_at.is_(None), + ) + ) + return result.scalar_one_or_none() + + async def _resolve_project_labeling_type( + self, project: LabelingProject + ) -> Optional[str]: + configuration = getattr(project, "configuration", None) + if isinstance(configuration, dict): + labeling_type = configuration.get(LABELING_TYPE_CONFIG_KEY) + if isinstance(labeling_type, str) and labeling_type.strip(): + return labeling_type + return await self._get_template_labeling_type(project.template_id) + + async def _validate_export_format_compatibility( + self, + project: LabelingProject, + format_type: ExportFormat, + ) -> None: + if format_type != ExportFormat.COCO: + return + + dataset_type = self._normalize_type_value( + await self._get_dataset_type(project.dataset_id) + ) + labeling_type = self._normalize_type_value( + await self._resolve_project_labeling_type(project) + ) + + if dataset_type == DATASET_TYPE_TEXT: + raise HTTPException( + status_code=400, + detail="导出格式 COCO 不支持文本类数据集(TEXT),请改用 JSON/JSONL/CSV 格式", + ) + + if ( + dataset_type in COCO_COMPATIBLE_DATASET_TYPES + or labeling_type == LABELING_TYPE_OBJECT_DETECTION + ): + return + + raise HTTPException( + status_code=400, + detail=( + "导出格式 COCO 仅适用于图像类或目标检测类数据集," + f"当前数据集类型: {dataset_type or 'UNKNOWN'}," + f"标注类型: {labeling_type or 'UNKNOWN'}" + ), + ) + async def _fetch_annotation_data( self, project_id: str,