You've already forked DataMate
feat(annotation): 优化文本标注分段功能实现
- 新增 getEditorTaskSegmentsUsingGet 接口用于获取任务分段信息 - 移除 SegmentInfo 中的 text、start、end 字段,精简数据结构 - 添加 EditorTaskSegmentsResponse 类型定义用于分段摘要响应 - 实现服务端 get_task_segments 方法,支持分段信息查询 - 重构前端组件缓存机制,使用 segmentSummaryFileRef 管理分段状态 - 优化分段构建逻辑,提取 _build_segment_contexts 公共方法 - 调整后端 _build_text_task 方法中的分段处理流程 - 更新 API 类型定义,统一 RequestParams 和 RequestPayload 类型
This commit is contained in:
@@ -6,6 +6,7 @@ import { useNavigate, useParams } from "react-router";
|
||||
import {
|
||||
getEditorProjectInfoUsingGet,
|
||||
getEditorTaskUsingGet,
|
||||
getEditorTaskSegmentsUsingGet,
|
||||
listEditorTasksUsingGet,
|
||||
upsertEditorAnnotationUsingPut,
|
||||
} from "../annotation.api";
|
||||
@@ -38,9 +39,6 @@ type LsfMessage = {
|
||||
|
||||
type SegmentInfo = {
|
||||
idx: number;
|
||||
text: string;
|
||||
start: number;
|
||||
end: number;
|
||||
hasAnnotation: boolean;
|
||||
lineIndex: number;
|
||||
chunkIndex: number;
|
||||
@@ -66,10 +64,16 @@ type EditorTaskPayload = {
|
||||
type EditorTaskResponse = {
|
||||
task?: EditorTaskPayload;
|
||||
segmented?: boolean;
|
||||
segments?: SegmentInfo[];
|
||||
totalSegments?: number;
|
||||
currentSegmentIndex?: number;
|
||||
};
|
||||
|
||||
type EditorTaskSegmentsResponse = {
|
||||
segmented?: boolean;
|
||||
segments?: SegmentInfo[];
|
||||
totalSegments?: number;
|
||||
};
|
||||
|
||||
type EditorTaskListResponse = {
|
||||
content?: EditorTaskListItem[];
|
||||
totalElements?: number;
|
||||
@@ -288,6 +292,7 @@ export default function LabelStudioTextEditor() {
|
||||
const segmentStatsCacheRef = useRef<Record<string, SegmentStats>>({});
|
||||
const segmentStatsSeqRef = useRef(0);
|
||||
const segmentStatsLoadingRef = useRef<Set<string>>(new Set());
|
||||
const segmentSummaryFileRef = useRef<string>("");
|
||||
|
||||
const [loadingProject, setLoadingProject] = useState(true);
|
||||
const [loadingTasks, setLoadingTasks] = useState(false);
|
||||
@@ -358,9 +363,7 @@ export default function LabelStudioTextEditor() {
|
||||
if (segmentStatsCacheRef.current[fileId] || segmentStatsLoadingRef.current.has(fileId)) return;
|
||||
segmentStatsLoadingRef.current.add(fileId);
|
||||
try {
|
||||
const resp = (await getEditorTaskUsingGet(projectId, fileId, {
|
||||
segmentIndex: 0,
|
||||
})) as ApiResponse<EditorTaskResponse>;
|
||||
const resp = (await getEditorTaskSegmentsUsingGet(projectId, fileId)) as ApiResponse<EditorTaskSegmentsResponse>;
|
||||
if (segmentStatsSeqRef.current !== seq) return;
|
||||
const data = resp?.data;
|
||||
if (!data?.segmented) return;
|
||||
@@ -591,20 +594,38 @@ export default function LabelStudioTextEditor() {
|
||||
if (seq !== initSeqRef.current) return;
|
||||
|
||||
// 更新分段状态
|
||||
const segmentIndex = data?.segmented
|
||||
const isSegmented = !!data?.segmented;
|
||||
const segmentIndex = isSegmented
|
||||
? resolveSegmentIndex(data.currentSegmentIndex) ?? 0
|
||||
: undefined;
|
||||
if (data?.segmented) {
|
||||
const stats = buildSegmentStats(data.segments);
|
||||
if (isSegmented) {
|
||||
let nextSegments: SegmentInfo[] = [];
|
||||
if (segmentSummaryFileRef.current === fileId && segments.length > 0) {
|
||||
nextSegments = segments;
|
||||
} else {
|
||||
try {
|
||||
const segmentResp = (await getEditorTaskSegmentsUsingGet(projectId, fileId)) as ApiResponse<EditorTaskSegmentsResponse>;
|
||||
if (seq !== initSeqRef.current) return;
|
||||
const segmentData = segmentResp?.data;
|
||||
if (segmentData?.segmented) {
|
||||
nextSegments = Array.isArray(segmentData.segments) ? segmentData.segments : [];
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
const stats = buildSegmentStats(nextSegments);
|
||||
setSegmented(true);
|
||||
setSegments(data.segments || []);
|
||||
setSegments(nextSegments);
|
||||
setCurrentSegmentIndex(segmentIndex ?? 0);
|
||||
updateSegmentStatsCache(fileId, stats);
|
||||
segmentSummaryFileRef.current = fileId;
|
||||
} else {
|
||||
setSegmented(false);
|
||||
setSegments([]);
|
||||
setCurrentSegmentIndex(0);
|
||||
updateSegmentStatsCache(fileId, null);
|
||||
segmentSummaryFileRef.current = fileId;
|
||||
}
|
||||
|
||||
const taskData = {
|
||||
@@ -664,7 +685,7 @@ export default function LabelStudioTextEditor() {
|
||||
} finally {
|
||||
if (seq === initSeqRef.current) setLoadingTaskDetail(false);
|
||||
}
|
||||
}, [iframeReady, message, postToIframe, project, projectId, updateSegmentStatsCache]);
|
||||
}, [iframeReady, message, postToIframe, project, projectId, segments, updateSegmentStatsCache]);
|
||||
|
||||
const advanceAfterSave = useCallback(async (fileId: string, segmentIndex?: number) => {
|
||||
if (!fileId) return;
|
||||
@@ -979,6 +1000,7 @@ export default function LabelStudioTextEditor() {
|
||||
setSegmented(false);
|
||||
setSegments([]);
|
||||
setCurrentSegmentIndex(0);
|
||||
segmentSummaryFileRef.current = "";
|
||||
savedSnapshotsRef.current = {};
|
||||
segmentStatsSeqRef.current += 1;
|
||||
segmentStatsCacheRef.current = {};
|
||||
|
||||
@@ -3,16 +3,19 @@ import { get, post, put, del, download } from "@/utils/request";
|
||||
// 导出格式类型
|
||||
export type ExportFormat = "json" | "jsonl" | "csv" | "coco" | "yolo";
|
||||
|
||||
type RequestParams = Record<string, unknown>;
|
||||
type RequestPayload = Record<string, unknown>;
|
||||
|
||||
// 标注任务管理相关接口
|
||||
export function queryAnnotationTasksUsingGet(params?: any) {
|
||||
export function queryAnnotationTasksUsingGet(params?: RequestParams) {
|
||||
return get("/api/annotation/project", params);
|
||||
}
|
||||
|
||||
export function createAnnotationTaskUsingPost(data: any) {
|
||||
export function createAnnotationTaskUsingPost(data: RequestPayload) {
|
||||
return post("/api/annotation/project", data);
|
||||
}
|
||||
|
||||
export function syncAnnotationTaskUsingPost(data: any) {
|
||||
export function syncAnnotationTaskUsingPost(data: RequestPayload) {
|
||||
return post(`/api/annotation/task/sync`, data);
|
||||
}
|
||||
|
||||
@@ -25,7 +28,7 @@ export function getAnnotationTaskByIdUsingGet(taskId: string) {
|
||||
return get(`/api/annotation/project/${taskId}`);
|
||||
}
|
||||
|
||||
export function updateAnnotationTaskByIdUsingPut(taskId: string, data: any) {
|
||||
export function updateAnnotationTaskByIdUsingPut(taskId: string, data: RequestPayload) {
|
||||
return put(`/api/annotation/project/${taskId}`, data);
|
||||
}
|
||||
|
||||
@@ -35,17 +38,17 @@ export function getTagConfigUsingGet() {
|
||||
}
|
||||
|
||||
// 标注模板管理
|
||||
export function queryAnnotationTemplatesUsingGet(params?: any) {
|
||||
export function queryAnnotationTemplatesUsingGet(params?: RequestParams) {
|
||||
return get("/api/annotation/template", params);
|
||||
}
|
||||
|
||||
export function createAnnotationTemplateUsingPost(data: any) {
|
||||
export function createAnnotationTemplateUsingPost(data: RequestPayload) {
|
||||
return post("/api/annotation/template", data);
|
||||
}
|
||||
|
||||
export function updateAnnotationTemplateByIdUsingPut(
|
||||
templateId: string | number,
|
||||
data: any
|
||||
data: RequestPayload
|
||||
) {
|
||||
return put(`/api/annotation/template/${templateId}`, data);
|
||||
}
|
||||
@@ -65,7 +68,7 @@ export function getEditorProjectInfoUsingGet(projectId: string) {
|
||||
return get(`/api/annotation/editor/projects/${projectId}`);
|
||||
}
|
||||
|
||||
export function listEditorTasksUsingGet(projectId: string, params?: any) {
|
||||
export function listEditorTasksUsingGet(projectId: string, params?: RequestParams) {
|
||||
return get(`/api/annotation/editor/projects/${projectId}/tasks`, params);
|
||||
}
|
||||
|
||||
@@ -77,11 +80,15 @@ export function getEditorTaskUsingGet(
|
||||
return get(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}`, params);
|
||||
}
|
||||
|
||||
export function getEditorTaskSegmentsUsingGet(projectId: string, fileId: string) {
|
||||
return get(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}/segments`);
|
||||
}
|
||||
|
||||
export function upsertEditorAnnotationUsingPut(
|
||||
projectId: string,
|
||||
fileId: string,
|
||||
data: {
|
||||
annotation: any;
|
||||
annotation: Record<string, unknown>;
|
||||
expectedUpdatedAt?: string;
|
||||
segmentIndex?: number;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.module.annotation.schema.editor import (
|
||||
EditorProjectInfo,
|
||||
EditorTaskListResponse,
|
||||
EditorTaskResponse,
|
||||
EditorTaskSegmentsResponse,
|
||||
UpsertAnnotationRequest,
|
||||
UpsertAnnotationResponse,
|
||||
)
|
||||
@@ -87,6 +88,20 @@ async def get_editor_task(
|
||||
return StandardResponse(code=200, message="success", data=task)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/tasks/{file_id}/segments",
|
||||
response_model=StandardResponse[EditorTaskSegmentsResponse],
|
||||
)
|
||||
async def list_editor_task_segments(
|
||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = AnnotationEditorService(db)
|
||||
result = await service.get_task_segments(project_id, file_id)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/projects/{project_id}/tasks/{file_id}/annotation",
|
||||
response_model=StandardResponse[UpsertAnnotationResponse],
|
||||
|
||||
@@ -79,12 +79,9 @@ class EditorTaskListResponse(BaseModel):
|
||||
|
||||
|
||||
class SegmentInfo(BaseModel):
|
||||
"""段落信息(用于文本分段标注)"""
|
||||
"""段落摘要(用于文本分段标注)"""
|
||||
|
||||
idx: int = Field(..., description="段落索引")
|
||||
text: str = Field(..., description="段落文本")
|
||||
start: int = Field(..., description="在原文中的起始位置")
|
||||
end: int = Field(..., description="在原文中的结束位置")
|
||||
has_annotation: bool = Field(False, alias="hasAnnotation", description="该段落是否已有标注")
|
||||
line_index: int = Field(0, alias="lineIndex", description="JSONL 行索引(从0开始)")
|
||||
chunk_index: int = Field(0, alias="chunkIndex", description="行内分片索引(从0开始)")
|
||||
@@ -100,13 +97,22 @@ class EditorTaskResponse(BaseModel):
|
||||
|
||||
# 分段相关字段
|
||||
segmented: bool = Field(False, description="是否启用分段模式")
|
||||
segments: Optional[List[SegmentInfo]] = Field(None, description="段落列表")
|
||||
total_segments: int = Field(0, alias="totalSegments", description="总段落数")
|
||||
current_segment_index: int = Field(0, alias="currentSegmentIndex", description="当前段落索引")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class EditorTaskSegmentsResponse(BaseModel):
|
||||
"""编辑器段落摘要响应"""
|
||||
|
||||
segmented: bool = Field(False, description="是否启用分段模式")
|
||||
segments: List[SegmentInfo] = Field(default_factory=list, description="段落摘要列表")
|
||||
total_segments: int = Field(0, alias="totalSegments", description="总段落数")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class UpsertAnnotationRequest(BaseModel):
|
||||
"""保存/覆盖最终标注(Label Studio annotation 原始对象)"""
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from app.module.annotation.schema.editor import (
|
||||
EditorTaskListItem,
|
||||
EditorTaskListResponse,
|
||||
EditorTaskResponse,
|
||||
EditorTaskSegmentsResponse,
|
||||
SegmentInfo,
|
||||
UpsertAnnotationRequest,
|
||||
UpsertAnnotationResponse,
|
||||
@@ -538,6 +539,49 @@ class AnnotationEditorService:
|
||||
return value
|
||||
return raw_text
|
||||
|
||||
def _build_segment_contexts(
|
||||
self,
|
||||
records: List[Tuple[Optional[Dict[str, Any]], str]],
|
||||
record_texts: List[str],
|
||||
segment_annotation_keys: set[str],
|
||||
) -> Tuple[List[SegmentInfo], List[Tuple[Optional[Dict[str, Any]], str, str, int, int]]]:
|
||||
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
||||
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
||||
segment_cursor = 0
|
||||
|
||||
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
|
||||
normalized_text = record_text or ""
|
||||
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
||||
raw_segments = splitter.split(normalized_text)
|
||||
for chunk_index, seg in enumerate(raw_segments):
|
||||
segments.append(
|
||||
SegmentInfo(
|
||||
idx=segment_cursor,
|
||||
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
|
||||
lineIndex=record_index,
|
||||
chunkIndex=chunk_index,
|
||||
)
|
||||
)
|
||||
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
|
||||
segment_cursor += 1
|
||||
else:
|
||||
segments.append(
|
||||
SegmentInfo(
|
||||
idx=segment_cursor,
|
||||
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
|
||||
lineIndex=record_index,
|
||||
chunkIndex=0,
|
||||
)
|
||||
)
|
||||
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
|
||||
segment_cursor += 1
|
||||
|
||||
if not segments:
|
||||
segments = [SegmentInfo(idx=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
|
||||
segment_contexts = [(None, "", "", 0, 0)]
|
||||
|
||||
return segments, segment_contexts
|
||||
|
||||
async def get_project_info(self, project_id: str) -> EditorProjectInfo:
|
||||
project = await self._get_project_or_404(project_id)
|
||||
|
||||
@@ -668,6 +712,87 @@ class AnnotationEditorService:
|
||||
|
||||
return await self._build_text_task(project, file_record, file_id, segment_index)
|
||||
|
||||
async def get_task_segments(
|
||||
self,
|
||||
project_id: str,
|
||||
file_id: str,
|
||||
) -> EditorTaskSegmentsResponse:
|
||||
project = await self._get_project_or_404(project_id)
|
||||
|
||||
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
|
||||
if dataset_type != DATASET_TYPE_TEXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="当前仅支持 TEXT 项目的段落摘要",
|
||||
)
|
||||
|
||||
file_result = await self.db.execute(
|
||||
select(DatasetFiles).where(
|
||||
DatasetFiles.id == file_id,
|
||||
DatasetFiles.dataset_id == project.dataset_id,
|
||||
)
|
||||
)
|
||||
file_record = file_result.scalar_one_or_none()
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail=f"文件不存在或不属于该项目: {file_id}")
|
||||
|
||||
if not self._resolve_segmentation_enabled(project):
|
||||
return EditorTaskSegmentsResponse(segmented=False, segments=[], totalSegments=0)
|
||||
|
||||
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
||||
assert isinstance(text_content, str)
|
||||
label_config = await self._resolve_project_label_config(project)
|
||||
primary_text_key = self._resolve_primary_text_key(label_config)
|
||||
file_name = str(getattr(file_record, "file_name", "")).lower()
|
||||
|
||||
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
||||
if file_name.endswith(JSONL_EXTENSION):
|
||||
records = self._parse_jsonl_records(text_content)
|
||||
else:
|
||||
parsed_payload = self._try_parse_json_payload(text_content)
|
||||
if parsed_payload:
|
||||
records = [(parsed_payload, text_content)]
|
||||
|
||||
if not records:
|
||||
records = [(None, text_content)]
|
||||
|
||||
record_texts = [
|
||||
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
||||
for payload, raw_text in records
|
||||
]
|
||||
if not record_texts:
|
||||
record_texts = [text_content]
|
||||
|
||||
needs_segmentation = len(records) > 1 or any(
|
||||
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
|
||||
)
|
||||
if not needs_segmentation:
|
||||
return EditorTaskSegmentsResponse(segmented=False, segments=[], totalSegments=0)
|
||||
|
||||
ann_result = await self.db.execute(
|
||||
select(AnnotationResult).where(
|
||||
AnnotationResult.project_id == project.id,
|
||||
AnnotationResult.file_id == file_id,
|
||||
)
|
||||
)
|
||||
ann = ann_result.scalar_one_or_none()
|
||||
segment_annotations: Dict[str, Dict[str, Any]] = {}
|
||||
if ann and isinstance(ann.annotation, dict):
|
||||
segment_annotations = self._extract_segment_annotations(ann.annotation)
|
||||
segment_annotation_keys = set(segment_annotations.keys())
|
||||
|
||||
segments, _ = self._build_segment_contexts(
|
||||
records,
|
||||
record_texts,
|
||||
segment_annotation_keys,
|
||||
)
|
||||
|
||||
return EditorTaskSegmentsResponse(
|
||||
segmented=True,
|
||||
segments=segments,
|
||||
totalSegments=len(segments),
|
||||
)
|
||||
|
||||
async def _build_text_task(
|
||||
self,
|
||||
project: LabelingProject,
|
||||
@@ -723,7 +848,8 @@ class AnnotationEditorService:
|
||||
needs_segmentation = segmentation_enabled and (
|
||||
len(records) > 1 or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
|
||||
)
|
||||
segments: Optional[List[SegmentInfo]] = None
|
||||
segments: List[SegmentInfo] = []
|
||||
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
||||
current_segment_index = 0
|
||||
display_text = record_texts[0] if record_texts else text_content
|
||||
selected_payload = records[0][0] if records else None
|
||||
@@ -732,46 +858,13 @@ class AnnotationEditorService:
|
||||
display_text = "\n".join(record_texts) if record_texts else text_content
|
||||
|
||||
if needs_segmentation:
|
||||
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
||||
segment_contexts: List[Tuple[Optional[Dict[str, Any]], str, str, int, int]] = []
|
||||
segments = []
|
||||
segment_cursor = 0
|
||||
|
||||
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
|
||||
normalized_text = record_text or ""
|
||||
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
||||
raw_segments = splitter.split(normalized_text)
|
||||
for chunk_index, seg in enumerate(raw_segments):
|
||||
segments.append(SegmentInfo(
|
||||
idx=segment_cursor,
|
||||
text=seg["text"],
|
||||
start=seg["start"],
|
||||
end=seg["end"],
|
||||
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
|
||||
lineIndex=record_index,
|
||||
chunkIndex=chunk_index,
|
||||
))
|
||||
segment_contexts.append((payload, raw_text, seg["text"], record_index, chunk_index))
|
||||
segment_cursor += 1
|
||||
else:
|
||||
segments.append(SegmentInfo(
|
||||
idx=segment_cursor,
|
||||
text=normalized_text,
|
||||
start=0,
|
||||
end=len(normalized_text),
|
||||
hasAnnotation=str(segment_cursor) in segment_annotation_keys,
|
||||
lineIndex=record_index,
|
||||
chunkIndex=0,
|
||||
))
|
||||
segment_contexts.append((payload, raw_text, normalized_text, record_index, 0))
|
||||
segment_cursor += 1
|
||||
|
||||
if not segments:
|
||||
segments = [SegmentInfo(idx=0, text="", start=0, end=0, hasAnnotation=False, lineIndex=0, chunkIndex=0)]
|
||||
segment_contexts = [(None, "", "", 0, 0)]
|
||||
|
||||
_, segment_contexts = self._build_segment_contexts(
|
||||
records,
|
||||
record_texts,
|
||||
segment_annotation_keys,
|
||||
)
|
||||
current_segment_index = segment_index if segment_index is not None else 0
|
||||
if current_segment_index < 0 or current_segment_index >= len(segments):
|
||||
if current_segment_index < 0 or current_segment_index >= len(segment_contexts):
|
||||
current_segment_index = 0
|
||||
|
||||
selected_payload, _, display_text, _, _ = segment_contexts[current_segment_index]
|
||||
@@ -849,8 +942,7 @@ class AnnotationEditorService:
|
||||
task=task,
|
||||
annotationUpdatedAt=annotation_updated_at,
|
||||
segmented=needs_segmentation,
|
||||
segments=segments,
|
||||
totalSegments=len(segments) if segments else 1,
|
||||
totalSegments=len(segment_contexts) if needs_segmentation else 1,
|
||||
currentSegmentIndex=current_segment_index,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user