Files
DataMate/runtime/datamate-python/app/module/annotation/service/editor.py
Jerry Yan 5a5279869e feat(annotation): 添加分段总数提示功能优化性能
- 在编辑器服务中添加 segment_total_hint 变量用于缓存分段总数计算结果
- 使用 with_for_update() 锁定查询以避免并发问题
- 将重复的分段总数计算逻辑替换为使用缓存的提示值
- 减少数据库查询次数提升标注任务处理效率
- 优化了分段索引存在时的总数获取流程
2026-01-31 16:28:39 +08:00

1160 lines
44 KiB
Python

"""
标注编辑器(Label Studio Editor)服务
职责:
- 解析 DataMate 标注项目(t_dm_labeling_projects)
- 以“文件下载/预览接口”读取文本内容,构造 Label Studio task
- 以原始 annotation JSON 形式 upsert 最终标注结果(单人单份)
"""
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import hashlib
import json
import xml.etree.ElementTree as ET
from fastapi import HTTPException
from sqlalchemy import case, func, select, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
from app.db.models.annotation_management import (
ANNOTATION_STATUS_ANNOTATED,
ANNOTATION_STATUS_IN_PROGRESS,
ANNOTATION_STATUS_CLIENT_VALUES,
ANNOTATION_STATUS_NO_ANNOTATION,
ANNOTATION_STATUS_NOT_APPLICABLE,
)
from app.module.annotation.config import LabelStudioTagConfig
from app.module.annotation.schema.editor import (
EditorProjectInfo,
EditorTaskListItem,
EditorTaskListResponse,
EditorTaskResponse,
SegmentInfo,
UpsertAnnotationRequest,
UpsertAnnotationResponse,
)
from app.module.annotation.service.template import AnnotationTemplateService
from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter
from app.module.annotation.service.text_fetcher import fetch_text_content_via_download_api
logger = get_logger(__name__)
TEXT_DATA_KEY = "text"
IMAGE_DATA_KEY = "image"
AUDIO_DATA_KEY = "audio"
VIDEO_DATA_KEY = "video"
DATASET_ID_KEY = "dataset_id"
FILE_ID_KEY = "file_id"
FILE_NAME_KEY = "file_name"
DATASET_ID_CAMEL_KEY = "datasetId"
FILE_ID_CAMEL_KEY = "fileId"
FILE_NAME_CAMEL_KEY = "fileName"
SEGMENT_INDEX_KEY = "segment_index"
SEGMENT_INDEX_CAMEL_KEY = "segmentIndex"
SEGMENTED_KEY = "segmented"
SEGMENTS_KEY = "segments"
SEGMENT_TOTAL_KEY = "total_segments"
SEGMENT_RESULT_KEY = "result"
SEGMENT_CREATED_AT_KEY = "created_at"
SEGMENT_UPDATED_AT_KEY = "updated_at"
JSONL_EXTENSION = ".jsonl"
TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
IMAGE_OBJECT_CATEGORIES = {"image"}
MEDIA_OBJECT_CATEGORIES = {"media"}
OBJECT_NAME_HEADER_PREFIX = "dm_object_header_"
DATASET_TYPE_TEXT = "TEXT"
DATASET_TYPE_IMAGE = "IMAGE"
DATASET_TYPE_AUDIO = "AUDIO"
DATASET_TYPE_VIDEO = "VIDEO"
SUPPORTED_EDITOR_DATASET_TYPES = (
DATASET_TYPE_TEXT,
DATASET_TYPE_IMAGE,
DATASET_TYPE_AUDIO,
DATASET_TYPE_VIDEO,
)
SEGMENTATION_ENABLED_KEY = "segmentation_enabled"
SOURCE_DOCUMENT_EXTENSIONS = (".pdf", ".doc", ".docx")
SOURCE_DOCUMENT_TYPES = ("pdf", "doc", "docx")
class AnnotationEditorService:
"""Label Studio Editor 集成服务(TEXT POC 版)"""
# 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession):
self.db = db
self.template_service = AnnotationTemplateService()
@staticmethod
def _stable_ls_id(seed: str) -> int:
"""
生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。
说明:
- Label Studio Frontend 的 mobx-state-tree 模型对 task/annotation 的 id 有类型约束(通常为 number)。
- DataMate 使用 UUID 作为 file_id/project_id,因此需映射为整数供编辑器使用。
- 取 sha1 的前 13 个 hex(52bit),落在 JS Number 的安全整数范围。
"""
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
value = int(digest[:13], 16)
return value if value > 0 else 1
def _make_ls_task_id(self, project_id: str, file_id: str) -> int:
return self._stable_ls_id(f"task:{project_id}:{file_id}")
def _make_ls_annotation_id(self, project_id: str, file_id: str) -> int:
# 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可
return self._stable_ls_id(f"annotation:{project_id}:{file_id}")
@staticmethod
def _normalize_dataset_type(dataset_type: Optional[str]) -> str:
return (dataset_type or "").upper()
@staticmethod
def _resolve_public_api_prefix() -> str:
base = (settings.datamate_backend_base_url or "").strip()
if not base:
return "/api"
parsed = urlparse(base)
if parsed.scheme and parsed.netloc:
prefix = parsed.path
else:
prefix = base
prefix = prefix.rstrip("/")
if not prefix:
return "/api"
if not prefix.startswith("/"):
prefix = "/" + prefix
return prefix
@classmethod
def _build_file_preview_url(cls, dataset_id: str, file_id: str) -> str:
prefix = cls._resolve_public_api_prefix()
return f"{prefix}/data-management/datasets/{dataset_id}/files/{file_id}/preview"
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None),
)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
result = await self.db.execute(
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
)
return result.scalar_one_or_none()
async def _get_label_config(self, template_id: Optional[str]) -> Optional[str]:
if not template_id:
return None
template = await self.template_service.get_template(self.db, template_id)
return getattr(template, "label_config", None) if template else None
async def _resolve_project_label_config(self, project: LabelingProject) -> Optional[str]:
label_config = None
if project.configuration and isinstance(project.configuration, dict):
label_config = project.configuration.get("label_config")
if not label_config:
label_config = await self._get_label_config(project.template_id)
return label_config
@staticmethod
def _resolve_segmentation_enabled(project: LabelingProject) -> bool:
config = project.configuration
if not isinstance(config, dict):
return True
value = config.get(SEGMENTATION_ENABLED_KEY)
if isinstance(value, bool):
return value
if value is None:
return True
return bool(value)
@classmethod
def _resolve_primary_text_key(cls, label_config: Optional[str]) -> Optional[str]:
if not label_config:
return None
keys = cls._extract_textual_value_keys(label_config)
if not keys:
return None
if TEXT_DATA_KEY in keys:
return TEXT_DATA_KEY
return keys[0]
@classmethod
def _resolve_media_value_keys(
cls,
label_config: Optional[str],
default_key: str,
categories: Optional[set[str]] = None,
) -> List[str]:
if not label_config:
return [default_key]
target_categories = categories or set()
keys = cls._extract_object_value_keys_by_category(label_config, target_categories)
if not keys:
return [default_key]
return keys
@staticmethod
def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]:
if not text_content:
return None
stripped = text_content.strip()
if not stripped:
return None
if stripped[0] not in ("{", "["):
return None
try:
parsed = json.loads(stripped)
except Exception:
return None
return parsed if isinstance(parsed, dict) else None
@classmethod
def _parse_jsonl_records(cls, text_content: str) -> List[Tuple[Optional[Dict[str, Any]], str]]:
lines = [line for line in text_content.splitlines() if line.strip()]
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
for line in lines:
payload = cls._try_parse_json_payload(line)
records.append((payload, line))
return records
@staticmethod
def _is_textual_object_tag(object_tag: str) -> bool:
config = LabelStudioTagConfig.get_object_config(object_tag) or {}
category = config.get("category")
return category in TEXTUAL_OBJECT_CATEGORIES
@classmethod
def _extract_object_value_keys_by_category(
cls,
label_config: str,
categories: set[str],
) -> List[str]:
try:
root = ET.fromstring(label_config)
except Exception as exc:
logger.warning("解析 label_config 失败,已跳过占位填充:%s", exc)
return []
object_types = LabelStudioTagConfig.get_object_types()
seen: Dict[str, None] = {}
for element in root.iter():
if element.tag not in object_types:
continue
config = LabelStudioTagConfig.get_object_config(element.tag) or {}
category = config.get("category")
if categories and category not in categories:
continue
value = element.attrib.get("value", "")
if not value.startswith("$"):
continue
key = value[1:].strip()
if not key:
continue
seen[key] = None
return list(seen.keys())
@classmethod
def _extract_textual_value_keys(cls, label_config: str) -> List[str]:
return cls._extract_object_value_keys_by_category(label_config, TEXTUAL_OBJECT_CATEGORIES)
@staticmethod
def _needs_placeholder(value: Any) -> bool:
if value is None:
return True
if isinstance(value, str) and not value.strip():
return True
return False
def _apply_text_placeholders(self, data: Dict[str, Any], label_config: Optional[str]) -> None:
if not label_config:
return
for key in self._extract_textual_value_keys(label_config):
if self._needs_placeholder(data.get(key)):
data[key] = key
@staticmethod
def _header_already_present(header: ET.Element, name: str) -> bool:
value = header.attrib.get("value", "")
if value == name:
return True
header_text = (header.text or "").strip()
return header_text == name
def _decorate_label_config_for_editor(self, label_config: str) -> str:
try:
root = ET.fromstring(label_config)
except Exception as exc:
logger.warning("解析 label_config 失败,已跳过 name 展示增强:%s", exc)
return label_config
object_types = LabelStudioTagConfig.get_object_types()
used_names = set()
for element in root.iter():
name = element.attrib.get("name")
if name:
used_names.add(name)
def allocate_header_name(base: str) -> str:
candidate = f"{OBJECT_NAME_HEADER_PREFIX}{base}"
if candidate not in used_names:
used_names.add(candidate)
return candidate
idx = 1
while f"{candidate}_{idx}" in used_names:
idx += 1
resolved = f"{candidate}_{idx}"
used_names.add(resolved)
return resolved
for parent in root.iter():
children = list(parent)
i = 0
while i < len(children):
child = children[i]
if child.tag not in object_types:
i += 1
continue
if not self._is_textual_object_tag(child.tag):
i += 1
continue
obj_name = child.attrib.get("name")
if not obj_name:
i += 1
continue
if i > 0:
prev = children[i - 1]
if prev.tag == "Header" and self._header_already_present(prev, obj_name):
i += 1
continue
header = ET.Element("Header")
header.set("name", allocate_header_name(obj_name))
header.set("value", obj_name)
parent.insert(i, header)
children.insert(i, header)
i += 2
# continue outer loop
return ET.tostring(root, encoding="unicode")
@staticmethod
def _extract_segment_annotations(payload: Optional[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
if not payload or not isinstance(payload, dict):
return {}
segments = payload.get(SEGMENTS_KEY)
if isinstance(segments, dict):
normalized: Dict[str, Dict[str, Any]] = {}
for key, value in segments.items():
if isinstance(value, dict):
normalized[str(key)] = value
return normalized
if isinstance(segments, list):
normalized: Dict[str, Dict[str, Any]] = {}
for idx, value in enumerate(segments):
if not isinstance(value, dict):
continue
key = (
value.get(SEGMENT_INDEX_CAMEL_KEY)
or value.get(SEGMENT_INDEX_KEY)
or value.get("segment")
or value.get("idx")
)
if key is None:
key = idx
normalized[str(key)] = value
return normalized
return {}
@staticmethod
def _is_segmented_annotation(payload: Optional[Dict[str, Any]]) -> bool:
if not payload or not isinstance(payload, dict):
return False
if payload.get(SEGMENTED_KEY):
return True
segments = payload.get(SEGMENTS_KEY)
if isinstance(segments, dict):
return len(segments) > 0
if isinstance(segments, list):
return len(segments) > 0
return False
@staticmethod
def _has_annotation_result(payload: Optional[Dict[str, Any]]) -> bool:
if not payload or not isinstance(payload, dict):
return False
if AnnotationEditorService._is_segmented_annotation(payload):
segments = AnnotationEditorService._extract_segment_annotations(payload)
if not segments:
return False
for segment in segments.values():
if not isinstance(segment, dict):
continue
result = segment.get(SEGMENT_RESULT_KEY)
if isinstance(result, list) and len(result) > 0:
return True
return False
result = payload.get(SEGMENT_RESULT_KEY)
return isinstance(result, list) and len(result) > 0
@staticmethod
def _resolve_segment_total(payload: Optional[Dict[str, Any]]) -> Optional[int]:
if not payload or not isinstance(payload, dict):
return None
value = payload.get(SEGMENT_TOTAL_KEY)
if isinstance(value, int):
return value if value > 0 else None
if isinstance(value, float) and value.is_integer():
return int(value) if value > 0 else None
if isinstance(value, str) and value.isdigit():
parsed = int(value)
return parsed if parsed > 0 else None
return None
async def _compute_segment_total(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
) -> Optional[int]:
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
if dataset_type != DATASET_TYPE_TEXT:
return None
if not self._resolve_segmentation_enabled(project):
return None
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
if not isinstance(text_content, str):
return None
label_config = await self._resolve_project_label_config(project)
primary_text_key = self._resolve_primary_text_key(label_config)
file_name = str(getattr(file_record, "file_name", "")).lower()
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
if file_name.endswith(JSONL_EXTENSION):
records = self._parse_jsonl_records(text_content)
else:
parsed_payload = self._try_parse_json_payload(text_content)
if parsed_payload:
records = [(parsed_payload, text_content)]
if not records:
records = [(None, text_content)]
record_texts = [
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
for payload, raw_text in records
]
if not record_texts:
record_texts = [text_content]
needs_segmentation = len(records) > 1 or any(
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
)
if not needs_segmentation:
return None
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
total_segments = 0
for record_text in record_texts:
normalized_text = record_text or ""
if len(normalized_text) > self.SEGMENT_THRESHOLD:
raw_segments = splitter.split(normalized_text)
total_segments += len(raw_segments) if raw_segments else 1
else:
total_segments += 1
return total_segments if total_segments > 0 else 1
@classmethod
def _build_source_document_filter(cls):
file_type_lower = func.lower(DatasetFiles.file_type)
file_name_lower = func.lower(DatasetFiles.file_name)
type_condition = file_type_lower.in_(SOURCE_DOCUMENT_TYPES)
name_conditions = [file_name_lower.like(f"%{ext}") for ext in SOURCE_DOCUMENT_EXTENSIONS]
return or_(type_condition, *name_conditions)
def _build_task_data(
self,
display_text: str,
parsed_payload: Optional[Dict[str, Any]],
label_config: Optional[str],
file_record: DatasetFiles,
dataset_id: str,
file_id: str,
primary_text_key: Optional[str],
) -> Dict[str, Any]:
data: Dict[str, Any] = dict(parsed_payload or {})
text_key = primary_text_key or TEXT_DATA_KEY
data[text_key] = display_text
file_name = str(getattr(file_record, "file_name", ""))
data[FILE_ID_KEY] = file_id
data[FILE_ID_CAMEL_KEY] = file_id
data[DATASET_ID_KEY] = dataset_id
data[DATASET_ID_CAMEL_KEY] = dataset_id
data[FILE_NAME_KEY] = file_name
data[FILE_NAME_CAMEL_KEY] = file_name
self._apply_text_placeholders(data, label_config)
return data
@classmethod
def _resolve_primary_text_value(
cls,
parsed_payload: Optional[Dict[str, Any]],
raw_text: str,
primary_text_key: Optional[str],
) -> str:
if parsed_payload and primary_text_key:
value = parsed_payload.get(primary_text_key)
if isinstance(value, str) and value.strip():
return value
if parsed_payload and not primary_text_key:
value = parsed_payload.get(TEXT_DATA_KEY)
if isinstance(value, str) and value.strip():
return value
return raw_text
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
project = await self._get_project_or_404(project_id)
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
supported = dataset_type in SUPPORTED_EDITOR_DATASET_TYPES
unsupported_reason = None
if not supported:
supported_hint = "/".join(SUPPORTED_EDITOR_DATASET_TYPES)
unsupported_reason = f"当前仅支持 {supported_hint},项目数据类型为: {dataset_type or 'UNKNOWN'}"
# 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置
label_config = await self._resolve_project_label_config(project)
return EditorProjectInfo(
projectId=project.id,
datasetId=project.dataset_id,
datasetType=dataset_type or None,
templateId=project.template_id,
labelConfig=label_config,
supported=supported,
unsupportedReason=unsupported_reason,
)
async def list_tasks(
self,
project_id: str,
page: int = 0,
size: int = 50,
exclude_source_documents: Optional[bool] = None,
) -> EditorTaskListResponse:
project = await self._get_project_or_404(project_id)
base_conditions = [
LabelingProjectFile.project_id == project_id,
DatasetFiles.dataset_id == project.dataset_id,
]
count_result = await self.db.execute(
select(func.count())
.select_from(LabelingProjectFile)
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
.where(*base_conditions)
)
total = int(count_result.scalar() or 0)
annotated_sort_key = case(
(AnnotationResult.id.isnot(None), 1),
else_=0,
)
files_result = await self.db.execute(
select(
DatasetFiles,
AnnotationResult.id,
AnnotationResult.updated_at,
AnnotationResult.annotation_status,
)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.outerjoin(
AnnotationResult,
(AnnotationResult.file_id == DatasetFiles.id)
& (AnnotationResult.project_id == project_id),
)
.where(*base_conditions)
.order_by(annotated_sort_key.asc(), DatasetFiles.created_at.desc())
.offset(page * size)
.limit(size)
)
rows = files_result.all()
items: List[EditorTaskListItem] = []
for file_record, annotation_id, annotation_updated_at, annotation_status in rows:
fid = str(file_record.id) # type: ignore[arg-type]
items.append(
EditorTaskListItem(
fileId=fid,
fileName=str(getattr(file_record, "file_name", "")),
fileType=getattr(file_record, "file_type", None),
hasAnnotation=annotation_id is not None,
annotationUpdatedAt=annotation_updated_at,
annotationStatus=annotation_status,
)
)
total_pages = (total + size - 1) // size if size > 0 else 0
return EditorTaskListResponse(
content=items,
totalElements=total,
totalPages=total_pages,
page=page,
size=size,
)
async def _fetch_text_content_via_download_api(self, dataset_id: str, file_id: str) -> str:
return await fetch_text_content_via_download_api(dataset_id, file_id)
async def get_task(
self,
project_id: str,
file_id: str,
segment_index: Optional[int] = None,
) -> EditorTaskResponse:
project = await self._get_project_or_404(project_id)
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
if dataset_type not in SUPPORTED_EDITOR_DATASET_TYPES:
raise HTTPException(
status_code=400,
detail="当前仅支持 TEXT/IMAGE/AUDIO/VIDEO 项目的内嵌编辑器",
)
file_result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
file_record = file_result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
if dataset_type == DATASET_TYPE_IMAGE:
return await self._build_image_task(project, file_record, file_id)
if dataset_type == DATASET_TYPE_AUDIO:
return await self._build_audio_task(project, file_record, file_id)
if dataset_type == DATASET_TYPE_VIDEO:
return await self._build_video_task(project, file_record, file_id)
return await self._build_text_task(project, file_record, file_id, segment_index)
async def _build_text_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
segment_index: Optional[int],
) -> EditorTaskResponse:
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
assert isinstance(text_content, str)
label_config = await self._resolve_project_label_config(project)
primary_text_key = self._resolve_primary_text_key(label_config)
file_name = str(getattr(file_record, "file_name", "")).lower()
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
if file_name.endswith(JSONL_EXTENSION):
records = self._parse_jsonl_records(text_content)
else:
parsed_payload = self._try_parse_json_payload(text_content)
if parsed_payload:
records = [(parsed_payload, text_content)]
if not records:
records = [(None, text_content)]
record_texts = [
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
for payload, raw_text in records
]
if not record_texts:
record_texts = [text_content]
# 获取现有标注
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project.id,
AnnotationResult.file_id == file_id,
)
)
ann = ann_result.scalar_one_or_none()
ls_task_id = self._make_ls_task_id(project.id, file_id)
segment_annotations: Dict[str, Dict[str, Any]] = {}
has_segmented_annotation = False
if ann and isinstance(ann.annotation, dict):
segment_annotations = self._extract_segment_annotations(ann.annotation)
has_segmented_annotation = self._is_segmented_annotation(ann.annotation)
segment_annotation_keys = set(segment_annotations.keys())
# 判断是否需要分段(JSONL 多行或主文本超过阈值)
segmentation_enabled = self._resolve_segmentation_enabled(project)
if not segmentation_enabled:
segment_index = None
needs_segmentation = segmentation_enabled and (
len(records) > 1 or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
)
segments: Optional[List[SegmentInfo]] = None
current_segment_index = 0
display_text = record_texts[0] if record_texts else text_content
selected_payload = records[0][0] if records else None
if not segmentation_enabled and len(records) > 1:
selected_payload = None
display_text = "\n".join(record_texts) if record_texts else text_content
if needs_segmentation:
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
segments = []
segment_cursor = 0
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
normalized_text = record_text or ""
if len(normalized_text) > self.SEGMENT_THRESHOLD:
raw_segments = splitter.split(normalized_text)
for chunk_index, seg in enumerate(raw_segments):
segments.append(SegmentInfo(
idx=segment_cursor,
text=seg["text"],
start=seg["start"],
end=seg["end"],
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=chunk_index,
))
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
segment_cursor += 1
else:
segments.append(SegmentInfo(
idx=segment_cursor,
text=normalized_text,
start=0,
end=len(normalized_text),
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=0,
))
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
segment_cursor += 1
if not segments:
segments = [SegmentInfo(idx=0, text="", start=0, end=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
segment_contexts = [(None, "", "", 0, 0)]
current_segment_index = segment_index if segment_index is not None else 0
if current_segment_index < 0 or current_segment_index >= len(segments):
current_segment_index = 0
selected_payload, _, display_text, _, _ = segment_contexts[current_segment_index]
# 构造 task 对象
task_data = self._build_task_data(
display_text=display_text,
parsed_payload=selected_payload,
label_config=label_config,
file_record=file_record,
dataset_id=project.dataset_id,
file_id=file_id,
primary_text_key=primary_text_key,
)
if needs_segmentation:
task_data[SEGMENT_INDEX_KEY] = current_segment_index
task_data[SEGMENT_INDEX_CAMEL_KEY] = current_segment_index
task: Dict[str, Any] = {
"id": ls_task_id,
"data": task_data,
"annotations": [],
}
annotation_updated_at = None
if ann:
annotation_updated_at = ann.updated_at
if needs_segmentation and has_segmented_annotation:
# 分段模式:获取当前段落的标注
seg_ann = segment_annotations.get(str(current_segment_index), {})
stored = {
"id": self._make_ls_annotation_id(project.id, file_id) + current_segment_index,
"task": ls_task_id,
"result": seg_ann.get(SEGMENT_RESULT_KEY, []),
"created_at": seg_ann.get(SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
"updated_at": seg_ann.get(SEGMENT_UPDATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
}
task["annotations"] = [stored]
elif not needs_segmentation and not has_segmented_annotation:
# 非分段模式:直接返回存储的 annotation 原始对象
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [stored]
else:
# 首次从非分段切换到分段:提供空标注
empty_ann_id = self._make_ls_annotation_id(project.id, file_id) + current_segment_index
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
else:
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
if needs_segmentation:
empty_ann_id += current_segment_index
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
return EditorTaskResponse(
task=task,
annotationUpdatedAt=annotation_updated_at,
segmented=needs_segmentation,
segments=segments,
totalSegments=len(segments) if segments else 1,
currentSegmentIndex=current_segment_index,
)
async def _build_media_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
default_key: str,
categories: set[str],
) -> EditorTaskResponse:
label_config = await self._resolve_project_label_config(project)
media_keys = self._resolve_media_value_keys(label_config, default_key, categories)
preview_url = self._build_file_preview_url(project.dataset_id, file_id)
file_name = str(getattr(file_record, "file_name", ""))
task_data: Dict[str, Any] = {
FILE_ID_KEY: file_id,
FILE_ID_CAMEL_KEY: file_id,
DATASET_ID_KEY: project.dataset_id,
DATASET_ID_CAMEL_KEY: project.dataset_id,
FILE_NAME_KEY: file_name,
FILE_NAME_CAMEL_KEY: file_name,
}
for key in media_keys:
task_data[key] = preview_url
self._apply_text_placeholders(task_data, label_config)
# 获取现有标注
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project.id,
AnnotationResult.file_id == file_id,
)
)
ann = ann_result.scalar_one_or_none()
ls_task_id = self._make_ls_task_id(project.id, file_id)
task: Dict[str, Any] = {
"id": ls_task_id,
"data": task_data,
"annotations": [],
}
annotation_updated_at = None
if ann and not (ann.annotation or {}).get("segmented"):
annotation_updated_at = ann.updated_at
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [stored]
else:
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
return EditorTaskResponse(
task=task,
annotationUpdatedAt=annotation_updated_at,
segmented=False,
segments=None,
totalSegments=1,
currentSegmentIndex=0,
)
async def _build_image_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
) -> EditorTaskResponse:
return await self._build_media_task(
project=project,
file_record=file_record,
file_id=file_id,
default_key=IMAGE_DATA_KEY,
categories=IMAGE_OBJECT_CATEGORIES,
)
async def _build_audio_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
) -> EditorTaskResponse:
return await self._build_media_task(
project=project,
file_record=file_record,
file_id=file_id,
default_key=AUDIO_DATA_KEY,
categories=MEDIA_OBJECT_CATEGORIES,
)
async def _build_video_task(
self,
project: LabelingProject,
file_record: DatasetFiles,
file_id: str,
) -> EditorTaskResponse:
return await self._build_media_task(
project=project,
file_record=file_record,
file_id=file_id,
default_key=VIDEO_DATA_KEY,
categories=MEDIA_OBJECT_CATEGORIES,
)
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
project = await self._get_project_or_404(project_id)
# 校验文件归属
file_result = await self.db.execute(
select(DatasetFiles)
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
.where(
LabelingProjectFile.project_id == project.id,
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
file_record = file_result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
annotation_payload = dict(request.annotation or {})
result = annotation_payload.get("result")
if result is None:
annotation_payload["result"] = []
elif not isinstance(result, list):
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
ls_task_id = self._make_ls_task_id(project_id, file_id)
segment_total_hint = None
if request.segment_index is not None:
segment_total_hint = self._resolve_segment_total(annotation_payload)
if segment_total_hint is None:
segment_total_hint = await self._compute_segment_total(project, file_record, file_id)
existing_result = await self.db.execute(
select(AnnotationResult)
.where(
AnnotationResult.project_id == project_id,
AnnotationResult.file_id == file_id,
)
.with_for_update()
)
existing = existing_result.scalar_one_or_none()
now = datetime.utcnow()
# 判断是否为分段保存模式
if request.segment_index is not None:
# 分段模式:合并段落标注到整体结构
final_payload = self._merge_segment_annotation(
existing.annotation if existing else None,
request.segment_index,
annotation_payload,
)
else:
# 非分段模式:直接使用传入的 annotation
annotation_payload["task"] = ls_task_id
if not isinstance(annotation_payload.get("id"), int):
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
final_payload = annotation_payload
requested_status = request.annotation_status
if requested_status is not None and requested_status not in ANNOTATION_STATUS_CLIENT_VALUES:
raise HTTPException(status_code=400, detail="annotationStatus 不合法")
segment_total = None
segment_done = None
if request.segment_index is not None:
segment_total = self._resolve_segment_total(final_payload)
if segment_total is None:
segment_total = segment_total_hint
if segment_total and segment_total > 0:
final_payload[SEGMENT_TOTAL_KEY] = segment_total
segment_done = len(self._extract_segment_annotations(final_payload))
if (
segment_total is not None
and segment_done is not None
and segment_done < segment_total
):
final_status = ANNOTATION_STATUS_IN_PROGRESS
else:
has_result = self._has_annotation_result(final_payload)
if has_result:
final_status = ANNOTATION_STATUS_ANNOTATED
else:
if requested_status == ANNOTATION_STATUS_NO_ANNOTATION:
final_status = ANNOTATION_STATUS_NO_ANNOTATION
elif requested_status == ANNOTATION_STATUS_NOT_APPLICABLE:
final_status = ANNOTATION_STATUS_NOT_APPLICABLE
else:
raise HTTPException(status_code=400, detail="未发现标注内容,请确认无标注/不适用后再保存")
if existing:
if request.expected_updated_at and existing.updated_at:
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
existing.annotation = final_payload # type: ignore[assignment]
existing.annotation_status = final_status # type: ignore[assignment]
existing.updated_at = now # type: ignore[assignment]
await self.db.commit()
await self.db.refresh(existing)
response = UpsertAnnotationResponse(
annotationId=existing.id,
updatedAt=existing.updated_at or now,
)
await self._sync_annotation_to_knowledge(project, file_record, final_payload, existing.updated_at)
return response
new_id = str(uuid.uuid4())
record = AnnotationResult(
id=new_id,
project_id=project_id,
file_id=file_id,
annotation=final_payload,
annotation_status=final_status,
created_at=now,
updated_at=now,
)
self.db.add(record)
await self.db.commit()
await self.db.refresh(record)
response = UpsertAnnotationResponse(
annotationId=record.id,
updatedAt=record.updated_at or now,
)
await self._sync_annotation_to_knowledge(project, file_record, final_payload, record.updated_at)
return response
def _merge_segment_annotation(
self,
existing: Optional[Dict[str, Any]],
segment_index: int,
new_annotation: Dict[str, Any],
) -> Dict[str, Any]:
"""
合并段落标注到整体结构
Args:
existing: 现有的 annotation 数据
segment_index: 段落索引
new_annotation: 新的段落标注数据
Returns:
合并后的 annotation 结构
"""
if not existing or not existing.get(SEGMENTED_KEY):
# 初始化分段结构
base: Dict[str, Any] = {
SEGMENTED_KEY: True,
"version": 1,
SEGMENTS_KEY: {},
}
else:
base = dict(existing)
if not base.get(SEGMENTED_KEY):
base[SEGMENTED_KEY] = True
segments = base.get(SEGMENTS_KEY)
if not isinstance(segments, dict):
segments = {}
base[SEGMENTS_KEY] = segments
# 更新指定段落的标注
segments[str(segment_index)] = {
SEGMENT_RESULT_KEY: new_annotation.get(SEGMENT_RESULT_KEY, []),
SEGMENT_CREATED_AT_KEY: new_annotation.get(SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
}
return base
async def _sync_annotation_to_knowledge(
self,
project: LabelingProject,
file_record: DatasetFiles,
annotation: Dict[str, Any],
annotation_updated_at: Optional[datetime],
) -> None:
"""同步标注结果到知识管理(失败不影响标注保存)"""
try:
await KnowledgeSyncService(self.db).sync_annotation_to_knowledge(
project=project,
file_record=file_record,
annotation=annotation,
annotation_updated_at=annotation_updated_at,
)
except Exception as exc:
logger.warning("标注同步知识管理失败:%s", exc)