feat(annotation): 支持图像数据集的内嵌标注编辑器

- 添加文件预览接口,支持以 inline 方式预览数据集中的指定文件
- 实现图像任务构建功能,支持图像标注任务的数据结构
- 扩展标注编辑器服务以支持 TEXT 和 IMAGE 类型数据集
- 添加媒体对象分类支持,解析图像标注配置
- 实现图像文件预览 URL 构建逻辑
- 优化项目信息获取和任务响应构建流程
- 修复数据库查询中的项目 ID 引用错误
This commit is contained in:
2026-01-25 17:25:44 +08:00
parent e6d1e4763f
commit c5ace0c4cc
3 changed files with 240 additions and 49 deletions

View File

@@ -336,6 +336,34 @@ paths:
type: string type: string
format: binary format: binary
/data-management/datasets/{datasetId}/files/{fileId}/preview:
get:
tags: [DatasetFile]
operationId: previewDatasetFile
summary: 预览文件
description: 以 inline 方式预览数据集中的指定文件
parameters:
- name: datasetId
in: path
required: true
schema:
type: string
description: 数据集ID
- name: fileId
in: path
required: true
schema:
type: string
description: 文件ID
responses:
'200':
description: 文件内容
content:
application/octet-stream:
schema:
type: string
format: binary
/data-management/datasets/{datasetId}/files/download: /data-management/datasets/{datasetId}/files/download:
get: get:
tags: [ DatasetFile ] tags: [ DatasetFile ]

View File

