You've already forked DataMate
feat(annotation): 添加文本分段标注功能
- 引入文本分割器实现长文本按200字符自动分段 - 增加分段状态管理和段落导航界面 - 支持按段落保存和加载标注数据 - 实现分段模式下的标注状态跟踪 - 扩展API接口支持段落索引参数 - 添加分段相关的数据模型定义
This commit is contained in:
@@ -27,10 +27,12 @@ from app.module.annotation.schema.editor import (
|
||||
EditorTaskListItem,
|
||||
EditorTaskListResponse,
|
||||
EditorTaskResponse,
|
||||
SegmentInfo,
|
||||
UpsertAnnotationRequest,
|
||||
UpsertAnnotationResponse,
|
||||
)
|
||||
from app.module.annotation.service.template import AnnotationTemplateService
|
||||
from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -38,6 +40,9 @@ logger = get_logger(__name__)
|
||||
class AnnotationEditorService:
|
||||
"""Label Studio Editor 集成服务(TEXT POC 版)"""
|
||||
|
||||
# 分段阈值:超过此字符数自动分段
|
||||
SEGMENT_THRESHOLD = 200
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.template_service = AnnotationTemplateService()
|
||||
@@ -206,7 +211,12 @@ class AnnotationEditorService:
|
||||
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, err={e}")
|
||||
raise HTTPException(status_code=502, detail="读取文本失败(下载接口调用异常)")
|
||||
|
||||
async def get_task(self, project_id: str, file_id: str) -> EditorTaskResponse:
|
||||
async def get_task(
|
||||
self,
|
||||
project_id: str,
|
||||
file_id: str,
|
||||
segment_index: Optional[int] = None,
|
||||
) -> EditorTaskResponse:
|
||||
project = await self._get_project_or_404(project_id)
|
||||
|
||||
# TEXT 支持校验
|
||||
@@ -226,6 +236,7 @@ class AnnotationEditorService:
|
||||
|
||||
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
||||
|
||||
# 获取现有标注
|
||||
ann_result = await self.db.execute(
|
||||
select(AnnotationResult).where(
|
||||
AnnotationResult.project_id == project_id,
|
||||
@@ -236,10 +247,44 @@ class AnnotationEditorService:
|
||||
|
||||
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
||||
|
||||
# 判断是否需要分段
|
||||
needs_segmentation = len(text_content) > self.SEGMENT_THRESHOLD
|
||||
segments: Optional[List[SegmentInfo]] = None
|
||||
current_segment_index = 0
|
||||
display_text = text_content
|
||||
|
||||
if needs_segmentation:
|
||||
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
||||
raw_segments = splitter.split(text_content)
|
||||
current_segment_index = segment_index if segment_index is not None else 0
|
||||
|
||||
# 校验段落索引
|
||||
if current_segment_index < 0 or current_segment_index >= len(raw_segments):
|
||||
current_segment_index = 0
|
||||
|
||||
# 标记每个段落是否已有标注
|
||||
segment_annotations: Dict[str, Any] = {}
|
||||
if ann and ann.annotation and ann.annotation.get("segmented"):
|
||||
segment_annotations = ann.annotation.get("segments", {})
|
||||
|
||||
segments = []
|
||||
for seg in raw_segments:
|
||||
segments.append(SegmentInfo(
|
||||
idx=seg["idx"],
|
||||
text=seg["text"],
|
||||
start=seg["start"],
|
||||
end=seg["end"],
|
||||
hasAnnotation=str(seg["idx"]) in segment_annotations,
|
||||
))
|
||||
|
||||
# 当前段落文本用于 task.data.text
|
||||
display_text = raw_segments[current_segment_index]["text"]
|
||||
|
||||
# 构造 task 对象
|
||||
task: Dict[str, Any] = {
|
||||
"id": ls_task_id,
|
||||
"data": {
|
||||
"text": text_content,
|
||||
"text": display_text,
|
||||
"file_id": file_id,
|
||||
"dataset_id": project.dataset_id,
|
||||
"file_name": getattr(file_record, "file_name", ""),
|
||||
@@ -250,15 +295,43 @@ class AnnotationEditorService:
|
||||
annotation_updated_at = None
|
||||
if ann:
|
||||
annotation_updated_at = ann.updated_at
|
||||
# 直接返回存储的 annotation 原始对象(Label Studio 兼容)
|
||||
stored = dict(ann.annotation or {})
|
||||
stored["task"] = ls_task_id
|
||||
if not isinstance(stored.get("id"), int):
|
||||
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||
task["annotations"] = [stored]
|
||||
|
||||
if needs_segmentation and ann.annotation and ann.annotation.get("segmented"):
|
||||
# 分段模式:获取当前段落的标注
|
||||
segment_annotations = ann.annotation.get("segments", {})
|
||||
seg_ann = segment_annotations.get(str(current_segment_index), {})
|
||||
stored = {
|
||||
"id": self._make_ls_annotation_id(project_id, file_id) + current_segment_index,
|
||||
"task": ls_task_id,
|
||||
"result": seg_ann.get("result", []),
|
||||
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
||||
"updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"),
|
||||
}
|
||||
task["annotations"] = [stored]
|
||||
elif not needs_segmentation:
|
||||
# 非分段模式:直接返回存储的 annotation 原始对象
|
||||
stored = dict(ann.annotation or {})
|
||||
stored["task"] = ls_task_id
|
||||
if not isinstance(stored.get("id"), int):
|
||||
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||
task["annotations"] = [stored]
|
||||
else:
|
||||
# 首次从非分段切换到分段:提供空标注
|
||||
empty_ann_id = self._make_ls_annotation_id(project_id, file_id) + current_segment_index
|
||||
task["annotations"] = [
|
||||
{
|
||||
"id": empty_ann_id,
|
||||
"task": ls_task_id,
|
||||
"result": [],
|
||||
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z",
|
||||
}
|
||||
]
|
||||
else:
|
||||
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
|
||||
empty_ann_id = self._make_ls_annotation_id(project_id, file_id)
|
||||
if needs_segmentation:
|
||||
empty_ann_id += current_segment_index
|
||||
task["annotations"] = [
|
||||
{
|
||||
"id": empty_ann_id,
|
||||
@@ -272,6 +345,10 @@ class AnnotationEditorService:
|
||||
return EditorTaskResponse(
|
||||
task=task,
|
||||
annotationUpdatedAt=annotation_updated_at,
|
||||
segmented=needs_segmentation,
|
||||
segments=segments,
|
||||
totalSegments=len(segments) if segments else 1,
|
||||
currentSegmentIndex=current_segment_index,
|
||||
)
|
||||
|
||||
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
|
||||
@@ -293,9 +370,6 @@ class AnnotationEditorService:
|
||||
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
|
||||
|
||||
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
||||
annotation_payload["task"] = ls_task_id
|
||||
if not isinstance(annotation_payload.get("id"), int):
|
||||
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||
|
||||
existing_result = await self.db.execute(
|
||||
select(AnnotationResult).where(
|
||||
@@ -307,12 +381,27 @@ class AnnotationEditorService:
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
# 判断是否为分段保存模式
|
||||
if request.segment_index is not None:
|
||||
# 分段模式:合并段落标注到整体结构
|
||||
final_payload = self._merge_segment_annotation(
|
||||
existing.annotation if existing else None,
|
||||
request.segment_index,
|
||||
annotation_payload,
|
||||
)
|
||||
else:
|
||||
# 非分段模式:直接使用传入的 annotation
|
||||
annotation_payload["task"] = ls_task_id
|
||||
if not isinstance(annotation_payload.get("id"), int):
|
||||
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||
final_payload = annotation_payload
|
||||
|
||||
if existing:
|
||||
if request.expected_updated_at and existing.updated_at:
|
||||
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
|
||||
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
|
||||
|
||||
existing.annotation = annotation_payload # type: ignore[assignment]
|
||||
existing.annotation = final_payload # type: ignore[assignment]
|
||||
existing.updated_at = now # type: ignore[assignment]
|
||||
await self.db.commit()
|
||||
await self.db.refresh(existing)
|
||||
@@ -327,7 +416,7 @@ class AnnotationEditorService:
|
||||
id=new_id,
|
||||
project_id=project_id,
|
||||
file_id=file_id,
|
||||
annotation=annotation_payload,
|
||||
annotation=final_payload,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
@@ -340,3 +429,39 @@ class AnnotationEditorService:
|
||||
updatedAt=record.updated_at or now,
|
||||
)
|
||||
|
||||
def _merge_segment_annotation(
|
||||
self,
|
||||
existing: Optional[Dict[str, Any]],
|
||||
segment_index: int,
|
||||
new_annotation: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并段落标注到整体结构
|
||||
|
||||
Args:
|
||||
existing: 现有的 annotation 数据
|
||||
segment_index: 段落索引
|
||||
new_annotation: 新的段落标注数据
|
||||
|
||||
Returns:
|
||||
合并后的 annotation 结构
|
||||
"""
|
||||
if not existing or not existing.get("segmented"):
|
||||
# 初始化分段结构
|
||||
base: Dict[str, Any] = {
|
||||
"segmented": True,
|
||||
"version": 1,
|
||||
"segments": {},
|
||||
}
|
||||
else:
|
||||
base = dict(existing)
|
||||
|
||||
# 更新指定段落的标注
|
||||
base["segments"][str(segment_index)] = {
|
||||
"result": new_annotation.get("result", []),
|
||||
"created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
||||
"updated_at": datetime.utcnow().isoformat() + "Z",
|
||||
}
|
||||
|
||||
return base
|
||||
|
||||
|
||||
Reference in New Issue
Block a user