You've already forked DataMate
- 修改数据源文件扫描方法,直接在主流程中获取任务详情和路径 - 移除独立的getFilePaths方法,将路径扫描逻辑整合到scanFilePaths方法中 - 新增copyFilesToDatasetDirWithSourceRoot方法支持保留相对路径的文件复制 - 更新数据集文件应用服务中的文件复制逻辑,支持相对路径处理 - 修改Python后端项目接口中的文件查询逻辑,移除注释掉的编辑器服务引用 - 调整文件过滤逻辑,基于元数据中的派生源ID进行文件筛选 - 移除编辑器服务中已废弃的源文档过滤条件
963 lines
36 KiB
Python
963 lines
36 KiB
Python
"""
|
|
标注编辑器(Label Studio Editor)服务
|
|
|
|
职责:
|
|
- 解析 DataMate 标注项目(t_dm_labeling_projects)
|
|
- 以“文件下载/预览接口”读取文本内容,构造 Label Studio task
|
|
- 以原始 annotation JSON 形式 upsert 最终标注结果(单人单份)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from urllib.parse import urlparse
|
|
|
|
import hashlib
|
|
import json
|
|
import xml.etree.ElementTree as ET
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import case, func, select, or_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
from app.db.models import AnnotationResult, Dataset, DatasetFiles, LabelingProject, LabelingProjectFile
|
|
from app.module.annotation.config import LabelStudioTagConfig
|
|
from app.module.annotation.schema.editor import (
|
|
EditorProjectInfo,
|
|
EditorTaskListItem,
|
|
EditorTaskListResponse,
|
|
EditorTaskResponse,
|
|
SegmentInfo,
|
|
UpsertAnnotationRequest,
|
|
UpsertAnnotationResponse,
|
|
)
|
|
from app.module.annotation.service.template import AnnotationTemplateService
|
|
from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
|
|
from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter
|
|
from app.module.annotation.service.text_fetcher import fetch_text_content_via_download_api
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
TEXT_DATA_KEY = "text"
|
|
IMAGE_DATA_KEY = "image"
|
|
AUDIO_DATA_KEY = "audio"
|
|
VIDEO_DATA_KEY = "video"
|
|
DATASET_ID_KEY = "dataset_id"
|
|
FILE_ID_KEY = "file_id"
|
|
FILE_NAME_KEY = "file_name"
|
|
DATASET_ID_CAMEL_KEY = "datasetId"
|
|
FILE_ID_CAMEL_KEY = "fileId"
|
|
FILE_NAME_CAMEL_KEY = "fileName"
|
|
SEGMENT_INDEX_KEY = "segment_index"
|
|
SEGMENT_INDEX_CAMEL_KEY = "segmentIndex"
|
|
JSONL_EXTENSION = ".jsonl"
|
|
TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
|
|
IMAGE_OBJECT_CATEGORIES = {"image"}
|
|
MEDIA_OBJECT_CATEGORIES = {"media"}
|
|
OBJECT_NAME_HEADER_PREFIX = "dm_object_header_"
|
|
DATASET_TYPE_TEXT = "TEXT"
|
|
DATASET_TYPE_IMAGE = "IMAGE"
|
|
DATASET_TYPE_AUDIO = "AUDIO"
|
|
DATASET_TYPE_VIDEO = "VIDEO"
|
|
SUPPORTED_EDITOR_DATASET_TYPES = (
|
|
DATASET_TYPE_TEXT,
|
|
DATASET_TYPE_IMAGE,
|
|
DATASET_TYPE_AUDIO,
|
|
DATASET_TYPE_VIDEO,
|
|
)
|
|
SEGMENTATION_ENABLED_KEY = "segmentation_enabled"
|
|
SOURCE_DOCUMENT_EXTENSIONS = (".pdf", ".doc", ".docx")
|
|
SOURCE_DOCUMENT_TYPES = ("pdf", "doc", "docx")
|
|
|
|
|
|
class AnnotationEditorService:
|
|
"""Label Studio Editor 集成服务(TEXT POC 版)"""
|
|
|
|
# 分段阈值:超过此字符数自动分段
|
|
SEGMENT_THRESHOLD = 200
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
self.template_service = AnnotationTemplateService()
|
|
|
|
@staticmethod
|
|
def _stable_ls_id(seed: str) -> int:
|
|
"""
|
|
生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。
|
|
|
|
说明:
|
|
- Label Studio Frontend 的 mobx-state-tree 模型对 task/annotation 的 id 有类型约束(通常为 number)。
|
|
- DataMate 使用 UUID 作为 file_id/project_id,因此需映射为整数供编辑器使用。
|
|
- 取 sha1 的前 13 个 hex(52bit),落在 JS Number 的安全整数范围。
|
|
"""
|
|
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
|
|
value = int(digest[:13], 16)
|
|
return value if value > 0 else 1
|
|
|
|
def _make_ls_task_id(self, project_id: str, file_id: str) -> int:
|
|
return self._stable_ls_id(f"task:{project_id}:{file_id}")
|
|
|
|
def _make_ls_annotation_id(self, project_id: str, file_id: str) -> int:
|
|
# 单人单份最终标签:每个 task 只保留一个 annotation,id 直接与 task 绑定即可
|
|
return self._stable_ls_id(f"annotation:{project_id}:{file_id}")
|
|
|
|
@staticmethod
|
|
def _normalize_dataset_type(dataset_type: Optional[str]) -> str:
|
|
return (dataset_type or "").upper()
|
|
|
|
@staticmethod
|
|
def _resolve_public_api_prefix() -> str:
|
|
base = (settings.datamate_backend_base_url or "").strip()
|
|
if not base:
|
|
return "/api"
|
|
parsed = urlparse(base)
|
|
if parsed.scheme and parsed.netloc:
|
|
prefix = parsed.path
|
|
else:
|
|
prefix = base
|
|
prefix = prefix.rstrip("/")
|
|
if not prefix:
|
|
return "/api"
|
|
if not prefix.startswith("/"):
|
|
prefix = "/" + prefix
|
|
return prefix
|
|
|
|
@classmethod
|
|
def _build_file_preview_url(cls, dataset_id: str, file_id: str) -> str:
|
|
prefix = cls._resolve_public_api_prefix()
|
|
return f"{prefix}/data-management/datasets/{dataset_id}/files/{file_id}/preview"
|
|
|
|
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.id == project_id,
|
|
LabelingProject.deleted_at.is_(None),
|
|
)
|
|
)
|
|
project = result.scalar_one_or_none()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
|
return project
|
|
|
|
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
|
result = await self.db.execute(
|
|
select(Dataset.dataset_type).where(Dataset.id == dataset_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def _get_label_config(self, template_id: Optional[str]) -> Optional[str]:
|
|
if not template_id:
|
|
return None
|
|
template = await self.template_service.get_template(self.db, template_id)
|
|
return getattr(template, "label_config", None) if template else None
|
|
|
|
async def _resolve_project_label_config(self, project: LabelingProject) -> Optional[str]:
|
|
label_config = None
|
|
if project.configuration and isinstance(project.configuration, dict):
|
|
label_config = project.configuration.get("label_config")
|
|
if not label_config:
|
|
label_config = await self._get_label_config(project.template_id)
|
|
if label_config:
|
|
label_config = self._decorate_label_config_for_editor(label_config)
|
|
return label_config
|
|
|
|
@staticmethod
|
|
def _resolve_segmentation_enabled(project: LabelingProject) -> bool:
|
|
config = project.configuration
|
|
if not isinstance(config, dict):
|
|
return True
|
|
value = config.get(SEGMENTATION_ENABLED_KEY)
|
|
if isinstance(value, bool):
|
|
return value
|
|
if value is None:
|
|
return True
|
|
return bool(value)
|
|
|
|
@classmethod
|
|
def _resolve_primary_text_key(cls, label_config: Optional[str]) -> Optional[str]:
|
|
if not label_config:
|
|
return None
|
|
keys = cls._extract_textual_value_keys(label_config)
|
|
if not keys:
|
|
return None
|
|
if TEXT_DATA_KEY in keys:
|
|
return TEXT_DATA_KEY
|
|
return keys[0]
|
|
|
|
@classmethod
|
|
def _resolve_media_value_keys(
|
|
cls,
|
|
label_config: Optional[str],
|
|
default_key: str,
|
|
categories: Optional[set[str]] = None,
|
|
) -> List[str]:
|
|
if not label_config:
|
|
return [default_key]
|
|
target_categories = categories or set()
|
|
keys = cls._extract_object_value_keys_by_category(label_config, target_categories)
|
|
if not keys:
|
|
return [default_key]
|
|
return keys
|
|
|
|
@staticmethod
|
|
def _try_parse_json_payload(text_content: str) -> Optional[Dict[str, Any]]:
|
|
if not text_content:
|
|
return None
|
|
stripped = text_content.strip()
|
|
if not stripped:
|
|
return None
|
|
if stripped[0] not in ("{", "["):
|
|
return None
|
|
try:
|
|
parsed = json.loads(stripped)
|
|
except Exception:
|
|
return None
|
|
return parsed if isinstance(parsed, dict) else None
|
|
|
|
@classmethod
|
|
def _parse_jsonl_records(cls, text_content: str) -> List[Tuple[Optional[Dict[str, Any]], str]]:
|
|
lines = [line for line in text_content.splitlines() if line.strip()]
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
for line in lines:
|
|
payload = cls._try_parse_json_payload(line)
|
|
records.append((payload, line))
|
|
return records
|
|
|
|
@staticmethod
|
|
def _is_textual_object_tag(object_tag: str) -> bool:
|
|
config = LabelStudioTagConfig.get_object_config(object_tag) or {}
|
|
category = config.get("category")
|
|
return category in TEXTUAL_OBJECT_CATEGORIES
|
|
|
|
@classmethod
|
|
def _extract_object_value_keys_by_category(
|
|
cls,
|
|
label_config: str,
|
|
categories: set[str],
|
|
) -> List[str]:
|
|
try:
|
|
root = ET.fromstring(label_config)
|
|
except Exception as exc:
|
|
logger.warning("解析 label_config 失败,已跳过占位填充:%s", exc)
|
|
return []
|
|
|
|
object_types = LabelStudioTagConfig.get_object_types()
|
|
seen: Dict[str, None] = {}
|
|
for element in root.iter():
|
|
if element.tag not in object_types:
|
|
continue
|
|
config = LabelStudioTagConfig.get_object_config(element.tag) or {}
|
|
category = config.get("category")
|
|
if categories and category not in categories:
|
|
continue
|
|
value = element.attrib.get("value", "")
|
|
if not value.startswith("$"):
|
|
continue
|
|
key = value[1:].strip()
|
|
if not key:
|
|
continue
|
|
seen[key] = None
|
|
return list(seen.keys())
|
|
|
|
@classmethod
|
|
def _extract_textual_value_keys(cls, label_config: str) -> List[str]:
|
|
return cls._extract_object_value_keys_by_category(label_config, TEXTUAL_OBJECT_CATEGORIES)
|
|
|
|
@staticmethod
|
|
def _needs_placeholder(value: Any) -> bool:
|
|
if value is None:
|
|
return True
|
|
if isinstance(value, str) and not value.strip():
|
|
return True
|
|
return False
|
|
|
|
def _apply_text_placeholders(self, data: Dict[str, Any], label_config: Optional[str]) -> None:
|
|
if not label_config:
|
|
return
|
|
for key in self._extract_textual_value_keys(label_config):
|
|
if self._needs_placeholder(data.get(key)):
|
|
data[key] = key
|
|
|
|
@staticmethod
|
|
def _header_already_present(header: ET.Element, name: str) -> bool:
|
|
value = header.attrib.get("value", "")
|
|
if value == name:
|
|
return True
|
|
header_text = (header.text or "").strip()
|
|
return header_text == name
|
|
|
|
def _decorate_label_config_for_editor(self, label_config: str) -> str:
|
|
try:
|
|
root = ET.fromstring(label_config)
|
|
except Exception as exc:
|
|
logger.warning("解析 label_config 失败,已跳过 name 展示增强:%s", exc)
|
|
return label_config
|
|
|
|
object_types = LabelStudioTagConfig.get_object_types()
|
|
used_names = set()
|
|
for element in root.iter():
|
|
name = element.attrib.get("name")
|
|
if name:
|
|
used_names.add(name)
|
|
|
|
def allocate_header_name(base: str) -> str:
|
|
candidate = f"{OBJECT_NAME_HEADER_PREFIX}{base}"
|
|
if candidate not in used_names:
|
|
used_names.add(candidate)
|
|
return candidate
|
|
idx = 1
|
|
while f"{candidate}_{idx}" in used_names:
|
|
idx += 1
|
|
resolved = f"{candidate}_{idx}"
|
|
used_names.add(resolved)
|
|
return resolved
|
|
|
|
for parent in root.iter():
|
|
children = list(parent)
|
|
i = 0
|
|
while i < len(children):
|
|
child = children[i]
|
|
if child.tag not in object_types:
|
|
i += 1
|
|
continue
|
|
if not self._is_textual_object_tag(child.tag):
|
|
i += 1
|
|
continue
|
|
obj_name = child.attrib.get("name")
|
|
if not obj_name:
|
|
i += 1
|
|
continue
|
|
|
|
if i > 0:
|
|
prev = children[i - 1]
|
|
if prev.tag == "Header" and self._header_already_present(prev, obj_name):
|
|
i += 1
|
|
continue
|
|
|
|
header = ET.Element("Header")
|
|
header.set("name", allocate_header_name(obj_name))
|
|
header.set("value", obj_name)
|
|
|
|
parent.insert(i, header)
|
|
children.insert(i, header)
|
|
i += 2
|
|
# continue outer loop
|
|
|
|
return ET.tostring(root, encoding="unicode")
|
|
|
|
@classmethod
|
|
def _build_source_document_filter(cls):
|
|
file_type_lower = func.lower(DatasetFiles.file_type)
|
|
file_name_lower = func.lower(DatasetFiles.file_name)
|
|
type_condition = file_type_lower.in_(SOURCE_DOCUMENT_TYPES)
|
|
name_conditions = [file_name_lower.like(f"%{ext}") for ext in SOURCE_DOCUMENT_EXTENSIONS]
|
|
return or_(type_condition, *name_conditions)
|
|
|
|
def _build_task_data(
|
|
self,
|
|
display_text: str,
|
|
parsed_payload: Optional[Dict[str, Any]],
|
|
label_config: Optional[str],
|
|
file_record: DatasetFiles,
|
|
dataset_id: str,
|
|
file_id: str,
|
|
primary_text_key: Optional[str],
|
|
) -> Dict[str, Any]:
|
|
data: Dict[str, Any] = dict(parsed_payload or {})
|
|
text_key = primary_text_key or TEXT_DATA_KEY
|
|
data[text_key] = display_text
|
|
|
|
file_name = str(getattr(file_record, "file_name", ""))
|
|
data[FILE_ID_KEY] = file_id
|
|
data[FILE_ID_CAMEL_KEY] = file_id
|
|
data[DATASET_ID_KEY] = dataset_id
|
|
data[DATASET_ID_CAMEL_KEY] = dataset_id
|
|
data[FILE_NAME_KEY] = file_name
|
|
data[FILE_NAME_CAMEL_KEY] = file_name
|
|
|
|
self._apply_text_placeholders(data, label_config)
|
|
return data
|
|
|
|
@classmethod
|
|
def _resolve_primary_text_value(
|
|
cls,
|
|
parsed_payload: Optional[Dict[str, Any]],
|
|
raw_text: str,
|
|
primary_text_key: Optional[str],
|
|
) -> str:
|
|
if parsed_payload and primary_text_key:
|
|
value = parsed_payload.get(primary_text_key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
if parsed_payload and not primary_text_key:
|
|
value = parsed_payload.get(TEXT_DATA_KEY)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
return raw_text
|
|
|
|
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
|
|
supported = dataset_type in SUPPORTED_EDITOR_DATASET_TYPES
|
|
unsupported_reason = None
|
|
if not supported:
|
|
supported_hint = "/".join(SUPPORTED_EDITOR_DATASET_TYPES)
|
|
unsupported_reason = f"当前仅支持 {supported_hint},项目数据类型为: {dataset_type or 'UNKNOWN'}"
|
|
|
|
# 优先使用项目配置中的label_config(用户编辑版本),其次使用模板默认配置
|
|
label_config = await self._resolve_project_label_config(project)
|
|
|
|
return EditorProjectInfo(
|
|
projectId=project.id,
|
|
datasetId=project.dataset_id,
|
|
datasetType=dataset_type or None,
|
|
templateId=project.template_id,
|
|
labelConfig=label_config,
|
|
supported=supported,
|
|
unsupportedReason=unsupported_reason,
|
|
)
|
|
|
|
async def list_tasks(
|
|
self,
|
|
project_id: str,
|
|
page: int = 0,
|
|
size: int = 50,
|
|
exclude_source_documents: Optional[bool] = None,
|
|
) -> EditorTaskListResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
base_conditions = [
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
]
|
|
|
|
count_result = await self.db.execute(
|
|
select(func.count())
|
|
.select_from(LabelingProjectFile)
|
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(*base_conditions)
|
|
)
|
|
total = int(count_result.scalar() or 0)
|
|
|
|
annotated_sort_key = case(
|
|
(AnnotationResult.id.isnot(None), 1),
|
|
else_=0,
|
|
)
|
|
files_result = await self.db.execute(
|
|
select(DatasetFiles, AnnotationResult.id, AnnotationResult.updated_at)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.outerjoin(
|
|
AnnotationResult,
|
|
(AnnotationResult.file_id == DatasetFiles.id)
|
|
& (AnnotationResult.project_id == project_id),
|
|
)
|
|
.where(*base_conditions)
|
|
.order_by(annotated_sort_key.asc(), DatasetFiles.created_at.desc())
|
|
.offset(page * size)
|
|
.limit(size)
|
|
)
|
|
rows = files_result.all()
|
|
|
|
items: List[EditorTaskListItem] = []
|
|
for file_record, annotation_id, annotation_updated_at in rows:
|
|
fid = str(file_record.id) # type: ignore[arg-type]
|
|
items.append(
|
|
EditorTaskListItem(
|
|
fileId=fid,
|
|
fileName=str(getattr(file_record, "file_name", "")),
|
|
fileType=getattr(file_record, "file_type", None),
|
|
hasAnnotation=annotation_id is not None,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
)
|
|
)
|
|
|
|
total_pages = (total + size - 1) // size if size > 0 else 0
|
|
return EditorTaskListResponse(
|
|
content=items,
|
|
totalElements=total,
|
|
totalPages=total_pages,
|
|
page=page,
|
|
size=size,
|
|
)
|
|
|
|
async def _fetch_text_content_via_download_api(self, dataset_id: str, file_id: str) -> str:
|
|
return await fetch_text_content_via_download_api(dataset_id, file_id)
|
|
|
|
async def get_task(
|
|
self,
|
|
project_id: str,
|
|
file_id: str,
|
|
segment_index: Optional[int] = None,
|
|
) -> EditorTaskResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
|
|
if dataset_type not in SUPPORTED_EDITOR_DATASET_TYPES:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="当前仅支持 TEXT/IMAGE/AUDIO/VIDEO 项目的内嵌编辑器",
|
|
)
|
|
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles).where(
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
|
|
|
|
if dataset_type == DATASET_TYPE_IMAGE:
|
|
return await self._build_image_task(project, file_record, file_id)
|
|
|
|
if dataset_type == DATASET_TYPE_AUDIO:
|
|
return await self._build_audio_task(project, file_record, file_id)
|
|
|
|
if dataset_type == DATASET_TYPE_VIDEO:
|
|
return await self._build_video_task(project, file_record, file_id)
|
|
|
|
return await self._build_text_task(project, file_record, file_id, segment_index)
|
|
|
|
async def _build_text_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
segment_index: Optional[int],
|
|
) -> EditorTaskResponse:
|
|
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
|
assert isinstance(text_content, str)
|
|
label_config = await self._resolve_project_label_config(project)
|
|
primary_text_key = self._resolve_primary_text_key(label_config)
|
|
file_name = str(getattr(file_record, "file_name", "")).lower()
|
|
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
|
if file_name.endswith(JSONL_EXTENSION):
|
|
records = self._parse_jsonl_records(text_content)
|
|
else:
|
|
parsed_payload = self._try_parse_json_payload(text_content)
|
|
if parsed_payload:
|
|
records = [(parsed_payload, text_content)]
|
|
|
|
if not records:
|
|
records = [(None, text_content)]
|
|
|
|
record_texts = [
|
|
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
|
for payload, raw_text in records
|
|
]
|
|
if not record_texts:
|
|
record_texts = [text_content]
|
|
|
|
# 获取现有标注
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project.id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
ann = ann_result.scalar_one_or_none()
|
|
|
|
ls_task_id = self._make_ls_task_id(project.id, file_id)
|
|
|
|
# 判断是否需要分段(JSONL 多行或主文本超过阈值)
|
|
segmentation_enabled = self._resolve_segmentation_enabled(project)
|
|
if not segmentation_enabled:
|
|
segment_index = None
|
|
needs_segmentation = segmentation_enabled and (
|
|
len(records) > 1 or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
|
|
)
|
|
segments: Optional[List[SegmentInfo]] = None
|
|
current_segment_index = 0
|
|
display_text = record_texts[0] if record_texts else text_content
|
|
selected_payload = records[0][0] if records else None
|
|
if not segmentation_enabled and len(records) > 1:
|
|
selected_payload = None
|
|
display_text = "\n".join(record_texts) if record_texts else text_content
|
|
|
|
segment_annotations: Dict[str, Any] = {}
|
|
if ann and ann.annotation and ann.annotation.get("segmented"):
|
|
segment_annotations = ann.annotation.get("segments", {})
|
|
|
|
if needs_segmentation:
|
|
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
|
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
|
segments = []
|
|
segment_cursor = 0
|
|
|
|
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
|
|
normalized_text = record_text or ""
|
|
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
|
raw_segments = splitter.split(normalized_text)
|
|
for chunk_index, seg in enumerate(raw_segments):
|
|
segments.append(SegmentInfo(
|
|
idx=segment_cursor,
|
|
text=seg["text"],
|
|
start=seg["start"],
|
|
end=seg["end"],
|
|
hasAnnotation=str(segment_cursor) in segment_annotations,
|
|
lineIndex=record_index,
|
|
chunkIndex=chunk_index,
|
|
))
|
|
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
|
|
segment_cursor += 1
|
|
else:
|
|
segments.append(SegmentInfo(
|
|
idx=segment_cursor,
|
|
text=normalized_text,
|
|
start=0,
|
|
end=len(normalized_text),
|
|
hasAnnotation=str(segment_cursor) in segment_annotations,
|
|
lineIndex=record_index,
|
|
chunkIndex=0,
|
|
))
|
|
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
|
|
segment_cursor += 1
|
|
|
|
if not segments:
|
|
segments = [SegmentInfo(idx=0, text="", start=0, end=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
|
|
segment_contexts = [(None, "", "", 0, 0)]
|
|
|
|
current_segment_index = segment_index if segment_index is not None else 0
|
|
if current_segment_index < 0 or current_segment_index >= len(segments):
|
|
current_segment_index = 0
|
|
|
|
selected_payload, _, display_text, _, _ = segment_contexts[current_segment_index]
|
|
|
|
# 构造 task 对象
|
|
task_data = self._build_task_data(
|
|
display_text=display_text,
|
|
parsed_payload=selected_payload,
|
|
label_config=label_config,
|
|
file_record=file_record,
|
|
dataset_id=project.dataset_id,
|
|
file_id=file_id,
|
|
primary_text_key=primary_text_key,
|
|
)
|
|
if needs_segmentation:
|
|
task_data[SEGMENT_INDEX_KEY] = current_segment_index
|
|
task_data[SEGMENT_INDEX_CAMEL_KEY] = current_segment_index
|
|
|
|
task: Dict[str, Any] = {
|
|
"id": ls_task_id,
|
|
"data": task_data,
|
|
"annotations": [],
|
|
}
|
|
|
|
annotation_updated_at = None
|
|
if ann:
|
|
annotation_updated_at = ann.updated_at
|
|
|
|
if needs_segmentation and ann.annotation and ann.annotation.get("segmented"):
|
|
# 分段模式:获取当前段落的标注
|
|
segment_annotations = ann.annotation.get("segments", {})
|
|
seg_ann = segment_annotations.get(str(current_segment_index), {})
|
|
stored = {
|
|
"id": self._make_ls_annotation_id(project.id, file_id) + current_segment_index,
|
|
"task": ls_task_id,
|
|
"result": seg_ann.get("result", []),
|
|
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
|
"updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"),
|
|
}
|
|
task["annotations"] = [stored]
|
|
elif not needs_segmentation and not (ann.annotation or {}).get("segmented"):
|
|
# 非分段模式:直接返回存储的 annotation 原始对象
|
|
stored = dict(ann.annotation or {})
|
|
stored["task"] = ls_task_id
|
|
if not isinstance(stored.get("id"), int):
|
|
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [stored]
|
|
else:
|
|
# 首次从非分段切换到分段:提供空标注
|
|
empty_ann_id = self._make_ls_annotation_id(project.id, file_id) + current_segment_index
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
else:
|
|
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
|
|
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
|
|
if needs_segmentation:
|
|
empty_ann_id += current_segment_index
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
|
|
return EditorTaskResponse(
|
|
task=task,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
segmented=needs_segmentation,
|
|
segments=segments,
|
|
totalSegments=len(segments) if segments else 1,
|
|
currentSegmentIndex=current_segment_index,
|
|
)
|
|
|
|
async def _build_media_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
default_key: str,
|
|
categories: set[str],
|
|
) -> EditorTaskResponse:
|
|
label_config = await self._resolve_project_label_config(project)
|
|
media_keys = self._resolve_media_value_keys(label_config, default_key, categories)
|
|
preview_url = self._build_file_preview_url(project.dataset_id, file_id)
|
|
file_name = str(getattr(file_record, "file_name", ""))
|
|
|
|
task_data: Dict[str, Any] = {
|
|
FILE_ID_KEY: file_id,
|
|
FILE_ID_CAMEL_KEY: file_id,
|
|
DATASET_ID_KEY: project.dataset_id,
|
|
DATASET_ID_CAMEL_KEY: project.dataset_id,
|
|
FILE_NAME_KEY: file_name,
|
|
FILE_NAME_CAMEL_KEY: file_name,
|
|
}
|
|
for key in media_keys:
|
|
task_data[key] = preview_url
|
|
self._apply_text_placeholders(task_data, label_config)
|
|
|
|
# 获取现有标注
|
|
ann_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project.id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
ann = ann_result.scalar_one_or_none()
|
|
ls_task_id = self._make_ls_task_id(project.id, file_id)
|
|
|
|
task: Dict[str, Any] = {
|
|
"id": ls_task_id,
|
|
"data": task_data,
|
|
"annotations": [],
|
|
}
|
|
|
|
annotation_updated_at = None
|
|
if ann and not (ann.annotation or {}).get("segmented"):
|
|
annotation_updated_at = ann.updated_at
|
|
stored = dict(ann.annotation or {})
|
|
stored["task"] = ls_task_id
|
|
if not isinstance(stored.get("id"), int):
|
|
stored["id"] = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [stored]
|
|
else:
|
|
empty_ann_id = self._make_ls_annotation_id(project.id, file_id)
|
|
task["annotations"] = [
|
|
{
|
|
"id": empty_ann_id,
|
|
"task": ls_task_id,
|
|
"result": [],
|
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
]
|
|
|
|
return EditorTaskResponse(
|
|
task=task,
|
|
annotationUpdatedAt=annotation_updated_at,
|
|
segmented=False,
|
|
segments=None,
|
|
totalSegments=1,
|
|
currentSegmentIndex=0,
|
|
)
|
|
|
|
async def _build_image_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=IMAGE_DATA_KEY,
|
|
categories=IMAGE_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def _build_audio_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=AUDIO_DATA_KEY,
|
|
categories=MEDIA_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def _build_video_task(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
file_id: str,
|
|
) -> EditorTaskResponse:
|
|
return await self._build_media_task(
|
|
project=project,
|
|
file_record=file_record,
|
|
file_id=file_id,
|
|
default_key=VIDEO_DATA_KEY,
|
|
categories=MEDIA_OBJECT_CATEGORIES,
|
|
)
|
|
|
|
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
|
|
project = await self._get_project_or_404(project_id)
|
|
|
|
# 校验文件归属
|
|
file_result = await self.db.execute(
|
|
select(DatasetFiles)
|
|
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project.id,
|
|
DatasetFiles.id == file_id,
|
|
DatasetFiles.dataset_id == project.dataset_id,
|
|
)
|
|
)
|
|
file_record = file_result.scalar_one_or_none()
|
|
if not file_record:
|
|
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
|
|
|
|
annotation_payload = dict(request.annotation or {})
|
|
result = annotation_payload.get("result")
|
|
if result is None:
|
|
annotation_payload["result"] = []
|
|
elif not isinstance(result, list):
|
|
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
|
|
|
|
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
|
|
|
existing_result = await self.db.execute(
|
|
select(AnnotationResult).where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.file_id == file_id,
|
|
)
|
|
)
|
|
existing = existing_result.scalar_one_or_none()
|
|
|
|
now = datetime.utcnow()
|
|
|
|
# 判断是否为分段保存模式
|
|
if request.segment_index is not None:
|
|
# 分段模式:合并段落标注到整体结构
|
|
final_payload = self._merge_segment_annotation(
|
|
existing.annotation if existing else None,
|
|
request.segment_index,
|
|
annotation_payload,
|
|
)
|
|
else:
|
|
# 非分段模式:直接使用传入的 annotation
|
|
annotation_payload["task"] = ls_task_id
|
|
if not isinstance(annotation_payload.get("id"), int):
|
|
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
|
|
final_payload = annotation_payload
|
|
|
|
if existing:
|
|
if request.expected_updated_at and existing.updated_at:
|
|
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
|
|
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
|
|
|
|
existing.annotation = final_payload # type: ignore[assignment]
|
|
existing.updated_at = now # type: ignore[assignment]
|
|
await self.db.commit()
|
|
await self.db.refresh(existing)
|
|
|
|
response = UpsertAnnotationResponse(
|
|
annotationId=existing.id,
|
|
updatedAt=existing.updated_at or now,
|
|
)
|
|
await self._sync_annotation_to_knowledge(project, file_record, final_payload, existing.updated_at)
|
|
return response
|
|
|
|
new_id = str(uuid.uuid4())
|
|
record = AnnotationResult(
|
|
id=new_id,
|
|
project_id=project_id,
|
|
file_id=file_id,
|
|
annotation=final_payload,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
self.db.add(record)
|
|
await self.db.commit()
|
|
await self.db.refresh(record)
|
|
|
|
response = UpsertAnnotationResponse(
|
|
annotationId=record.id,
|
|
updatedAt=record.updated_at or now,
|
|
)
|
|
await self._sync_annotation_to_knowledge(project, file_record, final_payload, record.updated_at)
|
|
return response
|
|
|
|
def _merge_segment_annotation(
|
|
self,
|
|
existing: Optional[Dict[str, Any]],
|
|
segment_index: int,
|
|
new_annotation: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
合并段落标注到整体结构
|
|
|
|
Args:
|
|
existing: 现有的 annotation 数据
|
|
segment_index: 段落索引
|
|
new_annotation: 新的段落标注数据
|
|
|
|
Returns:
|
|
合并后的 annotation 结构
|
|
"""
|
|
if not existing or not existing.get("segmented"):
|
|
# 初始化分段结构
|
|
base: Dict[str, Any] = {
|
|
"segmented": True,
|
|
"version": 1,
|
|
"segments": {},
|
|
}
|
|
else:
|
|
base = dict(existing)
|
|
|
|
# 更新指定段落的标注
|
|
base["segments"][str(segment_index)] = {
|
|
"result": new_annotation.get("result", []),
|
|
"created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
|
}
|
|
|
|
return base
|
|
|
|
async def _sync_annotation_to_knowledge(
|
|
self,
|
|
project: LabelingProject,
|
|
file_record: DatasetFiles,
|
|
annotation: Dict[str, Any],
|
|
annotation_updated_at: Optional[datetime],
|
|
) -> None:
|
|
"""同步标注结果到知识管理(失败不影响标注保存)"""
|
|
try:
|
|
await KnowledgeSyncService(self.db).sync_annotation_to_knowledge(
|
|
project=project,
|
|
file_record=file_record,
|
|
annotation=annotation,
|
|
annotation_updated_at=annotation_updated_at,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("标注同步知识管理失败:%s", exc)
|
|
|