""" 标注编辑器(Label Studio Editor)服务 职责: - 解析 DataMate 标注项目(t_dm_labeling_projects) - 以“文件下载/预览接口”读取文本内容,构造 Label Studio task - 以原始 annotation JSON 形式 upsert 最终标注结果(单人单份) """ from __future__ import annotations import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Tuple import hashlib import json import xml.etree.ElementTree as ET from fastapi import HTTPException from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject from app.module.annotation.config import LabelStudioTagConfig from app.module.annotation.schema.editor import ( EditorProjectInfo, EditorTaskListItem, EditorTaskListResponse, EditorTaskResponse, SegmentInfo, UpsertAnnotationRequest, UpsertAnnotationResponse, ) from app.module.annotation.service.template import AnnotationTemplateService from app.module.annotation.service.knowledge_sync import KnowledgeSyncService from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter from app.module.annotation.service.text_fetcher import fetch_text_content_via_download_api logger = get_logger(__name__) TEXT_DATA_KEY = "text" DATASET_ID_KEY = "dataset_id" FILE_ID_KEY = "file_id" FILE_NAME_KEY = "file_name" TEXTUAL_OBJECT_CATEGORIES = {"text", "document"} OBJECT_NAME_HEADER_PREFIX = "dm_object_header_" class AnnotationEditorService: """Label Studio Editor 集成服务(TEXT POC 版)""" # 分段阈值:超过此字符数自动分段 SEGMENT_THRESHOLD = 200 def __init__(self, db: AsyncSession): self.db = db self.template_service = AnnotationTemplateService() @staticmethod def _stable_ls_id(seed: str) -> int: """ 生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。 说明: - Label Studio Frontend 的 mobx-state-tree 模型对 task/annotation 的 id 有类型约束(通常为 number)。 - DataMate 使用 UUID 作为 file_id/project_id,因此需映射为整数供编辑器使用。 - 取 sha1 的前 13 个 hex(52bit),落在 JS Number 的安全整数范围。 """ digest = hashlib.sha1(seed.encode("utf-8")).hexdigest() value = int(digest[:13], 16) return value if value > 0 else 1 def _make_ls_task_id(self, project_id: str, file_id: str) -> int: return self._stable_ls_id(f"task:{project_id}:{file_id}") def _make_ls_annotation_id(self, project_id: str, file_id: str) -> int: # 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可 return self._stable_ls_id(f"annotation:{project_id}:{file_id}") async def _get_project_or_404(self, project_id: str) -> LabelingProject: result = await self.db.execute( select(LabelingProject).where( LabelingProject.id == project_id, LabelingProject.deleted_at.is_(None), ) ) project = result.scalar_one_or_none() if not project: raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}") return project async def _get_dataset_type(self, dataset_id: str) -> Optional[str]: result = await self.db.execute( select(Dataset.dataset_type).where(Dataset.id == dataset_id) ) return result.scalar_one_or_none() async def _get_label_config(self, template_id: Optional[str]) -> Optional[str]: if not template_id: return None template = await self.template_service.get_template(self.db, template_id) return getattr(template, "label_config", None) if template else None async def _resolve_project_label_config(self, project: LabelingProject) -> Optional[str]: label_config = None if project.configuration and isinstance(project.configuration, dict): label_config = project.configuration.get("label_config") if not label_config: label_config = await self._get_label_config(project.template_id) if label_config: label_config = self._decorate_label_config_for_editor(label_config) return label_config @staticmethod def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]: if not text_content: return None stripped = text_content.strip() if not stripped: return None if stripped[0] not in ("{", "["): return None try: parsed = json.loads(stripped) except Exception: return None return parsed if isinstance(parsed, dict) else None @staticmethod def _is_textual_object_tag(object_tag: str) -> bool: config = LabelStudioTagConfig.get_object_config(object_tag) or {} category = config.get("category") return category in TEXTUAL_OBJECT_CATEGORIES @classmethod def _extract_textual_value_keys(cls, label_config: str) -> List[str]: try: root = ET.fromstring(label_config) except Exception as exc: logger.warning("解析 label_config 失败,已跳过占位填充:%s", exc) return [] object_types = LabelStudioTagConfig.get_object_types() seen: Dict[str, None] = {} for element in root.iter(): if element.tag not in object_types: continue if not cls._is_textual_object_tag(element.tag): continue value = element.attrib.get("value", "") if not value.startswith("$"): continue key = value[1:].strip() if not key: continue seen[key] = None return list(seen.keys()) @staticmethod def _needs_placeholder(value: Any) -> bool: if value is None: return True if isinstance(value, str) and not value.strip(): return True return False def _apply_text_placeholders(self, data: Dict[str, Any], label_config: Optional[str]) -> None: if not label_config: return for key in self._extract_textual_value_keys(label_config): if self._needs_placeholder(data.get(key)): data[key] = key @staticmethod def _header_already_present(header: ET.Element, name: str) -> bool: value = header.attrib.get("value", "") if value == name: return True header_text = (header.text or "").strip() return header_text == name def _decorate_label_config_for_editor(self, label_config: str) -> str: try: root = ET.fromstring(label_config) except Exception as exc: logger.warning("解析 label_config 失败,已跳过 name 展示增强:%s", exc) return label_config object_types = LabelStudioTagConfig.get_object_types() used_names = set() for element in root.iter(): name = element.attrib.get("name") if name: used_names.add(name) def allocate_header_name(base: str) -> str: candidate = f"{OBJECT_NAME_HEADER_PREFIX}{base}" if candidate not in used_names: used_names.add(candidate) return candidate idx = 1 while f"{candidate}_{idx}" in used_names: idx += 1 resolved = f"{candidate}_{idx}" used_names.add(resolved) return resolved for parent in root.iter(): children = list(parent) i = 0 while i < len(children): child = children[i] if child.tag not in object_types: i += 1 continue if not self._is_textual_object_tag(child.tag): i += 1 continue obj_name = child.attrib.get("name") if not obj_name: i += 1 continue if i > 0: prev = children[i - 1] if prev.tag == "Header" and self._header_already_present(prev, obj_name): i += 1 continue header = ET.Element("Header") header.set("name", allocate_header_name(obj_name)) header.set("value", obj_name) parent.insert(i, header) children.insert(i, header) i += 2 # continue outer loop return ET.tostring(root, encoding="unicode") def _build_task_data( self, display_text: str, parsed_payload: Optional[Dict[str, Any]], label_config: Optional[str], file_record: DatasetFiles, dataset_id: str, file_id: str, ) -> Dict[str, Any]: data: Dict[str, Any] = dict(parsed_payload or {}) if self._needs_placeholder(data.get(TEXT_DATA_KEY)): data[TEXT_DATA_KEY] = display_text data.setdefault(FILE_ID_KEY, file_id) data.setdefault(DATASET_ID_KEY, dataset_id) data.setdefault(FILE_NAME_KEY, getattr(file_record, "file_name", "")) self._apply_text_placeholders(data, label_config) return data async def get_project_info(self, project_id: str) -> EditorProjectInfo: project = await self._get_project_or_404(project_id) dataset_type = await self._get_dataset_type(project.dataset_id) supported = (dataset_type or "").upper() == "TEXT" unsupported_reason = None if not supported: unsupported_reason = f"当前仅支持 TEXT,项目数据类型为: {dataset_type or 'UNKNOWN'}" # 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置 label_config = await self._resolve_project_label_config(project) return EditorProjectInfo( projectId=project.id, datasetId=project.dataset_id, templateId=project.template_id, labelConfig=label_config, supported=supported, unsupportedReason=unsupported_reason, ) async def list_tasks(self, project_id: str, page: int = 0, size: int = 50) -> EditorTaskListResponse: project = await self._get_project_or_404(project_id) count_result = await self.db.execute( select(func.count()).select_from(DatasetFiles).where( DatasetFiles.dataset_id == project.dataset_id ) ) total = int(count_result.scalar() or 0) files_result = await self.db.execute( select(DatasetFiles) .where(DatasetFiles.dataset_id == project.dataset_id) .order_by(DatasetFiles.created_at.desc()) .offset(page * size) .limit(size) ) files = files_result.scalars().all() file_ids = [str(f.id) for f in files] # type: ignore[arg-type] updated_map: Dict[str, datetime] = {} if file_ids: ann_result = await self.db.execute( select(AnnotationResult.file_id, AnnotationResult.updated_at).where( AnnotationResult.project_id == project_id, AnnotationResult.file_id.in_(file_ids), ) ) for file_id, updated_at in ann_result.all(): if file_id and updated_at: updated_map[str(file_id)] = updated_at items: List[EditorTaskListItem] = [] for f in files: fid = str(f.id) # type: ignore[arg-type] items.append( EditorTaskListItem( fileId=fid, fileName=str(getattr(f, "file_name", "")), fileType=getattr(f, "file_type", None), hasAnnotation=fid in updated_map, annotationUpdatedAt=updated_map.get(fid), ) ) total_pages = (total + size - 1) // size if size > 0 else 0 return EditorTaskListResponse( content=items, totalElements=total, totalPages=total_pages, page=page, size=size, ) async def _fetch_text_content_via_download_api(self, dataset_id: str, file_id: str) -> str: return await fetch_text_content_via_download_api(dataset_id, file_id) async def get_task( self, project_id: str, file_id: str, segment_index: Optional[int] = None, ) -> EditorTaskResponse: project = await self._get_project_or_404(project_id) # TEXT 支持校验 dataset_type = await self._get_dataset_type(project.dataset_id) if (dataset_type or "").upper() != "TEXT": raise HTTPException(status_code=400, detail="当前仅支持 TEXT 项目的内嵌编辑器") file_result = await self.db.execute( select(DatasetFiles).where( DatasetFiles.id == file_id, DatasetFiles.dataset_id == project.dataset_id, ) ) file_record = file_result.scalar_one_or_none() if not file_record: raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}") text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id) assert isinstance(text_content, str) label_config = await self._resolve_project_label_config(project) parsed_payload = self._try_parse_json_payload(text_content) # 获取现有标注 ann_result = await self.db.execute( select(AnnotationResult).where( AnnotationResult.project_id == project_id, AnnotationResult.file_id == file_id, ) ) ann = ann_result.scalar_one_or_none() ls_task_id = self._make_ls_task_id(project_id, file_id) # 判断是否需要分段 needs_segmentation = len(text_content) > self.SEGMENT_THRESHOLD segments: Optional[List[SegmentInfo]] = None current_segment_index = 0 display_text = text_content if needs_segmentation: splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD) raw_segments = splitter.split(text_content) current_segment_index = segment_index if segment_index is not None else 0 # 校验段落索引 if current_segment_index < 0 or current_segment_index >= len(raw_segments): current_segment_index = 0 # 标记每个段落是否已有标注 segment_annotations: Dict[str, Any] = {} if ann and ann.annotation and ann.annotation.get("segmented"): segment_annotations = ann.annotation.get("segments", {}) segments = [] for seg in raw_segments: segments.append(SegmentInfo( idx=seg["idx"], text=seg["text"], start=seg["start"], end=seg["end"], hasAnnotation=str(seg["idx"]) in segment_annotations, )) # 当前段落文本用于 task.data.text display_text = raw_segments[current_segment_index]["text"] # 构造 task 对象 task_data = self._build_task_data( display_text=display_text, parsed_payload=parsed_payload, label_config=label_config, file_record=file_record, dataset_id=project.dataset_id, file_id=file_id, ) task: Dict[str, Any] = { "id": ls_task_id, "data": task_data, "annotations": [], } annotation_updated_at = None if ann: annotation_updated_at = ann.updated_at if needs_segmentation and ann.annotation and ann.annotation.get("segmented"): # 分段模式:获取当前段落的标注 segment_annotations = ann.annotation.get("segments", {}) seg_ann = segment_annotations.get(str(current_segment_index), {}) stored = { "id": self._make_ls_annotation_id(project_id, file_id) + current_segment_index, "task": ls_task_id, "result": seg_ann.get("result", []), "created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"), "updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"), } task["annotations"] = [stored] elif not needs_segmentation: # 非分段模式:直接返回存储的 annotation 原始对象 stored = dict(ann.annotation or {}) stored["task"] = ls_task_id if not isinstance(stored.get("id"), int): stored["id"] = self._make_ls_annotation_id(project_id, file_id) task["annotations"] = [stored] else: # 首次从非分段切换到分段:提供空标注 empty_ann_id = self._make_ls_annotation_id(project_id, file_id) + current_segment_index task["annotations"] = [ { "id": empty_ann_id, "task": ls_task_id, "result": [], "created_at": datetime.utcnow().isoformat() + "Z", "updated_at": datetime.utcnow().isoformat() + "Z", } ] else: # 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result empty_ann_id = self._make_ls_annotation_id(project_id, file_id) if needs_segmentation: empty_ann_id += current_segment_index task["annotations"] = [ { "id": empty_ann_id, "task": ls_task_id, "result": [], "created_at": datetime.utcnow().isoformat() + "Z", "updated_at": datetime.utcnow().isoformat() + "Z", } ] return EditorTaskResponse( task=task, annotationUpdatedAt=annotation_updated_at, segmented=needs_segmentation, segments=segments, totalSegments=len(segments) if segments else 1, currentSegmentIndex=current_segment_index, ) async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse: project = await self._get_project_or_404(project_id) # 校验文件归属 file_result = await self.db.execute( select(DatasetFiles).where( DatasetFiles.id == file_id, DatasetFiles.dataset_id == project.dataset_id, ) ) file_record = file_result.scalar_one_or_none() if not file_record: raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}") annotation_payload = dict(request.annotation or {}) result = annotation_payload.get("result") if not isinstance(result, list): raise HTTPException(status_code=400, detail="annotation.result 必须为数组") ls_task_id = self._make_ls_task_id(project_id, file_id) existing_result = await self.db.execute( select(AnnotationResult).where( AnnotationResult.project_id == project_id, AnnotationResult.file_id == file_id, ) ) existing = existing_result.scalar_one_or_none() now = datetime.utcnow() # 判断是否为分段保存模式 if request.segment_index is not None: # 分段模式:合并段落标注到整体结构 final_payload = self._merge_segment_annotation( existing.annotation if existing else None, request.segment_index, annotation_payload, ) else: # 非分段模式:直接使用传入的 annotation annotation_payload["task"] = ls_task_id if not isinstance(annotation_payload.get("id"), int): annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id) final_payload = annotation_payload if existing: if request.expected_updated_at and existing.updated_at: if existing.updated_at != request.expected_updated_at.replace(tzinfo=None): raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试") existing.annotation = final_payload # type: ignore[assignment] existing.updated_at = now # type: ignore[assignment] await self.db.commit() await self.db.refresh(existing) response = UpsertAnnotationResponse( annotationId=existing.id, updatedAt=existing.updated_at or now, ) await self._sync_annotation_to_knowledge(project, file_record, final_payload, existing.updated_at) return response new_id = str(uuid.uuid4()) record = AnnotationResult( id=new_id, project_id=project_id, file_id=file_id, annotation=final_payload, created_at=now, updated_at=now, ) self.db.add(record) await self.db.commit() await self.db.refresh(record) response = UpsertAnnotationResponse( annotationId=record.id, updatedAt=record.updated_at or now, ) await self._sync_annotation_to_knowledge(project, file_record, final_payload, record.updated_at) return response def _merge_segment_annotation( self, existing: Optional[Dict[str, Any]], segment_index: int, new_annotation: Dict[str, Any], ) -> Dict[str, Any]: """ 合并段落标注到整体结构 Args: existing: 现有的 annotation 数据 segment_index: 段落索引 new_annotation: 新的段落标注数据 Returns: 合并后的 annotation 结构 """ if not existing or not existing.get("segmented"): # 初始化分段结构 base: Dict[str, Any] = { "segmented": True, "version": 1, "segments": {}, } else: base = dict(existing) # 更新指定段落的标注 base["segments"][str(segment_index)] = { "result": new_annotation.get("result", []), "created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"), "updated_at": datetime.utcnow().isoformat() + "Z", } return base async def _sync_annotation_to_knowledge( self, project: LabelingProject, file_record: DatasetFiles, annotation: Dict[str, Any], annotation_updated_at: Optional[datetime], ) -> None: """同步标注结果到知识管理(失败不影响标注保存)""" try: await KnowledgeSyncService(self.db).sync_annotation_to_knowledge( project=project, file_record=file_record, annotation=annotation, annotation_updated_at=annotation_updated_at, ) except Exception as exc: logger.warning("标注同步知识管理失败:%s", exc)