@@ -22,6 +22,7 @@ import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.http.MediaTypeFactory;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@@ -103,6 +104,28 @@ public class DatasetFileController {
} }
} }
@IgnoreResponseWrap
@GetMapping(value = "/{fileId}/preview", produces = MediaType.ALL_VALUE)
public ResponseEntity<Resource> previewDatasetFileById(@PathVariable("datasetId") String datasetId,
@PathVariable("fileId") String fileId) {
try {
DatasetFile datasetFile = datasetFileApplicationService.getDatasetFile(datasetId, fileId);
Resource resource = datasetFileApplicationService.downloadFile(datasetId, fileId);
MediaType mediaType = MediaTypeFactory.getMediaType(resource)
.orElse(MediaType.APPLICATION_OCTET_STREAM);
return ResponseEntity.ok()
.contentType(mediaType)
.header(HttpHeaders.CONTENT_DISPOSITION,
"inline; filename=\"" + datasetFile.getFileName() + "\"")
.body(resource);
} catch (IllegalArgumentException e) {
return ResponseEntity.status(HttpStatus.NOT_FOUND).build();
} catch (Exception e) {
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).build();
}
}
@IgnoreResponseWrap @IgnoreResponseWrap
@GetMapping(value = "/download", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE) @GetMapping(value = "/download", produces = MediaType.APPLICATION_OCTET_STREAM_VALUE)
public void downloadDatasetFileAsZip(@PathVariable("datasetId") String datasetId, HttpServletResponse response) { public void downloadDatasetFileAsZip(@PathVariable("datasetId") String datasetId, HttpServletResponse response) {

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import hashlib import hashlib
import json import json
@@ -20,6 +21,7 @@ from fastapi import HTTPException
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.module.annotation.config import LabelStudioTagConfig from app.module.annotation.config import LabelStudioTagConfig
@@ -40,6 +42,7 @@ from app.module.annotation.service.text_fetcher import fetch_text_content_via_do
logger = get_logger(__name__) logger = get_logger(__name__)
TEXT_DATA_KEY = "text" TEXT_DATA_KEY = "text"
IMAGE_DATA_KEY = "image"
DATASET_ID_KEY = "dataset_id" DATASET_ID_KEY = "dataset_id"
FILE_ID_KEY = "file_id" FILE_ID_KEY = "file_id"
FILE_NAME_KEY = "file_name" FILE_NAME_KEY = "file_name"
@@ -50,7 +53,9 @@ SEGMENT_INDEX_KEY = "segment_index"
SEGMENT_INDEX_CAMEL_KEY = "segmentIndex" SEGMENT_INDEX_CAMEL_KEY = "segmentIndex"
JSONL_EXTENSION = ".jsonl" JSONL_EXTENSION = ".jsonl"
TEXTUAL_OBJECT_CATEGORIES = {"text", "document"} TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
MEDIA_OBJECT_CATEGORIES = {"image"}
OBJECT_NAME_HEADER_PREFIX = "dm_object_header_" OBJECT_NAME_HEADER_PREFIX = "dm_object_header_"
SUPPORTED_EDITOR_DATASET_TYPES = ("TEXT", "IMAGE")
class AnnotationEditorService: class AnnotationEditorService:
@@ -84,6 +89,32 @@ class AnnotationEditorService:
# 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可 # 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可
return self._stable_ls_id(f"annotation:{project_id}:{file_id}") return self._stable_ls_id(f"annotation:{project_id}:{file_id}")
@staticmethod
def _normalize_dataset_type(dataset_type: Optional[str]) -> str:
return (dataset_type or "").upper()
@staticmethod
def _resolve_public_api_prefix() -> str:
base = (settings.datamate_backend_base_url or "").strip()
if not base:
return "/api"
parsed = urlparse(base)
if parsed.scheme and parsed.netloc:
prefix = parsed.path
else:
prefix = base
prefix = prefix.rstrip("/")
if not prefix:
return "/api"
if not prefix.startswith("/"):
prefix = "/" + prefix
return prefix
@classmethod
def _build_file_preview_url(cls, dataset_id: str, file_id: str) -> str:
prefix = cls._resolve_public_api_prefix()
return f"{prefix}/data-management/datasets/{dataset_id}/files/{file_id}/preview"
async def _get_project_or_404(self, project_id: str) -> LabelingProject: async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute( result = await self.db.execute(
select(LabelingProject).where( select(LabelingProject).where(
@@ -129,6 +160,23 @@ class AnnotationEditorService:
return TEXT_DATA_KEY return TEXT_DATA_KEY
return keys[0] return keys[0]
@classmethod
def _resolve_primary_media_key(
cls,
label_config: Optional[str],
default_key: str,
categories: Optional[set[str]] = None,
) -> str:
if not label_config:
return default_key
target_categories = categories or set()
keys = cls._extract_object_value_keys_by_category(label_config, target_categories)
if not keys:
return default_key
if default_key in keys:
return default_key
return keys[0]
@staticmethod @staticmethod
def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]: def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]:
if not text_content: if not text_content:
@@ -160,7 +208,11 @@ class AnnotationEditorService:
return category in TEXTUAL_OBJECT_CATEGORIES return category in TEXTUAL_OBJECT_CATEGORIES
@classmethod @classmethod
def _extract_textual_value_keys(cls, label_config: str) -> List[str]: def _extract_object_value_keys_by_category(
cls,
label_config: str,
categories: set[str],
) -> List[str]:
try: try:
root = ET.fromstring(label_config) root = ET.fromstring(label_config)
except Exception as exc: except Exception as exc:
@@ -172,7 +224,9 @@ class AnnotationEditorService:
for element in root.iter(): for element in root.iter():
if element.tag not in object_types: if element.tag not in object_types:
continue continue
if not cls._is_textual_object_tag(element.tag): config = LabelStudioTagConfig.get_object_config(element.tag) or {}
category = config.get("category")
if categories and category not in categories:
continue continue
value = element.attrib.get("value", "") value = element.attrib.get("value", "")
if not value.startswith("$"): if not value.startswith("$"):
@@ -183,6 +237,10 @@ class AnnotationEditorService:
seen[key] = None seen[key] = None
return list(seen.keys()) return list(seen.keys())
@classmethod
def _extract_textual_value_keys(cls, label_config: str) -> List[str]:
return cls._extract_object_value_keys_by_category(label_config, TEXTUAL_OBJECT_CATEGORIES)
@staticmethod @staticmethod
def _needs_placeholder(value: Any) -> bool: def _needs_placeholder(value: Any) -> bool:
if value is None: if value is None:
@@ -310,11 +368,12 @@ class AnnotationEditorService:
async def get_project_info(self, project_id: str) -> EditorProjectInfo: async def get_project_info(self, project_id: str) -> EditorProjectInfo:
project = await self._get_project_or_404(project_id) project = await self._get_project_or_404(project_id)
dataset_type = await self._get_dataset_type(project.dataset_id) dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
supported = (dataset_type or "").upper() == "TEXT" supported = dataset_type in SUPPORTED_EDITOR_DATASET_TYPES
unsupported_reason = None unsupported_reason = None
if not supported: if not supported:
unsupported_reason = f"当前仅支持 TEXT,项目数据类型为: {dataset_type or 'UNKNOWN'}" supported_hint = "/".join(SUPPORTED_EDITOR_DATASET_TYPES)
unsupported_reason = f"当前仅支持 {supported_hint},项目数据类型为: {dataset_type or 'UNKNOWN'}"
# 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置 # 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置
label_config = await self._resolve_project_label_config(project) label_config = await self._resolve_project_label_config(project)
@@ -393,10 +452,9 @@ class AnnotationEditorService:
) -> EditorTaskResponse: ) -> EditorTaskResponse:
project = await self._get_project_or_404(project_id) project = await self._get_project_or_404(project_id)
# TEXT 支持校验 dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
dataset_type = await self._get_dataset_type(project.dataset_id) if dataset_type not in SUPPORTED_EDITOR_DATASET_TYPES:
if (dataset_type or "").upper() != "TEXT": raise HTTPException(status_code=400, detail="当前仅支持 TEXT/IMAGE 项目的内嵌编辑器")
raise HTTPException(status_code=400, detail="当前仅支持 TEXT 项目的内嵌编辑器")
file_result = await self.db.execute( file_result = await self.db.execute(
select(DatasetFiles).where( select(DatasetFiles).where(
@@ -408,6 +466,18 @@ class AnnotationEditorService:
if not file_record: if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}") raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
if dataset_type == "IMAGE":
return await self._build_image_task(project, file_record, file_id)
return await self._build_text_task(project, file_record, file_id, segment_index)
async def _build_text_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
segment_index: Optional[int],
) -> EditorTaskResponse:
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id) text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
assert isinstance(text_content, str) assert isinstance(text_content, str)
label_config = await self._resolve_project_label_config(project) label_config = await self._resolve_project_label_config(project)
@@ -434,13 +504,13 @@ class AnnotationEditorService:
# 获取现有标注 # 获取现有标注
ann_result = await self.db.execute( ann_result = await self.db.execute(
select(AnnotationResult).where( select(AnnotationResult).where(
AnnotationResult.project_id == project_id, AnnotationResult.project_id == project.id,
AnnotationResult.file_id == file_id, AnnotationResult.file_id == file_id,
) )
) )
ann = ann_result.scalar_one_or_none() ann = ann_result.scalar_one_or_none()
ls_task_id = self._make_ls_task_id(project_id, file_id) ls_task_id = self._make_ls_task_id(project.id, file_id)
# 判断是否需要分段(JSONL 多行或主文本超过阈值) # 判断是否需要分段(JSONL 多行或主文本超过阈值)
needs_segmentation = len(records) > 1 or any( needs_segmentation = len(records) > 1 or any(
@@ -529,7 +599,7 @@ class AnnotationEditorService:
segment_annotations = ann.annotation.get("segments", {}) segment_annotations = ann.annotation.get("segments", {})
seg_ann = segment_annotations.get(str(current_segment_index), {}) seg_ann = segment_annotations.get(str(current_segment_index), {})
stored = { stored = {
"id": self._make_ls_annotation_id(project_id, file_id) + current_segment_index, "id": self._make_ls_annotation_id(project.id, file_id) + current_segment_index,
"task": ls_task_id, "task": ls_task_id,
"result": seg_ann.get("result", []), "result": seg_ann.get("result", []),
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"), "created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"),
@@ -541,11 +611,11 @@ class AnnotationEditorService:
stored = dict(ann.annotation or {}) stored = dict(ann.annotation or {})
stored["task"] = ls_task_id stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int): if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project_id, file_id) stored["id"] = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [stored] task["annotations"] = [stored]
else: else:
# 首次从非分段切换到分段:提供空标注 # 首次从非分段切换到分段:提供空标注
empty_ann_id = self._make_ls_annotation_id(project_id, file_id) + current_segment_index empty_ann_id = self._make_ls_annotation_id(project.id, file_id) + current_segment_index
task["annotations"] = [ task["annotations"] = [
{ {
"id": empty_ann_id, "id": empty_ann_id,
@@ -557,7 +627,7 @@ class AnnotationEditorService:
] ]
else: else:
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result # 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
empty_ann_id = self._make_ls_annotation_id(project_id, file_id) empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
if needs_segmentation: if needs_segmentation:
empty_ann_id += current_segment_index empty_ann_id += current_segment_index
task["annotations"] = [ task["annotations"] = [
@@ -579,6 +649,76 @@ class AnnotationEditorService:
currentSegmentIndex=current_segment_index, currentSegmentIndex=current_segment_index,
) )
async def _build_image_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
) -> EditorTaskResponse:
label_config = await self._resolve_project_label_config(project)
image_key = self._resolve_primary_media_key(
label_config,
IMAGE_DATA_KEY,
MEDIA_OBJECT_CATEGORIES,
)
preview_url = self._build_file_preview_url(project.dataset_id, file_id)
file_name = str(getattr(file_record, "file_name", ""))
task_data: Dict[str, Any] = {
image_key: preview_url,
FILE_ID_KEY: file_id,
FILE_ID_CAMEL_KEY: file_id,
DATASET_ID_KEY: project.dataset_id,
DATASET_ID_CAMEL_KEY: project.dataset_id,
FILE_NAME_KEY: file_name,
FILE_NAME_CAMEL_KEY: file_name,
}
# 获取现有标注
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project.id,
AnnotationResult.file_id == file_id,
)
)
ann = ann_result.scalar_one_or_none()
ls_task_id = self._make_ls_task_id(project.id, file_id)
task: Dict[str, Any] = {
"id": ls_task_id,
"data": task_data,
"annotations": [],
}
annotation_updated_at = None
if ann and not (ann.annotation or {}).get("segmented"):
annotation_updated_at = ann.updated_at
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [stored]
else:
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
return EditorTaskResponse(
task=task,
annotationUpdatedAt=annotation_updated_at,
segmented=False,
segments=None,
totalSegments=1,
currentSegmentIndex=0,
)
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse: async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
project = await self._get_project_or_404(project_id) project = await self._get_project_or_404(project_id)