feat(annotation): 添加文本分段标注功能

- 引入文本分割器实现长文本按200字符自动分段
- 增加分段状态管理和段落导航界面
- 支持按段落保存和加载标注数据
- 实现分段模式下的标注状态跟踪
- 扩展API接口支持段落索引参数
- 添加分段相关的数据模型定义
This commit is contained in:
2026-01-19 18:18:19 +08:00
parent 3af0f0b3a1
commit 71c4a8d8a6
6 changed files with 395 additions and 41 deletions

View File

@@ -0,0 +1,113 @@
"""
标注文本分割器
职责:将长文本按指定规则分割为适合标注的段落
- 最大200字符(CJK按1字符计)
- 分隔符:。;以及正则 \\?|\\!|(?<!\\d)\\.(?!\\d)
- 超长句子保持完整
"""
import re
from typing import List, TypedDict
class SegmentInfo(TypedDict):
"""段落信息"""
idx: int # 段落索引
text: str # 段落文本
start: int # 在原文中的起始位置
end: int # 在原文中的结束位置
class AnnotationTextSplitter:
"""标注文本分割器"""
# 分隔符正则:全角句号、全角分号、以及非数字间的英文句号/问号/感叹号
# 使用捕获组保留分隔符
SEPARATOR_PATTERN = r'(。|;|\?|\!|(?<!\d)\.(?!\d))'
def __init__(self, max_chars: int = 200):
"""
初始化分割器
Args:
max_chars: 每个段落的最大字符数(默认200)
"""
self.max_chars = max_chars
def split(self, text: str) -> List[SegmentInfo]:
"""
将文本分割为段落列表
规则:
1. 按分隔符切分为句子
2. 贪心合并句子,直到超过 max_chars
3. 单句超过 max_chars 则独立成段(保持句子完整)
Args:
text: 待分割的文本
Returns:
段落列表,每个元素包含 idx, text, start, end
"""
if not text:
return [{"idx": 0, "text": "", "start": 0, "end": 0}]
# 短文本不需要分割
if len(text) <= self.max_chars:
return [{"idx": 0, "text": text, "start": 0, "end": len(text)}]
# 按分隔符切分,保留分隔符
parts = re.split(self.SEPARATOR_PATTERN, text)
# 合并句子和分隔符
sentences: List[str] = []
i = 0
while i < len(parts):
part = parts[i]
# 检查下一个是否是分隔符(匹配捕获组)
if i + 1 < len(parts) and re.fullmatch(self.SEPARATOR_PATTERN, parts[i + 1]):
# 将分隔符附加到当前部分
part += parts[i + 1]
i += 2
else:
i += 1
# 跳过空字符串
if part:
sentences.append(part)
# 贪心合并
segments: List[SegmentInfo] = []
current_text = ""
current_start = 0
idx = 0
for sentence in sentences:
if not current_text:
# 开始新段落
current_text = sentence
elif len(current_text) + len(sentence) <= self.max_chars:
# 可以合并到当前段落
current_text += sentence
else:
# 当前段落已满,保存
segments.append({
"idx": idx,
"text": current_text,
"start": current_start,
"end": current_start + len(current_text)
})
idx += 1
current_start += len(current_text)
current_text = sentence
# 处理最后一个段落
if current_text:
segments.append({
"idx": idx,
"text": current_text,
"start": current_start,
"end": current_start + len(current_text)
})
return segments

View File

