You've already forked DataMate
Change has_new_version logic to compare current file version with latest version, regardless of whether annotation exists. Before: Only show warning if annotation exists and version is outdated After: Show warning if current file is not the latest version This ensures users are informed when viewing an old file version, even if they haven't started annotating yet.
1823 lines
68 KiB
Python
1823 lines
68 KiB
Python
"""
|
|
标注编辑器(Label Studio Editor)服务
|
|
|
|
职责:
|
|
- 解析 DataMate 标注项目(t_dm_labeling_projects)
|
|
- 以“文件下载/预览接口”读取文本内容,构造 Label Studio task
|
|
- 以原始 annotation JSON 形式 upsert 最终标注结果(单人单份)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
import hashlib
|
|
import json
|
|
import xml.etree.ElementTree as ET
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import case, func, select, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
from app.db.models import (
|
|
AnnotationResult,
|
|
Dataset,
|
|
DatasetFiles,
|
|
LabelingProject,
|
|
LabelingProjectFile,
|
|
)
|
|
from app.db.models.annotation_management import (
|
|
ANNOTATION_STATUS_ANNOTATED,
|
|
ANNOTATION_STATUS_IN_PROGRESS,
|
|
ANNOTATION_STATUS_CLIENT_VALUES,
|
|
ANNOTATION_STATUS_NO_ANNOTATION,
|
|
ANNOTATION_STATUS_NOT_APPLICABLE,
|
|
)
|
|
from app.module.annotation.config import LabelStudioTagConfig
|
|
from app.module.annotation.schema.editor import (
|
|
EditorProjectInfo,
|
|
EditorTaskListItem,
|
|
EditorTaskListResponse,
|
|
EditorTaskSegmentResponse,
|
|
EditorTaskResponse,
|
|
SegmentDetail,
|
|
SegmentInfo,
|
|
UpsertAnnotationRequest,
|
|
UpsertAnnotationResponse,
|
|
)
|
|
from app.module.annotation.service.template import AnnotationTemplateService
|
|
from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
|
|
from app.module.annotation.service.annotation_text_splitter import (
|
|
AnnotationTextSplitter,
|
|
)
|
|
from app.module.annotation.service.text_fetcher import (
|
|
fetch_text_content_via_download_api,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
TEXT_DATA_KEY = "text"
|
|
IMAGE_DATA_KEY = "image"
|
|
AUDIO_DATA_KEY = "audio"
|
|
VIDEO_DATA_KEY = "video"
|
|
DATASET_ID_KEY = "dataset_id"
|
|
FILE_ID_KEY = "file_id"
|
|
FILE_NAME_KEY = "file_name"
|
|
DATASET_ID_CAMEL_KEY = "datasetId"
|
|
FILE_ID_CAMEL_KEY = "fileId"
|
|
FILE_NAME_CAMEL_KEY = "fileName"
|
|
SEGMENT_INDEX_KEY = "segment_index"
|
|
SEGMENT_INDEX_CAMEL_KEY = "segmentIndex"
|
|
SEGMENTED_KEY = "segmented"
|
|
SEGMENTS_KEY = "segments"
|
|
SEGMENT_TOTAL_KEY = "total_segments"
|
|
SEGMENT_RESULT_KEY = "result"
|
|
SEGMENT_CREATED_AT_KEY = "created_at"
|
|
SEGMENT_UPDATED_AT_KEY = "updated_at"
|
|
JSONL_EXTENSION = ".jsonl"
|
|
TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
|
|
IMAGE_OBJECT_CATEGORIES = {"image"}
|
|
MEDIA_OBJECT_CATEGORIES = {"media"}
|
|
OBJECT_NAME_HEADER_PREFIX = "dm_object_header_"
|
|
DATASET_TYPE_TEXT = "TEXT"
|
|
DATASET_TYPE_IMAGE = "IMAGE"
|
|
DATASET_TYPE_AUDIO = "AUDIO"
|
|
DATASET_TYPE_VIDEO = "VIDEO"
|
|
SUPPORTED_EDITOR_DATASET_TYPES = (
|
|
DATASET_TYPE_TEXT,
|
|
DATASET_TYPE_IMAGE,
|
|
DATASET_TYPE_AUDIO,
|
|
DATASET_TYPE_VIDEO,
|
|
)
|
|
SEGMENTATION_ENABLED_KEY = "segmentation_enabled"
|
|
SOURCE_DOCUMENT_EXTENSIONS = (".pdf", ".doc", ".docx")
|
|
SOURCE_DOCUMENT_TYPES = ("pdf", "doc", "docx")
|
|
|
|
|
|
class AnnotationEditorService:
|
|
"""Label Studio Editor 集成服务(TEXT POC 版)"""
|
|
|
|
# 分段阈值:超过此字符数自动分段
|
|
SEGMENT_THRESHOLD = 200
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
self.template_service = AnnotationTemplateService()
|
|
|
|
@staticmethod
|
|
def _stable_ls_id(seed: str) -> int:
|
|
"""
|
|
生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。
|
|
|
|
说明:
|
|
- Label Studio Frontend 的 mobx-state-tree 模型对 task/annotation 的 id 有类型约束(通常为 number)。
|
|
- DataMate 使用 UUID 作为 file_id/project_id,因此需映射为整数供编辑器使用。
|
|
- 取 sha1 的前 13 个 hex(52bit),落在 JS Number 的安全整数范围。
|
|
"""
|
|
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
|
|
value = int(digest[:13], 16)
|
|
return value if value > 0 else 1
|
|
|
|
def _make_ls_task_id(self, project_id: str, file_id: str) -> int:
|
|
return self._stable_ls_id(f"task:{project_id}:{file_id}")
|
|
|
|
def _make_ls_annotation_id(self, project_id: str, file_id: str) -> int:
|
|
# 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可
|
|
return self._stable_ls_id(f"annotation:{project_id}:{file_id}")
|
|
|
|
@staticmethod
|
|
def _normalize_dataset_type(dataset_type: Optional[str]) -> str:
|
|
return (dataset_type or "").upper()
|
|
|
|
@staticmethod
|
|
def _resolve_public_api_prefix() -> str:
|
|
base = (settings.datamate_backend_base_url or "").strip()
|
|
if not base:
|
|
return "/api"
|
|
parsed = urlparse(base)
|
|
if parsed.scheme and parsed.netloc:
|
|
prefix = parsed.path
|
|
else:
|
|
prefix = base
|
|
prefix = prefix.rstrip("/")
|
|
if not prefix:
|
|
return "/api"
|
|
if not prefix.startswith("/"):
|
|
prefix = "/" + prefix
|
|
return prefix
|
|
|
|
@classmethod
|
|
def _build_file_preview_url(cls, dataset_id: str, file_id: str) -> str:
|
|
prefix = cls._resolve_public_api_prefix()
|
|
return f"{prefix}/data-management/datasets/{dataset_id}/files/{file_id}/preview"
|
|
|
|
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.id == project_id,
|
|
LabelingProject.deleted_at.is_(None),
|
|
)
|
|
)
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
|
return project
|
|
|
|
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
|
result = await self.db.execute(
|
|
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def _get_label_config(self, template_id: Optional[str]) -> Optional[str]:
|
|
if not template_id:
|
|
return None
|
|
template = await self.template_service.get_template(self.db, template_id)
|
|
return getattr(template, "label_config", None) if template else None
|
|
|
|
async def _resolve_project_label_config(
|
|
self, project: LabelingProject
|
|
) -> Optional[str]:
|
|
label_config = None
|
|
if project.configuration and isinstance(project.configuration, dict):
|
|
label_config = project.configuration.get("label_config")
|
|
if not label_config:
|
|
label_config = await self._get_label_config(project.template_id)
|
|
return label_config
|
|
|
|
@staticmethod
|
|
def _resolve_segmentation_enabled(project: LabelingProject) -> bool:
|
|
config = project.configuration
|
|
if not isinstance(config, dict):
|
|
return True
|
|
value = config.get(SEGMENTATION_ENABLED_KEY)
|
|
if isinstance(value, bool):
|
|
return value
|
|
if value is None:
|
|
return True
|
|
return bool(value)
|
|
|
|
@classmethod
|
|
def _resolve_primary_text_key(cls, label_config: Optional[str]) -> Optional[str]:
|
|
if not label_config:
|
|
return None
|
|
keys = cls._extract_textual_value_keys(label_config)
|
|
if not keys:
|
|
return None
|
|
if TEXT_DATA_KEY in keys:
|
|
return TEXT_DATA_KEY
|
|
return keys[0]
|
|
|
|
@classmethod
|
|
def _resolve_media_value_keys(
|
|
cls,
|
|
label_config: Optional[str],
|
|
default_key: str,
|
|
categories: Optional[set[str]] = None,
|
|
) -> List[str]:
|
|
if not label_config:
|
|
return [default_key]
|
|
target_categories = categories or set()
|
|
keys = cls._extract_object_value_keys_by_category(
|
|
label_config, target_categories
|
|
)
|
|
if not keys:
|
|
return [default_key]
|
|
return keys
|
|
|
|
@staticmethod
|
|
def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]:
|
|
if not text_content:
|
|
return None
|
|
stripped = text_content.strip()
|
|
if not stripped:
|
|
return None
|
|
if stripped[0] not in ("{", "["):
|
|
return None
|
|
try:
|
|
parsed = json.loads(stripped)
|
|
except Exception:
|
|
return None
|
|
return parsed if isinstance(parsed, dict) else None
|
|
|
|
@classmethod
|
|
def _parse_jsonl_records(
|
|
cls, text_content: str
|
|
) -> List[Tuple[Optional[Dict[str, Any]], str]]:
|
|
lines = [line for line in text_content.splitlines() if line.strip()]
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
for line in lines:
|
|
payload = cls._try_parse_json_payload(line)
|
|
records.append((payload, line))
|
|
return records
|
|
|
|
@staticmethod
|
|
def _is_textual_object_tag(object_tag: str) -> bool:
|
|
config = LabelStudioTagConfig.get_object_config(object_tag) or {}
|
|
category = config.get("category")
|
|
return category in TEXTUAL_OBJECT_CATEGORIES
|
|
|
|
@classmethod
|
|
def _extract_object_value_keys_by_category(
|
|
cls,
|
|
label_config: str,
|
|
categories: set[str],
|
|
) -> List[str]:
|
|
try:
|
|
root = ET.fromstring(label_config)
|
|
except Exception as exc:
|
|
logger.warning("解析 label_config 失败,已跳过占位填充:%s", exc)
|
|
return []
|
|
|
|
object_types = LabelStudioTagConfig.get_object_types()
|
|
seen: Dict[str, None] = {}
|
|
for element in root.iter():
|
|
if element.tag not in object_types:
|
|
continue
|
|
config = LabelStudioTagConfig.get_object_config(element.tag) or {}
|
|
category = config.get("category")
|
|
if categories and category not in categories:
|
|
continue
|
|
value = element.attrib.get("value", "")
|
|
if not value.startswith("$"):
|
|
continue
|
|
key = value[1:].strip()
|
|
if not key:
|
|
continue
|
|
seen[key] = None
|
|
return list(seen.keys())
|
|
|
|
@classmethod
|
|
def _extract_textual_value_keys(cls, label_config: str) -> List[str]:
|
|
return cls._extract_object_value_keys_by_category(
|
|
label_config, TEXTUAL_OBJECT_CATEGORIES
|
|
)
|
|
|
|
@staticmethod
|
|
def _needs_placeholder(value: Any) -> bool:
|
|
if value is None:
|
|
return True
|
|
if isinstance(value, str) and not value.strip():
|
|
return True
|
|
return False
|
|
|
|
def _apply_text_placeholders(
|
|
self, data: Dict[str, Any], label_config: Optional[str]
|
|
) -> None:
|
|
if not label_config:
|
|
return
|
|
for key in self._extract_textual_value_keys(label_config):
|
|
if self._needs_placeholder(data.get(key)):
|
|
data[key] = key
|
|
|
|
@staticmethod
|
|
def _header_already_present(header: ET.Element, name: str) -> bool:
|
|
value = header.attrib.get("value", "")
|
|
if value == name:
|
|
return True
|
|
header_text = (header.text or "").strip()
|
|
return header_text == name
|
|
|
|
def _decorate_label_config_for_editor(self, label_config: str) -> str:
|
|
try:
|
|
root = ET.fromstring(label_config)
|
|
except Exception as exc:
|
|
logger.warning("解析 label_config 失败,已跳过 name 展示增强:%s", exc)
|
|
return label_config
|
|
|
|
object_types = LabelStudioTagConfig.get_object_types()
|
|
used_names = set()
|
|
for element in root.iter():
|
|
name = element.attrib.get("name")
|
|
if name:
|
|
used_names.add(name)
|
|
|
|
def allocate_header_name(base: str) -> str:
|
|
candidate = f"{OBJECT_NAME_HEADER_PREFIX}{base}"
|
|
if candidate not in used_names:
|
|
used_names.add(candidate)
|
|
return candidate
|
|
idx = 1
|
|
while f"{candidate}_{idx}" in used_names:
|
|
idx += 1
|
|
resolved = f"{candidate}_{idx}"
|
|
used_names.add(resolved)
|
|
return resolved
|
|
|
|
for parent in root.iter():
|
|
children = list(parent)
|
|
i = 0
|
|
while i < len(children):
|
|
child = children[i]
|
|
if child.tag not in object_types:
|
|
i += 1
|
|
continue
|
|
if not self._is_textual_object_tag(child.tag):
|
|
i += 1
|
|
continue
|
|
obj_name = child.attrib.get("name")
|
|
if not obj_name:
|
|
i += 1
|
|
continue
|
|
|
|
if i > 0:
|
|
prev = children[i - 1]
|
|
if prev.tag == "Header" and self._header_already_present(
|
|
prev, obj_name
|
|
):
|
|
i += 1
|
|
continue
|
|
|
|
header = ET.Element("Header")
|
|
header.set("name", allocate_header_name(obj_name))
|
|
header.set("value", obj_name)
|
|
|
|
parent.insert(i, header)
|
|
children.insert(i, header)
|
|
i += 2
|
|
# continue outer loop
|
|
|
|
return ET.tostring(root, encoding="unicode")
|
|
|
|
@staticmethod
|
|
def _extract_segment_annotations(
|
|
payload: Optional[Dict[str, Any]],
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
if not payload or not isinstance(payload, dict):
|
|
return {}
|
|
segments = payload.get(SEGMENTS_KEY)
|
|
if isinstance(segments, dict):
|
|
normalized: Dict[str, Dict[str, Any]] = {}
|
|
for key, value in segments.items():
|
|
if isinstance(value, dict):
|
|
normalized[str(key)] = value
|
|
return normalized
|
|
if isinstance(segments, list):
|
|
normalized: Dict[str, Dict[str, Any]] = {}
|
|
for idx, value in enumerate(segments):
|
|
if not isinstance(value, dict):
|
|
continue
|
|
key = (
|
|
value.get(SEGMENT_INDEX_CAMEL_KEY)
|
|
or value.get(SEGMENT_INDEX_KEY)
|
|
or value.get("segment")
|
|
or value.get("idx")
|
|
)
|
|
if key is None:
|
|
key = idx
|
|
normalized[str(key)] = value
|
|
return normalized
|
|
return {}
|
|
|
|
@staticmethod
|
|
def _is_segmented_annotation(payload: Optional[Dict[str, Any]]) -> bool:
|
|
if not payload or not isinstance(payload, dict):
|
|
return False
|
|
if payload.get(SEGMENTED_KEY):
|
|
return True
|
|
segments = payload.get(SEGMENTS_KEY)
|
|
if isinstance(segments, dict):
|
|
return len(segments) > 0
|
|
if isinstance(segments, list):
|
|
return len(segments) > 0
|
|
return False
|
|
|
|
@staticmethod
|
|
def _has_annotation_result(payload: Optional[Dict[str, Any]]) -> bool:
|
|
if not payload or not isinstance(payload, dict):
|
|
return False
|
|
if AnnotationEditorService._is_segmented_annotation(payload):
|
|
segments = AnnotationEditorService._extract_segment_annotations(payload)
|
|
if not segments:
|
|
return False
|
|
for segment in segments.values():
|
|
if not isinstance(segment, dict):
|
|
continue
|
|
result = segment.get(SEGMENT_RESULT_KEY)
|
|
if isinstance(result, list) and len(result) > 0:
|
|
return True
|
|
return False
|
|
result = payload.get(SEGMENT_RESULT_KEY)
|
|
return isinstance(result, list) and len(result) > 0
|
|
|
|
@staticmethod
|
|
def _resolve_segment_total(payload: Optional[Dict[str, Any]]) -> Optional[int]:
|
|
if not payload or not isinstance(payload, dict):
|
|
return None
|
|
value = payload.get(SEGMENT_TOTAL_KEY)
|
|
if isinstance(value, int):
|
|
return value if value > 0 else None
|
|
if isinstance(value, float) and value.is_integer():
|
|
return int(value) if value > 0 else None
|
|
if isinstance(value, str) and value.isdigit():
|
|
parsed = int(value)
|
|
return parsed if parsed > 0 else None
|
|
return None
|
|
|
|
async def _compute_segment_total(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> Optional[int]:
|
|
dataset_type = self._normalize_dataset_type(
|
|
await self._get_dataset_type(project.dataset_id)
|
|
)
|
|
if dataset_type != DATASET_TYPE_TEXT:
|
|
return None
|
|
if not self._resolve_segmentation_enabled(project):
|
|
return None
|
|
|
|
text_content = await self._fetch_text_content_via_download_api(
|
|
project.dataset_id, file_id
|
|
)
|
|
if not isinstance(text_content, str):
|
|
return None
|
|
|
|
label_config = await self._resolve_project_label_config(project)
|
|
primary_text_key = self._resolve_primary_text_key(label_config)
|
|
file_name = str(getattr(file_record, "file_name", "")).lower()
|
|
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
if file_name.endswith(JSONL_EXTENSION):
|
|
records = self._parse_jsonl_records(text_content)
|
|
else:
|
|
parsed_payload = self._try_parse_json_payload(text_content)
|
|
if parsed_payload:
|
|
records = [(parsed_payload, text_content)]
|
|
|
|
if not records:
|
|
records = [(None, text_content)]
|
|
|
|
record_texts = [
|
|
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
|
for payload, raw_text in records
|
|
]
|
|
if not record_texts:
|
|
record_texts = [text_content]
|
|
|
|
needs_segmentation = len(records) > 1 or any(
|
|
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
|
|
)
|
|
if not needs_segmentation:
|
|
return None
|
|
|
|
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
|
total_segments = 0
|
|
for record_text in record_texts:
|
|
normalized_text = record_text or ""
|
|
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
|
raw_segments = splitter.split(normalized_text)
|
|
total_segments += len(raw_segments) if raw_segments else 1
|
|
else:
|
|
total_segments += 1
|
|
|
|
return total_segments if total_segments > 0 else 1
|
|
|
|
@classmethod
|
|
def _build_source_document_filter(cls):
|
|
file_type_lower = func.lower(DatasetFiles.file_type)
|
|
file_name_lower = func.lower(DatasetFiles.file_name)
|
|
type_condition = file_type_lower.in_(SOURCE_DOCUMENT_TYPES)
|
|
name_conditions = [
|
|
file_name_lower.like(f"%{ext}") for ext in SOURCE_DOCUMENT_EXTENSIONS
|
|
]
|
|
return or_(type_condition, *name_conditions)
|
|
|
|
def _build_task_data(
|
|
self,
|
|
display_text: str,
|
|
parsed_payload: Optional[Dict[str, Any]],
|
|
label_config: Optional[str],
|
|
file_record: DatasetFiles,
|
|
dataset_id: str,
|
|
file_id: str,
|
|
primary_text_key: Optional[str],
|
|
) -> Dict[str, Any]:
|
|
data: Dict[str, Any] = dict(parsed_payload or {})
|
|
text_key = primary_text_key or TEXT_DATA_KEY
|
|
data[text_key] = display_text
|
|
|
|
file_name = str(getattr(file_record, "file_name", ""))
|
|
data[FILE_ID_KEY] = file_id
|
|
data[FILE_ID_CAMEL_KEY] = file_id
|
|
data[DATASET_ID_KEY] = dataset_id
|
|
data[DATASET_ID_CAMEL_KEY] = dataset_id
|
|
data[FILE_NAME_KEY] = file_name
|
|
data[FILE_NAME_CAMEL_KEY] = file_name
|
|
|
|
self._apply_text_placeholders(data, label_config)
|
|
return data
|
|
|
|
@classmethod
|
|
def _resolve_primary_text_value(
|
|
cls,
|
|
parsed_payload: Optional[Dict[str, Any]],
|
|
raw_text: str,
|
|
primary_text_key: Optional[str],
|
|
) -> str:
|
|
if parsed_payload and primary_text_key:
|
|
value = parsed_payload.get(primary_text_key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
if parsed_payload and not primary_text_key:
|
|
value = parsed_payload.get(TEXT_DATA_KEY)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
return raw_text
|
|
|
|
def _build_segment_contexts(
|
|
self,
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]],
|
|
record_texts: List[str],
|
|
segment_annotation_keys: set[str],
|
|
) -> Tuple[
|
|
List[SegmentInfo], List[Tuple[Optional[Dict[str, Any]], str, str, int, int]]
|
|
]:
|
|
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
|
segments: List[SegmentInfo] = []
|
|
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
|
segment_cursor = 0
|
|
|
|
for record_index, ((payload, raw_text), record_text) in enumerate(
|
|
zip(records, record_texts)
|
|
):
|
|
normalized_text = record_text or ""
|
|
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
|
raw_segments = splitter.split(normalized_text)
|
|
for chunk_index, seg in enumerate(raw_segments):
|
|
segments.append(
|
|
SegmentInfo(
|
|
idx=segment_cursor,
|
|
hasAnnotation=str(segment_cursor)
|
|
in segment_annotation_keys,
|
|
lineIndex=record_index,
|
|
chunkIndex=chunk_index,
|
|
)
|
|
)
|
|
segment_contexts.append(
|
|
(payload, raw_text, seg["text"], record_index, chunk_index)
|
|
)
|
|
segment_cursor += 1
|
|
else:
|
|
segments.append(
|
|
SegmentInfo(
|
|
idx=segment_cursor,
|
|
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
|
|
lineIndex=record_index,
|
|
chunkIndex=0,
|
|
)
|
|
)
|
|
segment_contexts.append(
|
|
(payload, raw_text, normalized_text, record_index, 0)
|
|
)
|
|
segment_cursor += 1
|
|
|
|
if not segments:
|
|
segments = [
|
|
SegmentInfo(idx=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)
|
|
]
|
|
segment_contexts = [(None, "", "", 0, 0)]
|
|
|
|
return segments, segment_contexts
|
|
|
|
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
dataset_type = self._normalize_dataset_type(
|
|
await self._get_dataset_type(project.dataset_id)
|
|
)
|
|
supported = dataset_type in SUPPORTED_EDITOR_DATASET_TYPES
|
|
unsupported_reason = None
|
|
if not supported:
|
|
supported_hint = "/".join(SUPPORTED_EDITOR_DATASET_TYPES)
|
|
unsupported_reason = f"当前仅支持 {supported_hint},项目数据类型为: {dataset_type or 'UNKNOWN'}"
|
|
|
|
# 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置
|
|
label_config = await self._resolve_project_label_config(project)
|
|
|
|
return EditorProjectInfo(
|
|
projectId=project.id,
|
|
datasetId=project.dataset_id,
|
|
datasetType=dataset_type or None,
|
|
templateId=project.template_id,
|
|
labelConfig=label_config,
|
|
supported=supported,
|
|
unsupportedReason=unsupported_reason,
|
|
)
|
|
|
|
async def list_tasks(
|
|
self,
|
|
project_id: str,
|
|
page: int = 0,
|
|
size: int = 50,
|
|
exclude_source_documents: Optional[bool] = None,
|
|
) -> EditorTaskListResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
base_conditions = [
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
]
|
|
|
|
count_result = await self.db.execute(
|
|
select(func.count())
|
|
.select_from(LabelingProjectFile)
|
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(*base_conditions)
|
|
)
|
|
total = int(count_result.scalar() or 0)
|
|
|
|
annotated_sort_key = case(
|
|
(AnnotationResult.id.isnot(None), 1),
|
|
else_=0,
|
|
)
|
|
files_result = await self.db.execute(
|
|
select(
|
|
DatasetFiles,
|
|
AnnotationResult.id,
|
|
AnnotationResult.updated_at,
|
|
AnnotationResult.annotation_status,
|
|
)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.outerjoin(
|
|
AnnotationResult,
|
|
(AnnotationResult.file_id == DatasetFiles.id)
|
|
& (AnnotationResult.project_id == project_id),
|
|
)
|
|
.where(*base_conditions)
|
|
.order_by(annotated_sort_key.asc(), DatasetFiles.created_at.desc())
|
|
.offset(page * size)
|
|
.limit(size)
|
|
)
|
|
rows = files_result.all()
|
|
|
|
items: List[EditorTaskListItem] = []
|
|
for (
|
|
file_record,
|
|
annotation_id,
|
|
annotation_updated_at,
|
|
annotation_status,
|
|
) in rows:
|
|
fid = str(file_record.id) # type: ignore[arg-type]
|
|
items.append(
|
|
EditorTaskListItem(
|
|
fileId=fid,
|
|
fileName=str(getattr(file_record, "file_name", "")),
|
|
fileType=getattr(file_record, "file_type", None),
|
|
hasAnnotation=annotation_id is not None,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
annotationStatus=annotation_status,
|
|
)
|
|
)
|
|
|
|
total_pages = (total + size - 1) // size if size > 0 else 0
|
|
return EditorTaskListResponse(
|
|
content=items,
|
|
totalElements=total,
|
|
totalPages=total_pages,
|
|
page=page,
|
|
size=size,
|
|
)
|
|
|
|
async def _fetch_text_content_via_download_api(
|
|
self, dataset_id: str, file_id: str
|
|
) -> str:
|
|
return await fetch_text_content_via_download_api(dataset_id, file_id)
|
|
|
|
async def get_task(
|
|
self,
|
|
project_id: str,
|
|
file_id: str,
|
|
segment_index: Optional[int] = None,
|
|
) -> EditorTaskResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
dataset_type = self._normalize_dataset_type(
|
|
await self._get_dataset_type(project.dataset_id)
|
|
)
|
|
if dataset_type not in SUPPORTED_EDITOR_DATASET_TYPES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="当前仅支持 TEXT/IMAGE/AUDIO/VIDEO 项目的内嵌编辑器",
|
|
)
|
|
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles).where(
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"文件不存在或不属于该项目: {file_id}"
|
|
)
|
|
|
|
if dataset_type == DATASET_TYPE_IMAGE:
|
|
return await self._build_image_task(project, file_record, file_id)
|
|
|
|
if dataset_type == DATASET_TYPE_AUDIO:
|
|
return await self._build_audio_task(project, file_record, file_id)
|
|
|
|
if dataset_type == DATASET_TYPE_VIDEO:
|
|
return await self._build_video_task(project, file_record, file_id)
|
|
|
|
return await self._build_text_task(project, file_record, file_id, segment_index)
|
|
|
|
async def get_task_segment(
|
|
self,
|
|
project_id: str,
|
|
file_id: str,
|
|
segment_index: int,
|
|
) -> EditorTaskSegmentResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
dataset_type = self._normalize_dataset_type(
|
|
await self._get_dataset_type(project.dataset_id)
|
|
)
|
|
if dataset_type != DATASET_TYPE_TEXT:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="当前仅支持 TEXT 项目的段落内容",
|
|
)
|
|
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles).where(
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"文件不存在或不属于该项目: {file_id}"
|
|
)
|
|
|
|
if not self._resolve_segmentation_enabled(project):
|
|
return EditorTaskSegmentResponse(
|
|
segmented=False,
|
|
segment=None,
|
|
totalSegments=0,
|
|
currentSegmentIndex=0,
|
|
)
|
|
|
|
text_content = await self._fetch_text_content_via_download_api(
|
|
project.dataset_id, file_id
|
|
)
|
|
assert isinstance(text_content, str)
|
|
label_config = await self._resolve_project_label_config(project)
|
|
primary_text_key = self._resolve_primary_text_key(label_config)
|
|
file_name = str(getattr(file_record, "file_name", "")).lower()
|
|
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
if file_name.endswith(JSONL_EXTENSION):
|
|
records = self._parse_jsonl_records(text_content)
|
|
else:
|
|
parsed_payload = self._try_parse_json_payload(text_content)
|
|
if parsed_payload:
|
|
records = [(parsed_payload, text_content)]
|
|
|
|
if not records:
|
|
records = [(None, text_content)]
|
|
|
|
record_texts = [
|
|
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
|
for payload, raw_text in records
|
|
]
|
|
if not record_texts:
|
|
record_texts = [text_content]
|
|
|
|
needs_segmentation = len(records) > 1 or any(
|
|
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
|
|
)
|
|
if not needs_segmentation:
|
|
return EditorTaskSegmentResponse(
|
|
segmented=False,
|
|
segment=None,
|
|
totalSegments=0,
|
|
currentSegmentIndex=0,
|
|
)
|
|
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project.id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
ann = ann_result.scalar_one_or_none()
|
|
segment_annotations: Dict[str, Dict[str, Any]] = {}
|
|
if ann and isinstance(ann.annotation, dict):
|
|
segment_annotations = self._extract_segment_annotations(ann.annotation)
|
|
segment_annotation_keys = set(segment_annotations.keys())
|
|
|
|
segments, segment_contexts = self._build_segment_contexts(
|
|
records,
|
|
record_texts,
|
|
segment_annotation_keys,
|
|
)
|
|
|
|
total_segments = len(segment_contexts)
|
|
if total_segments == 0:
|
|
return EditorTaskSegmentResponse(
|
|
segmented=False,
|
|
segment=None,
|
|
totalSegments=0,
|
|
currentSegmentIndex=0,
|
|
)
|
|
|
|
if segment_index < 0 or segment_index >= total_segments:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"segmentIndex 超出范围: {segment_index}",
|
|
)
|
|
|
|
segment_info = segments[segment_index]
|
|
_, _, segment_text, line_index, chunk_index = segment_contexts[segment_index]
|
|
segment_detail = SegmentDetail(
|
|
idx=segment_info.idx,
|
|
text=segment_text,
|
|
hasAnnotation=segment_info.has_annotation,
|
|
lineIndex=line_index,
|
|
chunkIndex=chunk_index,
|
|
)
|
|
|
|
return EditorTaskSegmentResponse(
|
|
segmented=True,
|
|
segment=segment_detail,
|
|
totalSegments=total_segments,
|
|
currentSegmentIndex=segment_index,
|
|
)
|
|
|
|
async def _build_text_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
segment_index: Optional[int],
|
|
) -> EditorTaskResponse:
|
|
text_content = await self._fetch_text_content_via_download_api(
|
|
project.dataset_id, file_id
|
|
)
|
|
assert isinstance(text_content, str)
|
|
label_config = await self._resolve_project_label_config(project)
|
|
primary_text_key = self._resolve_primary_text_key(label_config)
|
|
file_name = str(getattr(file_record, "file_name", "")).lower()
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
if file_name.endswith(JSONL_EXTENSION):
|
|
records = self._parse_jsonl_records(text_content)
|
|
else:
|
|
parsed_payload = self._try_parse_json_payload(text_content)
|
|
if parsed_payload:
|
|
records = [(parsed_payload, text_content)]
|
|
|
|
if not records:
|
|
records = [(None, text_content)]
|
|
|
|
record_texts = [
|
|
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
|
for payload, raw_text in records
|
|
]
|
|
if not record_texts:
|
|
record_texts = [text_content]
|
|
|
|
# 获取现有标注
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project.id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
ann = ann_result.scalar_one_or_none()
|
|
|
|
ls_task_id = self._make_ls_task_id(project.id, file_id)
|
|
|
|
segment_annotations: Dict[str, Dict[str, Any]] = {}
|
|
has_segmented_annotation = False
|
|
if ann and isinstance(ann.annotation, dict):
|
|
segment_annotations = self._extract_segment_annotations(ann.annotation)
|
|
has_segmented_annotation = self._is_segmented_annotation(ann.annotation)
|
|
segment_annotation_keys = set(segment_annotations.keys())
|
|
|
|
# 判断是否需要分段(JSONL 多行或主文本超过阈值)
|
|
segmentation_enabled = self._resolve_segmentation_enabled(project)
|
|
if not segmentation_enabled:
|
|
segment_index = None
|
|
needs_segmentation = segmentation_enabled and (
|
|
len(records) > 1
|
|
or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
|
|
)
|
|
segments: List[SegmentInfo] = []
|
|
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
|
current_segment_index = 0
|
|
display_text = record_texts[0] if record_texts else text_content
|
|
selected_payload = records[0][0] if records else None
|
|
if not segmentation_enabled and len(records) > 1:
|
|
selected_payload = None
|
|
display_text = "\n".join(record_texts) if record_texts else text_content
|
|
|
|
if needs_segmentation:
|
|
_, segment_contexts = self._build_segment_contexts(
|
|
records,
|
|
record_texts,
|
|
segment_annotation_keys,
|
|
)
|
|
current_segment_index = segment_index if segment_index is not None else 0
|
|
if current_segment_index < 0 or current_segment_index >= len(
|
|
segment_contexts
|
|
):
|
|
current_segment_index = 0
|
|
|
|
selected_payload, _, display_text, _, _ = segment_contexts[
|
|
current_segment_index
|
|
]
|
|
|
|
# 构造 task 对象
|
|
task_data = self._build_task_data(
|
|
display_text=display_text,
|
|
parsed_payload=selected_payload,
|
|
label_config=label_config,
|
|
file_record=file_record,
|
|
dataset_id=project.dataset_id,
|
|
file_id=file_id,
|
|
primary_text_key=primary_text_key,
|
|
)
|
|
if needs_segmentation:
|
|
task_data[SEGMENT_INDEX_KEY] = current_segment_index
|
|
task_data[SEGMENT_INDEX_CAMEL_KEY] = current_segment_index
|
|
|
|
task: Dict[str, Any] = {
|
|
"id": ls_task_id,
|
|
"data": task_data,
|
|
"annotations": [],
|
|
}
|
|
|
|
annotation_updated_at = None
|
|
if ann:
|
|
annotation_updated_at = ann.updated_at
|
|
|
|
if needs_segmentation and has_segmented_annotation:
|
|
# 分段模式:获取当前段落的标注
|
|
seg_ann = segment_annotations.get(str(current_segment_index), {})
|
|
stored = {
|
|
"id": self._make_ls_annotation_id(project.id, file_id)
|
|
+ current_segment_index,
|
|
"task": ls_task_id,
|
|
"result": seg_ann.get(SEGMENT_RESULT_KEY, []),
|
|
"created_at": seg_ann.get(
|
|
SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"
|
|
),
|
|
"updated_at": seg_ann.get(
|
|
SEGMENT_UPDATED_AT_KEY, datetime.utcnow().isoformat() + "Z"
|
|
),
|
|
}
|
|
task["annotations"] = [stored]
|
|
elif not needs_segmentation and not has_segmented_annotation:
|
|
# 非分段模式:直接返回存储的 annotation 原始对象
|
|
stored = dict(ann.annotation or {})
|
|
stored["task"] = ls_task_id
|
|
if not isinstance(stored.get("id"), int):
|
|
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [stored]
|
|
else:
|
|
# 首次从非分段切换到分段:提供空标注
|
|
empty_ann_id = (
|
|
self._make_ls_annotation_id(project.id, file_id)
|
|
+ current_segment_index
|
|
)
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
else:
|
|
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
|
|
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
|
|
if needs_segmentation:
|
|
empty_ann_id += current_segment_index
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
|
|
return EditorTaskResponse(
|
|
task=task,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
segmented=needs_segmentation,
|
|
totalSegments=len(segment_contexts) if needs_segmentation else 1,
|
|
currentSegmentIndex=current_segment_index,
|
|
)
|
|
|
|
async def _build_media_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
default_key: str,
|
|
categories: set[str],
|
|
) -> EditorTaskResponse:
|
|
label_config = await self._resolve_project_label_config(project)
|
|
media_keys = self._resolve_media_value_keys(
|
|
label_config, default_key, categories
|
|
)
|
|
preview_url = self._build_file_preview_url(project.dataset_id, file_id)
|
|
file_name = str(getattr(file_record, "file_name", ""))
|
|
|
|
task_data: Dict[str, Any] = {
|
|
FILE_ID_KEY: file_id,
|
|
FILE_ID_CAMEL_KEY: file_id,
|
|
DATASET_ID_KEY: project.dataset_id,
|
|
DATASET_ID_CAMEL_KEY: project.dataset_id,
|
|
FILE_NAME_KEY: file_name,
|
|
FILE_NAME_CAMEL_KEY: file_name,
|
|
}
|
|
for key in media_keys:
|
|
task_data[key] = preview_url
|
|
self._apply_text_placeholders(task_data, label_config)
|
|
|
|
# 获取现有标注
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project.id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
ann = ann_result.scalar_one_or_none()
|
|
ls_task_id = self._make_ls_task_id(project.id, file_id)
|
|
|
|
task: Dict[str, Any] = {
|
|
"id": ls_task_id,
|
|
"data": task_data,
|
|
"annotations": [],
|
|
}
|
|
|
|
annotation_updated_at = None
|
|
if ann and not (ann.annotation or {}).get("segmented"):
|
|
annotation_updated_at = ann.updated_at
|
|
stored = dict(ann.annotation or {})
|
|
stored["task"] = ls_task_id
|
|
if not isinstance(stored.get("id"), int):
|
|
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [stored]
|
|
else:
|
|
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
|
|
return EditorTaskResponse(
|
|
task=task,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
segmented=False,
|
|
segments=None,
|
|
totalSegments=1,
|
|
currentSegmentIndex=0,
|
|
)
|
|
|
|
async def _build_image_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=IMAGE_DATA_KEY,
|
|
categories=IMAGE_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def _build_audio_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=AUDIO_DATA_KEY,
|
|
categories=MEDIA_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def _build_video_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=VIDEO_DATA_KEY,
|
|
categories=MEDIA_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def upsert_annotation(
|
|
self, project_id: str, file_id: str, request: UpsertAnnotationRequest
|
|
) -> UpsertAnnotationResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
# 校验文件归属
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project.id,
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"文件不存在或不属于该项目: {file_id}"
|
|
)
|
|
|
|
# 检查文件版本是否变化
|
|
current_file_version = file_record.version
|
|
existing_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
existing_annotation = existing_result.scalar_one_or_none()
|
|
|
|
if existing_annotation and existing_annotation.file_version is not None:
|
|
if existing_annotation.file_version != current_file_version:
|
|
raise HTTPException(
|
|
status_code=409,
|
|
detail=f"文件已更新到新版本(当前版本: {current_file_version}, 标注版本: {existing_annotation.file_version}),请使用新版本",
|
|
)
|
|
|
|
annotation_payload = dict(request.annotation or {})
|
|
result = annotation_payload.get("result")
|
|
if result is None:
|
|
annotation_payload["result"] = []
|
|
elif not isinstance(result, list):
|
|
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
|
|
|
|
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
|
|
|
segment_total_hint = None
|
|
if request.segment_index is not None:
|
|
segment_total_hint = self._resolve_segment_total(annotation_payload)
|
|
if segment_total_hint is None:
|
|
segment_total_hint = await self._compute_segment_total(
|
|
project, file_record, file_id
|
|
)
|
|
|
|
existing_result = await self.db.execute(
|
|
select(AnnotationResult)
|
|
.where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
.with_for_update()
|
|
)
|
|
existing = existing_result.scalar_one_or_none()
|
|
|
|
now = datetime.utcnow()
|
|
|
|
# 判断是否为分段保存模式
|
|
if request.segment_index is not None:
|
|
# 分段模式:合并段落标注到整体结构
|
|
final_payload = self._merge_segment_annotation(
|
|
existing.annotation if existing else None,
|
|
request.segment_index,
|
|
annotation_payload,
|
|
)
|
|
segment_entries = self._extract_segment_annotations(final_payload)
|
|
if str(request.segment_index) not in segment_entries:
|
|
logger.warning(
|
|
"分段标注合并异常:未找到当前段落 key,project_id=%s file_id=%s segment_index=%s",
|
|
project_id,
|
|
file_id,
|
|
request.segment_index,
|
|
)
|
|
else:
|
|
# 非分段模式:直接使用传入的 annotation
|
|
annotation_payload["task"] = ls_task_id
|
|
if not isinstance(annotation_payload.get("id"), int):
|
|
annotation_payload["id"] = self._make_ls_annotation_id(
|
|
project_id, file_id
|
|
)
|
|
final_payload = annotation_payload
|
|
|
|
requested_status = request.annotation_status
|
|
if (
|
|
requested_status is not None
|
|
and requested_status not in ANNOTATION_STATUS_CLIENT_VALUES
|
|
):
|
|
raise HTTPException(status_code=400, detail="annotationStatus 不合法")
|
|
|
|
segment_total = None
|
|
segment_done = None
|
|
if request.segment_index is not None:
|
|
segment_total = self._resolve_segment_total(final_payload)
|
|
if segment_total is None:
|
|
segment_total = segment_total_hint
|
|
if segment_total and segment_total > 0:
|
|
final_payload[SEGMENT_TOTAL_KEY] = segment_total
|
|
segment_done = len(self._extract_segment_annotations(final_payload))
|
|
|
|
if (
|
|
segment_total is not None
|
|
and segment_done is not None
|
|
and segment_done < segment_total
|
|
):
|
|
final_status = ANNOTATION_STATUS_IN_PROGRESS
|
|
else:
|
|
has_result = self._has_annotation_result(final_payload)
|
|
if has_result:
|
|
final_status = ANNOTATION_STATUS_ANNOTATED
|
|
else:
|
|
if requested_status == ANNOTATION_STATUS_NO_ANNOTATION:
|
|
final_status = ANNOTATION_STATUS_NO_ANNOTATION
|
|
elif requested_status == ANNOTATION_STATUS_NOT_APPLICABLE:
|
|
final_status = ANNOTATION_STATUS_NOT_APPLICABLE
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="未发现标注内容,请确认无标注/不适用后再保存",
|
|
)
|
|
|
|
if request.segment_index is not None:
|
|
segment_entries = self._extract_segment_annotations(final_payload)
|
|
logger.info(
|
|
"分段标注保存:project_id=%s file_id=%s segment_index=%s segments=%s total=%s status=%s",
|
|
project_id,
|
|
file_id,
|
|
request.segment_index,
|
|
len(segment_entries),
|
|
segment_total,
|
|
final_status,
|
|
)
|
|
|
|
if existing:
|
|
if request.expected_updated_at and existing.updated_at:
|
|
if existing.updated_at != request.expected_updated_at.replace(
|
|
tzinfo=None
|
|
):
|
|
raise HTTPException(
|
|
status_code=409, detail="标注已被更新,请刷新后重试"
|
|
)
|
|
|
|
existing.annotation = final_payload # type: ignore[assignment]
|
|
existing.annotation_status = final_status # type: ignore[assignment]
|
|
existing.file_version = current_file_version # type: ignore[assignment]
|
|
existing.updated_at = now # type: ignore[assignment]
|
|
await self.db.commit()
|
|
await self.db.refresh(existing)
|
|
|
|
response = UpsertAnnotationResponse(
|
|
annotationId=existing.id,
|
|
updatedAt=existing.updated_at or now,
|
|
)
|
|
await self._sync_annotation_to_knowledge(
|
|
project, file_record, final_payload, existing.updated_at
|
|
)
|
|
return response
|
|
|
|
new_id = str(uuid.uuid4())
|
|
record = AnnotationResult(
|
|
id=new_id,
|
|
project_id=project_id,
|
|
file_id=file_id,
|
|
annotation=final_payload,
|
|
annotation_status=final_status,
|
|
file_version=current_file_version,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
self.db.add(record)
|
|
await self.db.commit()
|
|
await self.db.refresh(record)
|
|
|
|
response = UpsertAnnotationResponse(
|
|
annotationId=record.id,
|
|
updatedAt=record.updated_at or now,
|
|
)
|
|
await self._sync_annotation_to_knowledge(
|
|
project, file_record, final_payload, record.updated_at
|
|
)
|
|
return response
|
|
|
|
def _merge_segment_annotation(
|
|
self,
|
|
existing: Optional[Dict[str, Any]],
|
|
segment_index: int,
|
|
new_annotation: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
合并段落标注到整体结构
|
|
|
|
Args:
|
|
existing: 现有的 annotation 数据
|
|
segment_index: 段落索引
|
|
new_annotation: 新的段落标注数据
|
|
|
|
Returns:
|
|
合并后的 annotation 结构
|
|
"""
|
|
if not existing or not existing.get(SEGMENTED_KEY):
|
|
# 初始化分段结构
|
|
base: Dict[str, Any] = {
|
|
SEGMENTED_KEY: True,
|
|
"version": 1,
|
|
SEGMENTS_KEY: {},
|
|
}
|
|
else:
|
|
base = dict(existing)
|
|
|
|
if not base.get(SEGMENTED_KEY):
|
|
base[SEGMENTED_KEY] = True
|
|
segments = base.get(SEGMENTS_KEY)
|
|
if isinstance(segments, dict):
|
|
# 拷贝一份,避免原地修改导致 SQLAlchemy 变更检测失效
|
|
segments = dict(segments)
|
|
base[SEGMENTS_KEY] = segments
|
|
elif isinstance(segments, list):
|
|
# 兼容旧的 list 结构,归一化为 dict 结构
|
|
segments = self._extract_segment_annotations(base)
|
|
base[SEGMENTS_KEY] = segments
|
|
else:
|
|
segments = {}
|
|
base[SEGMENTS_KEY] = segments
|
|
|
|
# 更新指定段落的标注
|
|
segments[str(segment_index)] = {
|
|
SEGMENT_RESULT_KEY: new_annotation.get(SEGMENT_RESULT_KEY, []),
|
|
SEGMENT_CREATED_AT_KEY: new_annotation.get(
|
|
SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"
|
|
),
|
|
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
|
|
return base
|
|
|
|
async def _sync_annotation_to_knowledge(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
annotation: Dict[str, Any],
|
|
annotation_updated_at: Optional[datetime],
|
|
) -> None:
|
|
"""同步标注结果到知识管理(失败不影响标注保存)"""
|
|
try:
|
|
await KnowledgeSyncService(self.db).sync_annotation_to_knowledge(
|
|
project=project,
|
|
file_record=file_record,
|
|
annotation=annotation,
|
|
annotation_updated_at=annotation_updated_at,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("标注同步知识管理失败:%s", exc)
|
|
|
|
async def precompute_segmentation_for_project(
|
|
self, project_id: str, max_retries: int = 3
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
为指定项目的所有文本文件预计算切片结构并持久化到数据库
|
|
|
|
Args:
|
|
project_id: 标注项目ID
|
|
max_retries: 失败重试次数
|
|
|
|
Returns:
|
|
统计信息:{total_files, succeeded, failed}
|
|
"""
|
|
project = await self._get_project_or_404(project_id)
|
|
dataset_type = self._normalize_dataset_type(
|
|
await self._get_dataset_type(project.dataset_id)
|
|
)
|
|
|
|
# 只处理文本数据集
|
|
if dataset_type != DATASET_TYPE_TEXT:
|
|
logger.info(f"项目 {project_id} 不是文本数据集,跳过切片预生成")
|
|
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
|
|
|
# 检查是否启用分段
|
|
if not self._resolve_segmentation_enabled(project):
|
|
logger.info(f"项目 {project_id} 未启用分段,跳过切片预生成")
|
|
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
|
|
|
# 获取项目的所有文本文件(排除源文档)
|
|
files_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_records = files_result.scalars().all()
|
|
|
|
if not file_records:
|
|
logger.info(f"项目 {project_id} 没有文件,跳过切片预生成")
|
|
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
|
|
|
# 过滤源文档文件
|
|
valid_files = []
|
|
for file_record in file_records:
|
|
file_type = str(getattr(file_record, "file_type", "") or "").lower()
|
|
file_name = str(getattr(file_record, "file_name", "")).lower()
|
|
is_source_document = file_type in SOURCE_DOCUMENT_TYPES or any(
|
|
file_name.endswith(ext) for ext in SOURCE_DOCUMENT_EXTENSIONS
|
|
)
|
|
if not is_source_document:
|
|
valid_files.append(file_record)
|
|
|
|
total_files = len(valid_files)
|
|
succeeded = 0
|
|
failed = 0
|
|
|
|
label_config = await self._resolve_project_label_config(project)
|
|
primary_text_key = self._resolve_primary_text_key(label_config)
|
|
|
|
for file_record in valid_files:
|
|
file_id = str(file_record.id) # type: ignore
|
|
file_name = str(getattr(file_record, "file_name", ""))
|
|
current_file_version = getattr(file_record, "version", None)
|
|
|
|
for retry in range(max_retries):
|
|
try:
|
|
# 读取文本内容
|
|
text_content = await self._fetch_text_content_via_download_api(
|
|
project.dataset_id, file_id
|
|
)
|
|
if not isinstance(text_content, str):
|
|
logger.warning(f"文件 {file_id} 内容不是字符串,跳过切片")
|
|
failed += 1
|
|
break
|
|
|
|
# 解析文本记录
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
if file_name.lower().endswith(JSONL_EXTENSION):
|
|
records = self._parse_jsonl_records(text_content)
|
|
else:
|
|
parsed_payload = self._try_parse_json_payload(text_content)
|
|
if parsed_payload:
|
|
records = [(parsed_payload, text_content)]
|
|
|
|
if not records:
|
|
records = [(None, text_content)]
|
|
|
|
record_texts = [
|
|
self._resolve_primary_text_value(
|
|
payload, raw_text, primary_text_key
|
|
)
|
|
for payload, raw_text in records
|
|
]
|
|
if not record_texts:
|
|
record_texts = [text_content]
|
|
|
|
# 判断是否需要分段
|
|
needs_segmentation = len(records) > 1 or any(
|
|
len(text or "") > self.SEGMENT_THRESHOLD
|
|
for text in record_texts
|
|
)
|
|
|
|
if not needs_segmentation:
|
|
# 不需要分段的文件,跳过
|
|
succeeded += 1
|
|
break
|
|
|
|
# 执行切片
|
|
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
|
segment_cursor = 0
|
|
segments = {}
|
|
|
|
for record_index, ((payload, raw_text), record_text) in enumerate(
|
|
zip(records, record_texts)
|
|
):
|
|
normalized_text = record_text or ""
|
|
|
|
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
|
raw_segments = splitter.split(normalized_text)
|
|
for chunk_index, seg in enumerate(raw_segments):
|
|
segments[str(segment_cursor)] = {
|
|
SEGMENT_RESULT_KEY: [],
|
|
SEGMENT_CREATED_AT_KEY: datetime.utcnow().isoformat()
|
|
+ "Z",
|
|
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat()
|
|
+ "Z",
|
|
}
|
|
segment_cursor += 1
|
|
else:
|
|
segments[str(segment_cursor)] = {
|
|
SEGMENT_RESULT_KEY: [],
|
|
SEGMENT_CREATED_AT_KEY: datetime.utcnow().isoformat()
|
|
+ "Z",
|
|
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat()
|
|
+ "Z",
|
|
}
|
|
segment_cursor += 1
|
|
|
|
if not segments:
|
|
succeeded += 1
|
|
break
|
|
|
|
# 构造分段标注结构
|
|
final_payload = {
|
|
SEGMENTED_KEY: True,
|
|
"version": 1,
|
|
SEGMENTS_KEY: segments,
|
|
SEGMENT_TOTAL_KEY: segment_cursor,
|
|
}
|
|
|
|
# 检查是否已存在标注
|
|
existing_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
existing = existing_result.scalar_one_or_none()
|
|
|
|
now = datetime.utcnow()
|
|
|
|
if existing:
|
|
# 更新现有标注
|
|
existing.annotation = final_payload # type: ignore[assignment]
|
|
existing.annotation_status = ANNOTATION_STATUS_IN_PROGRESS # type: ignore[assignment]
|
|
existing.file_version = current_file_version # type: ignore[assignment]
|
|
existing.updated_at = now # type: ignore[assignment]
|
|
else:
|
|
# 创建新标注记录
|
|
record = AnnotationResult(
|
|
id=str(uuid.uuid4()),
|
|
project_id=project_id,
|
|
file_id=file_id,
|
|
annotation=final_payload,
|
|
annotation_status=ANNOTATION_STATUS_IN_PROGRESS,
|
|
file_version=current_file_version,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
self.db.add(record)
|
|
|
|
await self.db.commit()
|
|
succeeded += 1
|
|
logger.info(f"成功为文件 {file_id} 预生成 {segment_cursor} 个切片")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"为文件 {file_id} 预生成切片失败 (重试 {retry + 1}/{max_retries}): {e}"
|
|
)
|
|
if retry == max_retries - 1:
|
|
failed += 1
|
|
await self.db.rollback()
|
|
|
|
logger.info(
|
|
f"项目 {project_id} 切片预生成完成: 总计 {total_files}, 成功 {succeeded}, 失败 {failed}"
|
|
)
|
|
return {
|
|
"total_files": total_files,
|
|
"succeeded": succeeded,
|
|
"failed": failed,
|
|
}
|
|
|
|
async def check_file_version(self, project_id: str, file_id: str) -> Dict[str, Any]:
|
|
"""
|
|
检查文件是否有新版本
|
|
|
|
通过比较同一逻辑路径下最新版本的文件与标注时的文件版本,
|
|
判断是否有新版本可用。
|
|
|
|
Args:
|
|
project_id: 标注项目ID
|
|
file_id: 文件ID(标注关联的文件ID)
|
|
|
|
Returns:
|
|
包含文件版本信息的字典
|
|
"""
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
# 获取当前标注关联的文件信息
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project.id,
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"文件不存在或不属于该项目: {file_id}"
|
|
)
|
|
|
|
# 获取同一逻辑路径下的最新版本文件
|
|
# 用于判断是否有新版本
|
|
logical_path = file_record.logical_path
|
|
latest_file_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.where(
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
DatasetFiles.logical_path == logical_path,
|
|
DatasetFiles.status == "ACTIVE",
|
|
)
|
|
.order_by(DatasetFiles.version.desc())
|
|
.limit(1)
|
|
)
|
|
latest_file = latest_file_result.scalar_one_or_none()
|
|
|
|
# 获取标注信息(基于传入的 file_id)
|
|
annotation_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
annotation = annotation_result.scalar_one_or_none()
|
|
|
|
# 最新文件版本(同一逻辑路径下版本最高的 ACTIVE 文件)
|
|
latest_file_version = latest_file.version if latest_file else file_record.version
|
|
# 标注时的文件版本
|
|
annotation_file_version = annotation.file_version if annotation else None
|
|
|
|
annotation_version_unknown = (
|
|
annotation is not None and annotation_file_version is None
|
|
)
|
|
|
|
# 判断是否有新版本:最新版本 > 当前文件版本
|
|
# 无论是否有标注,只要传入的文件不是最新版本就提示
|
|
has_new_version = latest_file_version > file_record.version
|
|
|
|
return {
|
|
"fileId": file_id,
|
|
"currentFileVersion": latest_file_version,
|
|
"annotationFileVersion": annotation_file_version,
|
|
"hasNewVersion": has_new_version,
|
|
"annotationVersionUnknown": annotation_version_unknown,
|
|
"latestFileId": latest_file.id if latest_file else file_id,
|
|
}
|
|
|
|
async def use_new_version(self, project_id: str, file_id: str) -> Dict[str, Any]:
|
|
"""
|
|
使用文件新版本并清空标注
|
|
|
|
Args:
|
|
project_id: 标注项目ID
|
|
file_id: 文件ID
|
|
|
|
Returns:
|
|
操作结果
|
|
"""
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
# 获取文件信息
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project.id,
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(
|
|
status_code=404, detail=f"文件不存在或不属于该项目: {file_id}"
|
|
)
|
|
|
|
# 获取标注信息
|
|
annotation_result = await self.db.execute(
|
|
select(AnnotationResult)
|
|
.where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
.with_for_update()
|
|
)
|
|
annotation = annotation_result.scalar_one_or_none()
|
|
|
|
current_file_version = file_record.version
|
|
|
|
if not annotation:
|
|
raise HTTPException(status_code=404, detail=f"标注不存在: {file_id}")
|
|
|
|
previous_file_version = annotation.file_version
|
|
|
|
if annotation.file_version is not None:
|
|
if current_file_version <= annotation.file_version:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"文件版本({current_file_version})未更新或低于标注版本({annotation.file_version})",
|
|
)
|
|
|
|
# 清空标注并更新版本号
|
|
now = datetime.utcnow()
|
|
cleared_payload: Dict[str, Any] = {}
|
|
if isinstance(annotation.annotation, dict) and self._is_segmented_annotation(
|
|
annotation.annotation
|
|
):
|
|
segments = self._extract_segment_annotations(annotation.annotation)
|
|
cleared_segments: Dict[str, Dict[str, Any]] = {}
|
|
for segment_id, segment_data in segments.items():
|
|
if not isinstance(segment_data, dict):
|
|
continue
|
|
normalized = dict(segment_data)
|
|
normalized[SEGMENT_RESULT_KEY] = []
|
|
cleared_segments[str(segment_id)] = normalized
|
|
|
|
total_segments = self._resolve_segment_total(annotation.annotation)
|
|
if total_segments is None:
|
|
total_segments = len(cleared_segments)
|
|
|
|
cleared_payload = {
|
|
SEGMENTED_KEY: True,
|
|
"version": annotation.annotation.get("version", 1),
|
|
SEGMENTS_KEY: cleared_segments,
|
|
SEGMENT_TOTAL_KEY: total_segments,
|
|
}
|
|
|
|
annotation.annotation = cleared_payload
|
|
annotation.annotation_status = ANNOTATION_STATUS_NO_ANNOTATION
|
|
annotation.file_version = current_file_version
|
|
annotation.updated_at = now
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(annotation)
|
|
|
|
await self._sync_annotation_to_knowledge(
|
|
project,
|
|
file_record,
|
|
cleared_payload,
|
|
annotation.updated_at or now,
|
|
)
|
|
|
|
return {
|
|
"fileId": file_id,
|
|
"previousFileVersion": previous_file_version,
|
|
"currentFileVersion": current_file_version,
|
|
"message": "已使用新版本并清空标注",
|
|
}
|