feat(annotation): 添加分段标注功能支持

- 定义分段标注相关常量(segmented、segments、result等键名)
- 实现分段标注提取方法_extract_segment_annotations处理字典和列表格式
- 添加分段标注判断方法_is_segmented_annotation检测标注状态
- 修改_has_annotation_result方法使用新的分段标注处理逻辑
- 在任务创建过程中集成分段标注数据处理
- 更新导出服务中的分段标注结果扁平化处理
- 实现标注归一化方法支持分段标注格式转换
- 调整JSON和CSV导出格式适配分段标注结构
This commit is contained in:
2026-01-31 14:36:16 +08:00
parent 8fdc7d99b8
commit c5c8e6c69e
2 changed files with 145 additions and 29 deletions

View File

@@ -59,6 +59,11 @@ FILE_ID_CAMEL_KEY = "fileId"
FILE_NAME_CAMEL_KEY = "fileName" FILE_NAME_CAMEL_KEY = "fileName"
SEGMENT_INDEX_KEY = "segment_index" SEGMENT_INDEX_KEY = "segment_index"
SEGMENT_INDEX_CAMEL_KEY = "segmentIndex" SEGMENT_INDEX_CAMEL_KEY = "segmentIndex"
SEGMENTED_KEY = "segmented"
SEGMENTS_KEY = "segments"
SEGMENT_RESULT_KEY = "result"
SEGMENT_CREATED_AT_KEY = "created_at"
SEGMENT_UPDATED_AT_KEY = "updated_at"
JSONL_EXTENSION = ".jsonl" JSONL_EXTENSION = ".jsonl"
TEXTUAL_OBJECT_CATEGORIES = {"text", "document"} TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
IMAGE_OBJECT_CATEGORIES = {"image"} IMAGE_OBJECT_CATEGORIES = {"image"}
@@ -352,22 +357,63 @@ class AnnotationEditorService:
return ET.tostring(root, encoding="unicode") return ET.tostring(root, encoding="unicode")
@staticmethod
def _extract_segment_annotations(payload: Optional[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
if not payload or not isinstance(payload, dict):
return {}
segments = payload.get(SEGMENTS_KEY)
if isinstance(segments, dict):
normalized: Dict[str, Dict[str, Any]] = {}
for key, value in segments.items():
if isinstance(value, dict):
normalized[str(key)] = value
return normalized
if isinstance(segments, list):
normalized: Dict[str, Dict[str, Any]] = {}
for idx, value in enumerate(segments):
if not isinstance(value, dict):
continue
key = (
value.get(SEGMENT_INDEX_CAMEL_KEY)
or value.get(SEGMENT_INDEX_KEY)
or value.get("segment")
or value.get("idx")
)
if key is None:
key = idx
normalized[str(key)] = value
return normalized
return {}
@staticmethod
def _is_segmented_annotation(payload: Optional[Dict[str, Any]]) -> bool:
if not payload or not isinstance(payload, dict):
return False
if payload.get(SEGMENTED_KEY):
return True
segments = payload.get(SEGMENTS_KEY)
if isinstance(segments, dict):
return len(segments) > 0
if isinstance(segments, list):
return len(segments) > 0
return False
@staticmethod @staticmethod
def _has_annotation_result(payload: Optional[Dict[str, Any]]) -> bool: def _has_annotation_result(payload: Optional[Dict[str, Any]]) -> bool:
if not payload or not isinstance(payload, dict): if not payload or not isinstance(payload, dict):
return False return False
if payload.get("segmented"): if AnnotationEditorService._is_segmented_annotation(payload):
segments = payload.get("segments", {}) segments = AnnotationEditorService._extract_segment_annotations(payload)
if not isinstance(segments, dict): if not segments:
return False return False
for segment in segments.values(): for segment in segments.values():
if not isinstance(segment, dict): if not isinstance(segment, dict):
continue continue
result = segment.get("result") result = segment.get(SEGMENT_RESULT_KEY)
if isinstance(result, list) and len(result) > 0: if isinstance(result, list) and len(result) > 0:
return True return True
return False return False
result = payload.get("result") result = payload.get(SEGMENT_RESULT_KEY)
return isinstance(result, list) and len(result) > 0 return isinstance(result, list) and len(result) > 0
@classmethod @classmethod
@@ -591,6 +637,13 @@ class AnnotationEditorService:
ls_task_id = self._make_ls_task_id(project.id, file_id) ls_task_id = self._make_ls_task_id(project.id, file_id)
segment_annotations: Dict[str, Dict[str, Any]] = {}
has_segmented_annotation = False
if ann and isinstance(ann.annotation, dict):
segment_annotations = self._extract_segment_annotations(ann.annotation)
has_segmented_annotation = self._is_segmented_annotation(ann.annotation)
segment_annotation_keys = set(segment_annotations.keys())
# 判断是否需要分段(JSONL 多行或主文本超过阈值) # 判断是否需要分段(JSONL 多行或主文本超过阈值)
segmentation_enabled = self._resolve_segmentation_enabled(project) segmentation_enabled = self._resolve_segmentation_enabled(project)
if not segmentation_enabled: if not segmentation_enabled:
@@ -606,10 +659,6 @@ class AnnotationEditorService:
selected_payload = None selected_payload = None
display_text = "\n".join(record_texts) if record_texts else text_content display_text = "\n".join(record_texts) if record_texts else text_content
segment_annotations: Dict[str, Any] = {}
if ann and ann.annotation and ann.annotation.get("segmented"):
segment_annotations = ann.annotation.get("segments", {})
if needs_segmentation: if needs_segmentation:
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD) splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = [] segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
@@ -626,7 +675,7 @@ class AnnotationEditorService:
text=seg["text"], text=seg["text"],
start=seg["start"], start=seg["start"],
end=seg["end"], end=seg["end"],
hasAnnotation=str(segment_cursor) in segment_annotations, hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index, lineIndex=record_index,
chunkIndex=chunk_index, chunkIndex=chunk_index,
)) ))
@@ -638,7 +687,7 @@ class AnnotationEditorService:
text=normalized_text, text=normalized_text,
start=0, start=0,
end=len(normalized_text), end=len(normalized_text),
hasAnnotation=str(segment_cursor) in segment_annotations, hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index, lineIndex=record_index,
chunkIndex=0, chunkIndex=0,
)) ))
@@ -679,19 +728,18 @@ class AnnotationEditorService:
if ann: if ann:
annotation_updated_at = ann.updated_at annotation_updated_at = ann.updated_at
if needs_segmentation and ann.annotation and ann.annotation.get("segmented"): if needs_segmentation and has_segmented_annotation:
# 分段模式:获取当前段落的标注 # 分段模式:获取当前段落的标注
segment_annotations = ann.annotation.get("segments", {})
seg_ann = segment_annotations.get(str(current_segment_index), {}) seg_ann = segment_annotations.get(str(current_segment_index), {})
stored = { stored = {
"id": self._make_ls_annotation_id(project.id, file_id) + current_segment_index, "id": self._make_ls_annotation_id(project.id, file_id) + current_segment_index,
"task": ls_task_id, "task": ls_task_id,
"result": seg_ann.get("result", []), "result": seg_ann.get(SEGMENT_RESULT_KEY, []),
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"), "created_at": seg_ann.get(SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
"updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"), "updated_at": seg_ann.get(SEGMENT_UPDATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
} }
task["annotations"] = [stored] task["annotations"] = [stored]
elif not needs_segmentation and not (ann.annotation or {}).get("segmented"): elif not needs_segmentation and not has_segmented_annotation:
# 非分段模式:直接返回存储的 annotation 原始对象 # 非分段模式:直接返回存储的 annotation 原始对象
stored = dict(ann.annotation or {}) stored = dict(ann.annotation or {})
stored["task"] = ls_task_id stored["task"] = ls_task_id
@@ -968,21 +1016,28 @@ class AnnotationEditorService:
Returns: Returns:
合并后的 annotation 结构 合并后的 annotation 结构
""" """
if not existing or not existing.get("segmented"): if not existing or not existing.get(SEGMENTED_KEY):
# 初始化分段结构 # 初始化分段结构
base: Dict[str, Any] = { base: Dict[str, Any] = {
"segmented": True, SEGMENTED_KEY: True,
"version": 1, "version": 1,
"segments": {}, SEGMENTS_KEY: {},
} }
else: else:
base = dict(existing) base = dict(existing)
if not base.get(SEGMENTED_KEY):
base[SEGMENTED_KEY] = True
segments = base.get(SEGMENTS_KEY)
if not isinstance(segments, dict):
segments = {}
base[SEGMENTS_KEY] = segments
# 更新指定段落的标注 # 更新指定段落的标注
base["segments"][str(segment_index)] = { segments[str(segment_index)] = {
"result": new_annotation.get("result", []), SEGMENT_RESULT_KEY: new_annotation.get(SEGMENT_RESULT_KEY, []),
"created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"), SEGMENT_CREATED_AT_KEY: new_annotation.get(SEGMENT_CREATED_AT_KEY, datetime.utcnow().isoformat() + "Z"),
"updated_at": datetime.utcnow().isoformat() + "Z", SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
} }
return base return base

