feat(annotation): 优化文本标注分段功能实现

- 新增 getEditorTaskSegmentsUsingGet 接口用于获取任务分段信息
- 移除 SegmentInfo 中的 text、start、end 字段,精简数据结构
- 添加 EditorTaskSegmentsResponse 类型定义用于分段摘要响应
- 实现服务端 get_task_segments 方法,支持分段信息查询
- 重构前端组件缓存机制,使用 segmentSummaryFileRef 管理分段状态
- 优化分段构建逻辑,提取 _build_segment_contexts 公共方法
- 调整后端 _build_text_task 方法中的分段处理流程
- 更新 API 类型定义,统一 RequestParams 和 RequestPayload 类型
This commit is contained in:
2026-02-04 16:59:04 +08:00
parent 394e2bda18
commit cda22a720c
5 changed files with 250 additions and 108 deletions

View File

@@ -20,6 +20,7 @@ from app.module.annotation.schema.editor import (
EditorProjectInfo,
EditorTaskListResponse,
EditorTaskResponse,
EditorTaskSegmentsResponse,
UpsertAnnotationRequest,
UpsertAnnotationResponse,
)
@@ -87,6 +88,20 @@ async def get_editor_task(
return StandardResponse(code=200, message="success", data=task)
@router.get(
"/projects/{project_id}/tasks/{file_id}/segments",
response_model=StandardResponse[EditorTaskSegmentsResponse],
)
async def list_editor_task_segments(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
):
service = AnnotationEditorService(db)
result = await service.get_task_segments(project_id, file_id)
return StandardResponse(code=200, message="success", data=result)
@router.put(
"/projects/{project_id}/tasks/{file_id}/annotation",
response_model=StandardResponse[UpsertAnnotationResponse],

View File

@@ -79,12 +79,9 @@ class EditorTaskListResponse(BaseModel):
class SegmentInfo(BaseModel):
"""段落信息(用于文本分段标注)"""
"""段落摘要(用于文本分段标注)"""
idx: int = Field(..., description="段落索引")
text: str = Field(..., description="段落文本")
start: int = Field(..., description="在原文中的起始位置")
end: int = Field(..., description="在原文中的结束位置")
has_annotation: bool = Field(False, alias="hasAnnotation", description="该段落是否已有标注")
line_index: int = Field(0, alias="lineIndex", description="JSONL 行索引(从0开始)")
chunk_index: int = Field(0, alias="chunkIndex", description="行内分片索引(从0开始)")
@@ -100,13 +97,22 @@ class EditorTaskResponse(BaseModel):
# 分段相关字段
segmented: bool = Field(False, description="是否启用分段模式")
segments: Optional[List[SegmentInfo]] = Field(None, description="段落列表")
total_segments: int = Field(0, alias="totalSegments", description="总段落数")
current_segment_index: int = Field(0, alias="currentSegmentIndex", description="当前段落索引")
model_config = ConfigDict(populate_by_name=True)
class EditorTaskSegmentsResponse(BaseModel):
"""编辑器段落摘要响应"""
segmented: bool = Field(False, description="是否启用分段模式")
segments: List[SegmentInfo] = Field(default_factory=list, description="段落摘要列表")
total_segments: int = Field(0, alias="totalSegments", description="总段落数")
model_config = ConfigDict(populate_by_name=True)
class UpsertAnnotationRequest(BaseModel):
"""保存/覆盖最终标注(Label Studio annotation 原始对象)"""

View File

@@ -37,6 +37,7 @@ from app.module.annotation.schema.editor import (
EditorTaskListItem,
EditorTaskListResponse,
EditorTaskResponse,
EditorTaskSegmentsResponse,
SegmentInfo,
UpsertAnnotationRequest,
UpsertAnnotationResponse,
@@ -538,6 +539,49 @@ class AnnotationEditorService:
return value
return raw_text
def _build_segment_contexts(
self,
records: List[Tuple[Optional[Dict[str, Any]], str]],
record_texts: List[str],
segment_annotation_keys: set[str],
) -> Tuple[List[SegmentInfo], List[Tuple[Optional[Dict[str, Any]], str, str, int, int]]]:
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
segment_cursor = 0
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
normalized_text = record_text or ""
if len(normalized_text) > self.SEGMENT_THRESHOLD:
raw_segments = splitter.split(normalized_text)
for chunk_index, seg in enumerate(raw_segments):
segments.append(
SegmentInfo(
idx=segment_cursor,
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=chunk_index,
)
)
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
segment_cursor += 1
else:
segments.append(
SegmentInfo(
idx=segment_cursor,
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=0,
)
)
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
segment_cursor += 1
if not segments:
segments = [SegmentInfo(idx=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
segment_contexts = [(None, "", "", 0, 0)]
return segments, segment_contexts
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
project = await self._get_project_or_404(project_id)
@@ -668,6 +712,87 @@ class AnnotationEditorService:
return await self._build_text_task(project, file_record, file_id, segment_index)
async def get_task_segments(
self,
project_id: str,
file_id: str,
) -> EditorTaskSegmentsResponse:
project = await self._get_project_or_404(project_id)
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
if dataset_type != DATASET_TYPE_TEXT:
raise HTTPException(
status_code=400,
detail="当前仅支持 TEXT 项目的段落摘要",
)
file_result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == project.dataset_id,
)
)
file_record = file_result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
if not self._resolve_segmentation_enabled(project):
return EditorTaskSegmentsResponse(segmented=False, segments=[], totalSegments=0)
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
assert isinstance(text_content, str)
label_config = await self._resolve_project_label_config(project)
primary_text_key = self._resolve_primary_text_key(label_config)
file_name = str(getattr(file_record, "file_name", "")).lower()
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
if file_name.endswith(JSONL_EXTENSION):
records = self._parse_jsonl_records(text_content)
else:
parsed_payload = self._try_parse_json_payload(text_content)
if parsed_payload:
records = [(parsed_payload, text_content)]
if not records:
records = [(None, text_content)]
record_texts = [
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
for payload, raw_text in records
]
if not record_texts:
record_texts = [text_content]
needs_segmentation = len(records) > 1 or any(
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
)
if not needs_segmentation:
return EditorTaskSegmentsResponse(segmented=False, segments=[], totalSegments=0)
ann_result = await self.db.execute(
select(AnnotationResult).where(
AnnotationResult.project_id == project.id,
AnnotationResult.file_id == file_id,
)
)
ann = ann_result.scalar_one_or_none()
segment_annotations: Dict[str, Dict[str, Any]] = {}
if ann and isinstance(ann.annotation, dict):
segment_annotations = self._extract_segment_annotations(ann.annotation)
segment_annotation_keys = set(segment_annotations.keys())
segments, _ = self._build_segment_contexts(
records,
record_texts,
segment_annotation_keys,
)
return EditorTaskSegmentsResponse(
segmented=True,
segments=segments,
totalSegments=len(segments),
)
async def _build_text_task(
self,
project: LabelingProject,
@@ -723,7 +848,8 @@ class AnnotationEditorService:
needs_segmentation = segmentation_enabled and (
len(records) > 1 or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
)
segments: Optional[List[SegmentInfo]] = None
segments: List[SegmentInfo] = []
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
current_segment_index = 0
display_text = record_texts[0] if record_texts else text_content
selected_payload = records[0][0] if records else None
@@ -732,46 +858,13 @@ class AnnotationEditorService:
display_text = "\n".join(record_texts) if record_texts else text_content
if needs_segmentation:
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
segments = []
segment_cursor = 0
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
normalized_text = record_text or ""
if len(normalized_text) > self.SEGMENT_THRESHOLD:
raw_segments = splitter.split(normalized_text)
for chunk_index, seg in enumerate(raw_segments):
segments.append(SegmentInfo(
idx=segment_cursor,
text=seg["text"],
start=seg["start"],
end=seg["end"],
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=chunk_index,
))
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
segment_cursor += 1
else:
segments.append(SegmentInfo(
idx=segment_cursor,
text=normalized_text,
start=0,
end=len(normalized_text),
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
lineIndex=record_index,
chunkIndex=0,
))
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
segment_cursor += 1
if not segments:
segments = [SegmentInfo(idx=0, text="", start=0, end=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
segment_contexts = [(None, "", "", 0, 0)]
_, segment_contexts = self._build_segment_contexts(
records,
record_texts,
segment_annotation_keys,
)
current_segment_index = segment_index if segment_index is not None else 0
if current_segment_index < 0 or current_segment_index >= len(segments):
if current_segment_index < 0 or current_segment_index >= len(segment_contexts):
current_segment_index = 0
selected_payload, _, display_text, _, _ = segment_contexts[current_segment_index]
@@ -849,8 +942,7 @@ class AnnotationEditorService:
task=task,
annotationUpdatedAt=annotation_updated_at,
segmented=needs_segmentation,
segments=segments,
totalSegments=len(segments) if segments else 1,
totalSegments=len(segment_contexts) if needs_segmentation else 1,
currentSegmentIndex=current_segment_index,
)