@@ -27,10 +27,12 @@ from app.module.annotation.schema.editor import (
EditorTaskListItem,
EditorTaskListResponse,
EditorTaskResponse,
SegmentInfo,
UpsertAnnotationRequest,
UpsertAnnotationResponse,
)
from app.module.annotation.service.template import AnnotationTemplateService
from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter
logger = get_logger(__name__)
@@ -38,6 +40,9 @@ logger = get_logger(__name__)
class AnnotationEditorService:
"""Label Studio Editor 集成服务(TEXT POC 版)"""
# 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession):
self.db = db
self.template_service = AnnotationTemplateService()
@@ -206,7 +211,12 @@ class AnnotationEditorService:
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, err={e}")
raise HTTPException(status_code=502, detail="读取文本失败(下载接口调用异常)")
async def get_task(self, project_id: str, file_id: str) -> EditorTaskResponse:
async def get_task(
self,
project_id: str,
file_id: str,
segment_index: Optional[int] = None,
) -> EditorTaskResponse:
project = await self._get_project_or_404(project_id)
# TEXT 支持校验
@@ -226,6 +236,7 @@ class AnnotationEditorService:
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
# 获取现有标注
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project_id,
@@ -236,10 +247,44 @@ class AnnotationEditorService:
ls_task_id = self._make_ls_task_id(project_id, file_id)
# 判断是否需要分段
needs_segmentation = len(text_content) > self.SEGMENT_THRESHOLD
segments: Optional[List[SegmentInfo]] = None
current_segment_index = 0
display_text = text_content
if needs_segmentation:
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
raw_segments = splitter.split(text_content)
current_segment_index = segment_index if segment_index is not None else 0
# 校验段落索引
if current_segment_index < 0 or current_segment_index >= len(raw_segments):
current_segment_index = 0
# 标记每个段落是否已有标注
segment_annotations: Dict[str, Any] = {}
if ann and ann.annotation and ann.annotation.get("segmented"):
segment_annotations = ann.annotation.get("segments", {})
segments = []
for seg in raw_segments:
segments.append(SegmentInfo(
idx=seg["idx"],
text=seg["text"],
start=seg["start"],
end=seg["end"],
hasAnnotation=str(seg["idx"]) in segment_annotations,
))
# 当前段落文本用于 task.data.text
display_text = raw_segments[current_segment_index]["text"]
# 构造 task 对象
task: Dict[str, Any] = {
"id": ls_task_id,
"data": {
"text": text_content,
"text": display_text,
"file_id": file_id,
"dataset_id": project.dataset_id,
"file_name": getattr(file_record, "file_name", ""),
@@ -250,15 +295,43 @@ class AnnotationEditorService:
annotation_updated_at = None
if ann:
annotation_updated_at = ann.updated_at
# 直接返回存储的 annotation 原始对象(Label Studio 兼容)
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
task["annotations"] = [stored]
if needs_segmentation and ann.annotation and ann.annotation.get("segmented"):
# 分段模式:获取当前段落的标注
segment_annotations = ann.annotation.get("segments", {})
seg_ann = segment_annotations.get(str(current_segment_index), {})
stored = {
"id": self._make_ls_annotation_id(project_id, file_id) + current_segment_index,
"task": ls_task_id,
"result": seg_ann.get("result", []),
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"),
"updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"),
}
task["annotations"] = [stored]
elif not needs_segmentation:
# 非分段模式:直接返回存储的 annotation 原始对象
stored = dict(ann.annotation or {})
stored["task"] = ls_task_id
if not isinstance(stored.get("id"), int):
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
task["annotations"] = [stored]
else:
# 首次从非分段切换到分段:提供空标注
empty_ann_id = self._make_ls_annotation_id(project_id, file_id) + current_segment_index
task["annotations"] = [
{
"id": empty_ann_id,
"task": ls_task_id,
"result": [],
"created_at": datetime.utcnow().isoformat() + "Z",
"updated_at": datetime.utcnow().isoformat() + "Z",
}
]
else:
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
empty_ann_id = self._make_ls_annotation_id(project_id, file_id)
if needs_segmentation:
empty_ann_id += current_segment_index
task["annotations"] = [
{
"id": empty_ann_id,
@@ -272,6 +345,10 @@ class AnnotationEditorService:
return EditorTaskResponse(
task=task,
annotationUpdatedAt=annotation_updated_at,
segmented=needs_segmentation,
segments=segments,
totalSegments=len(segments) if segments else 1,
currentSegmentIndex=current_segment_index,
)
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
@@ -293,9 +370,6 @@ class AnnotationEditorService:
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
ls_task_id = self._make_ls_task_id(project_id, file_id)
annotation_payload["task"] = ls_task_id
if not isinstance(annotation_payload.get("id"), int):
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
existing_result = await self.db.execute(
select(AnnotationResult).where(
@@ -307,12 +381,27 @@ class AnnotationEditorService:
now = datetime.utcnow()
# 判断是否为分段保存模式
if request.segment_index is not None:
# 分段模式:合并段落标注到整体结构
final_payload = self._merge_segment_annotation(
existing.annotation if existing else None,
request.segment_index,
annotation_payload,
)
else:
# 非分段模式:直接使用传入的 annotation
annotation_payload["task"] = ls_task_id
if not isinstance(annotation_payload.get("id"), int):
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
final_payload = annotation_payload
if existing:
if request.expected_updated_at and existing.updated_at:
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
existing.annotation = annotation_payload # type: ignore[assignment]
existing.annotation = final_payload # type: ignore[assignment]
existing.updated_at = now # type: ignore[assignment]
await self.db.commit()
await self.db.refresh(existing)
@@ -327,7 +416,7 @@ class AnnotationEditorService:
id=new_id,
project_id=project_id,
file_id=file_id,
annotation=annotation_payload,
annotation=final_payload,
created_at=now,
updated_at=now,
)
@@ -340,3 +429,39 @@ class AnnotationEditorService:
updatedAt=record.updated_at or now,
)
def _merge_segment_annotation(
self,
existing: Optional[Dict[str, Any]],
segment_index: int,
new_annotation: Dict[str, Any],
) -> Dict[str, Any]:
"""
合并段落标注到整体结构
Args:
existing: 现有的 annotation 数据
segment_index: 段落索引
new_annotation: 新的段落标注数据
Returns:
合并后的 annotation 结构
"""
if not existing or not existing.get("segmented"):
# 初始化分段结构
base: Dict[str, Any] = {
"segmented": True,
"version": 1,
"segments": {},
}
else:
base = dict(existing)
# 更新指定段落的标注
base["segments"][str(segment_index)] = {
"result": new_annotation.get("result", []),
"created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"),
"updated_at": datetime.utcnow().isoformat() + "Z",
}
return base