View File

@@ -63,6 +63,12 @@ from ..schema.export import (
logger = get_logger(__name__) logger = get_logger(__name__)
SEGMENTED_KEY = "segmented"
SEGMENTS_KEY = "segments"
SEGMENT_RESULT_KEY = "result"
SEGMENT_INDEX_KEY = "segmentIndex"
SEGMENT_INDEX_FALLBACK_KEY = "segment_index"
class AnnotationExportService: class AnnotationExportService:
"""标注数据导出服务""" """标注数据导出服务"""
@@ -239,6 +245,61 @@ class AnnotationExportService:
return items return items
@staticmethod
def _flatten_annotation_results(annotation: Dict[str, Any]) -> List[Dict[str, Any]]:
if not annotation or not isinstance(annotation, dict):
return []
segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
results: List[Dict[str, Any]] = []
if isinstance(segments, dict):
for key, segment in segments.items():
if not isinstance(segment, dict):
continue
segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list):
continue
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
normalized[SEGMENT_INDEX_KEY] = int(key) if str(key).isdigit() else key
results.append(normalized)
else:
results.append({"value": item, SEGMENT_INDEX_KEY: key})
elif isinstance(segments, list):
for idx, segment in enumerate(segments):
if not isinstance(segment, dict):
continue
segment_results = segment.get(SEGMENT_RESULT_KEY)
if not isinstance(segment_results, list):
continue
segment_index = segment.get(SEGMENT_INDEX_KEY, segment.get(SEGMENT_INDEX_FALLBACK_KEY, idx))
for item in segment_results:
if isinstance(item, dict):
normalized = dict(item)
if SEGMENT_INDEX_KEY not in normalized and SEGMENT_INDEX_FALLBACK_KEY not in normalized:
normalized[SEGMENT_INDEX_KEY] = segment_index
results.append(normalized)
else:
results.append({"value": item, SEGMENT_INDEX_KEY: segment_index})
return results
result = annotation.get(SEGMENT_RESULT_KEY)
return result if isinstance(result, list) else []
@classmethod
def _normalize_annotation_for_export(cls, annotation: Dict[str, Any]) -> Dict[str, Any]:
if not annotation or not isinstance(annotation, dict):
return {}
segments = annotation.get(SEGMENTS_KEY)
if annotation.get(SEGMENTED_KEY) or isinstance(segments, (dict, list)):
normalized = dict(annotation)
normalized_result = cls._flatten_annotation_results(annotation)
if SEGMENT_RESULT_KEY not in normalized or not isinstance(normalized.get(SEGMENT_RESULT_KEY), list):
normalized[SEGMENT_RESULT_KEY] = normalized_result
return normalized
return annotation
def _export_json( def _export_json(
self, items: List[AnnotationExportItem], project_name: str self, items: List[AnnotationExportItem], project_name: str
) -> Tuple[bytes, str, str]: ) -> Tuple[bytes, str, str]:
@@ -252,7 +313,7 @@ class AnnotationExportService:
"file_id": item.file_id, "file_id": item.file_id,
"file_name": item.file_name, "file_name": item.file_name,
"data": item.data, "data": item.data,
"annotations": item.annotations, "annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"created_at": item.created_at.isoformat() if item.created_at else None, "created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None, "updated_at": item.updated_at.isoformat() if item.updated_at else None,
} }
@@ -274,7 +335,7 @@ class AnnotationExportService:
"file_id": item.file_id, "file_id": item.file_id,
"file_name": item.file_name, "file_name": item.file_name,
"data": item.data, "data": item.data,
"annotations": item.annotations, "annotations": [self._normalize_annotation_for_export(ann) for ann in item.annotations],
"created_at": item.created_at.isoformat() if item.created_at else None, "created_at": item.created_at.isoformat() if item.created_at else None,
"updated_at": item.updated_at.isoformat() if item.updated_at else None, "updated_at": item.updated_at.isoformat() if item.updated_at else None,
} }
@@ -307,7 +368,7 @@ class AnnotationExportService:
# 提取标签信息(支持多种标注类型) # 提取标签信息(支持多种标注类型)
labels = [] labels = []
for ann in item.annotations: for ann in item.annotations:
results = ann.get("result", []) results = self._flatten_annotation_results(ann)
for r in results: for r in results:
value = r.get("value", {}) value = r.get("value", {})
label_type = r.get("type", "") label_type = r.get("type", "")
@@ -382,7 +443,7 @@ class AnnotationExportService:
# 处理标注 # 处理标注
for ann in item.annotations: for ann in item.annotations:
results = ann.get("result", []) results = self._flatten_annotation_results(ann)
for r in results: for r in results:
# 处理矩形框标注 (rectanglelabels) # 处理矩形框标注 (rectanglelabels)
if r.get("type") == "rectanglelabels": if r.get("type") == "rectanglelabels":
@@ -434,7 +495,7 @@ class AnnotationExportService:
lines = [] lines = []
for ann in item.annotations: for ann in item.annotations:
results = ann.get("result", []) results = self._flatten_annotation_results(ann)
for r in results: for r in results:
# 处理矩形框标注 # 处理矩形框标注
if r.get("type") == "rectanglelabels": if r.get("type") == "rectanglelabels":