Files
DataMate/runtime/datamate-python/app/module/annotation/service/editor.py
Jerry Yan 08336e2a13 feat(annotation): 添加标注模板配置功能
- 在schema中新增choice和show_inline字段支持选择模式配置
- 为编辑器服务添加空标注创建逻辑避免前端异常
- 实现标签类型的标准化处理和大小写兼容
- 支持Choices标签的单选/多选和行内显示配置
- 优化前端界面滚动条显示控制样式
2026-01-09 13:05:09 +08:00

338 lines
13 KiB
Python

"""
标注编辑器(Label Studio Editor)服务
职责:
- 解析 DataMate 标注项目(t_dm_labeling_projects)
- 以“文件下载/预览接口”读取文本内容,构造 Label Studio task
- 以原始 annotation JSON 形式 upsert 最终标注结果(单人单份)
"""
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import hashlib
import httpx
from fastapi import HTTPException
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject
from app.module.annotation.schema.editor import (
EditorProjectInfo,
EditorTaskListItem,
EditorTaskListResponse,
EditorTaskResponse,
UpsertAnnotationRequest,
UpsertAnnotationResponse,
)
from app.module.annotation.service.template import AnnotationTemplateService
logger = get_logger(__name__)
class AnnotationEditorService:
"""Label Studio Editor 集成服务(TEXT POC 版)"""
def __init__(self, db: AsyncSession):
self.db = db
self.template_service = AnnotationTemplateService()
@staticmethod
def _stable_ls_id(seed: str) -> int:
"""
生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。
说明:
- Label Studio Frontend 的 mobx-state-tree 模型对 task/annotation 的 id 有类型约束(通常为 number)。
- DataMate 使用 UUID 作为 file_id/project_id,因此需映射为整数供编辑器使用。
- 取 sha1 的前 13 个 hex(52bit),落在 JS Number 的安全整数范围。
"""
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
value = int(digest[:13], 16)
return value if value > 0 else 1
def _make_ls_task_id(self, project_id: str, file_id: str) -> int:
return self._stable_ls_id(f"task:{project_id}:{file_id}")
def _make_ls_annotation_id(self, project_id: str, file_id: str) -> int:
# 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可
return self._stable_ls_id(f"annotation:{project_id}:{file_id}")
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None),
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
result = await self.db.execute(
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
)
return result.scalar_one_or_none()
async def _get_label_config(self, template_id: Optional[str]) -> Optional[str]:
if not template_id:
return None
template = await self.template_service.get_template(self.db, template_id)
return getattr(template, "label_config", None) if template else None
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
project = await self._get_project_or_404(project_id)
dataset_type = await self._get_dataset_type(project.dataset_id)
supported = (dataset_type or "").upper() == "TEXT"
unsupported_reason = None
if not supported:
unsupported_reason = f"当前仅支持 TEXT,项目数据类型为: {dataset_type or 'UNKNOWN'}"
label_config = await self._get_label_config(project.template_id)
return EditorProjectInfo(
projectId=project.id,
datasetId=project.dataset_id,
templateId=project.template_id,
labelConfig=label_config,
supported=supported,
unsupportedReason=unsupported_reason,
)
async def list_tasks(self, project_id: str, page: int = 0, size: int = 50) -> EditorTaskListResponse:
project = await self._get_project_or_404(project_id)
count_result = await self.db.execute(
select(func.count()).select_from(DatasetFiles).where(
DatasetFiles.dataset_id == project.dataset_id
)
)
total = int(count_result.scalar() or 0)
files_result = await self.db.execute(
select(DatasetFiles)
.where(DatasetFiles.dataset_id == project.dataset_id)
.order_by(DatasetFiles.created_at.desc())
.offset(page * size)
.limit(size)
)
files = files_result.scalars().all()
file_ids = [str(f.id) for f in files] # type: ignore[arg-type]
updated_map: Dict[str, datetime] = {}
if file_ids:
ann_result = await self.db.execute(
select(AnnotationResult.file_id, AnnotationResult.updated_at).where(
AnnotationResult.project_id == project_id,
AnnotationResult.file_id.in_(file_ids),
)
)
for file_id, updated_at in ann_result.all():
if file_id and updated_at:
updated_map[str(file_id)] = updated_at
items: List[EditorTaskListItem] = []
for f in files:
fid = str(f.id) # type: ignore[arg-type]
items.append(
EditorTaskListItem(
fileId=fid,
fileName=str(getattr(f, "file_name", "")),
fileType=getattr(f, "file_type", None),
hasAnnotation=fid in updated_map,
annotationUpdatedAt=updated_map.get(fid),
)
)
total_pages = (total + size - 1) // size if size > 0 else 0
return EditorTaskListResponse(
content=items,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
)
async def _fetch_text_content_via_download_api(self, dataset_id: str, file_id: str) -> str:
base = settings.datamate_backend_base_url.rstrip("/")
url = f"{base}/data-management/datasets/{dataset_id}/files/{file_id}/download"
try:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
resp = await client.get(url)
resp.raise_for_status()
content_length = resp.headers.get("content-length")
if content_length:
try:
if int(content_length) > settings.editor_max_text_bytes:
raise HTTPException(
status_code=413,
detail=f"文本文件过大,限制 {settings.editor_max_text_bytes} 字节",
)
except ValueError:
# content-length 非法则忽略,走实际长度判断
pass
data = resp.content
if len(data) > settings.editor_max_text_bytes:
raise HTTPException(
status_code=413,
detail=f"文本文件过大,限制 {settings.editor_max_text_bytes} 字节",
)
# TEXT POC:默认按 UTF-8 解码,不可解码字符用替换符处理
return data.decode("utf-8", errors="replace")
except HTTPException:
raise
except httpx.HTTPStatusError as e:
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, http={e.response.status_code}")
raise HTTPException(status_code=502, detail="读取文本失败(下载接口返回错误)")
except Exception as e:
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, err={e}")
raise HTTPException(status_code=502, detail="读取文本失败(下载接口调用异常)")
async def get_task(self, project_id: str, file_id: str) -> EditorTaskResponse:
project = await self._get_project_or_404(project_id)
# TEXT 支持校验
dataset_type = await self._get_dataset_type(project.dataset_id)
if (dataset_type or "").upper() != "TEXT":
raise HTTPException(status_code=400, detail="当前仅支持 TEXT 项目的内嵌编辑器")
file_result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
file_record = file_result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project_id,
AnnotationResult.file_id == file_id,
)
)
ann = ann_result.scalar_one_or_none()
ls_task_id = self._make_ls_task_id(project_id, file_id)
task: Dict[str, Any] = {
"id": ls_task_id,
"data": {
"text": text_content,
"file_id": file_id,
"dataset_id": project.dataset_id,
"file_name": getattr(file_record, "file_name", ""),
},
"annotations": [],
}
annotation_updated_at = None
if ann:
annotation_updated_at = ann.updated_at
# 直接返回存储的 annotation 原始对象(Label Studio 兼容)
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
task["annotations"] = [stored]
else:
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
empty_ann_id = self._make_ls_annotation_id(project_id, file_id)
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
return EditorTaskResponse(
task=task,
annotationUpdatedAt=annotation_updated_at,
)
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
project = await self._get_project_or_404(project_id)
# 校验文件归属
file_check = await self.db.execute(
select(DatasetFiles.id).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
if not file_check.scalar_one_or_none():
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
annotation_payload = dict(request.annotation or {})
result = annotation_payload.get("result")
if not isinstance(result, list):
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
ls_task_id = self._make_ls_task_id(project_id, file_id)
annotation_payload["task"] = ls_task_id
if not isinstance(annotation_payload.get("id"), int):
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
existing_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project_id,
AnnotationResult.file_id == file_id,
)
)
existing = existing_result.scalar_one_or_none()
now = datetime.utcnow()
if existing:
if request.expected_updated_at and existing.updated_at:
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
existing.annotation = annotation_payload # type: ignore[assignment]
existing.updated_at = now # type: ignore[assignment]
await self.db.commit()
await self.db.refresh(existing)
return UpsertAnnotationResponse(
annotationId=existing.id,
updatedAt=existing.updated_at or now,
)
new_id = str(uuid.uuid4())
record = AnnotationResult(
id=new_id,
project_id=project_id,
file_id=file_id,
annotation=annotation_payload,
created_at=now,
updated_at=now,
)
self.db.add(record)
await self.db.commit()
await self.db.refresh(record)
return UpsertAnnotationResponse(
annotationId=record.id,
updatedAt=record.updated_at or now,
)