You've already forked DataMate
Merge branch 'editor_next' into lsf
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import { useEffect, useMemo, useRef, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { App, Button, Card, List, Spin, Typography } from "antd";
|
import { App, Button, Card, List, Spin, Typography, Tag } from "antd";
|
||||||
import { LeftOutlined, ReloadOutlined, SaveOutlined, MenuFoldOutlined, MenuUnfoldOutlined } from "@ant-design/icons";
|
import { LeftOutlined, ReloadOutlined, SaveOutlined, MenuFoldOutlined, MenuUnfoldOutlined, CheckOutlined } from "@ant-design/icons";
|
||||||
import { useNavigate, useParams } from "react-router";
|
import { useNavigate, useParams } from "react-router";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@@ -32,6 +32,14 @@ type LsfMessage = {
|
|||||||
payload?: any;
|
payload?: any;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type SegmentInfo = {
|
||||||
|
idx: number;
|
||||||
|
text: string;
|
||||||
|
start: number;
|
||||||
|
end: number;
|
||||||
|
hasAnnotation: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
const LSF_IFRAME_SRC = "/lsf/lsf.html";
|
const LSF_IFRAME_SRC = "/lsf/lsf.html";
|
||||||
|
|
||||||
export default function LabelStudioTextEditor() {
|
export default function LabelStudioTextEditor() {
|
||||||
@@ -56,6 +64,11 @@ export default function LabelStudioTextEditor() {
|
|||||||
const [selectedFileId, setSelectedFileId] = useState<string>("");
|
const [selectedFileId, setSelectedFileId] = useState<string>("");
|
||||||
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
|
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
|
||||||
|
|
||||||
|
// 分段相关状态
|
||||||
|
const [segmented, setSegmented] = useState(false);
|
||||||
|
const [segments, setSegments] = useState<SegmentInfo[]>([]);
|
||||||
|
const [currentSegmentIndex, setCurrentSegmentIndex] = useState(0);
|
||||||
|
|
||||||
const postToIframe = (type: string, payload?: any) => {
|
const postToIframe = (type: string, payload?: any) => {
|
||||||
const win = iframeRef.current?.contentWindow;
|
const win = iframeRef.current?.contentWindow;
|
||||||
if (!win) return;
|
if (!win) return;
|
||||||
@@ -102,7 +115,7 @@ export default function LabelStudioTextEditor() {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const initEditorForFile = async (fileId: string) => {
|
const initEditorForFile = async (fileId: string, segmentIdx?: number) => {
|
||||||
if (!project?.supported) return;
|
if (!project?.supported) return;
|
||||||
if (!project?.labelConfig) {
|
if (!project?.labelConfig) {
|
||||||
message.error("该项目未绑定标注模板,无法加载编辑器");
|
message.error("该项目未绑定标注模板,无法加载编辑器");
|
||||||
@@ -116,14 +129,28 @@ export default function LabelStudioTextEditor() {
|
|||||||
expectedTaskIdRef.current = null;
|
expectedTaskIdRef.current = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const resp = (await getEditorTaskUsingGet(projectId, fileId)) as any;
|
const resp = (await getEditorTaskUsingGet(projectId, fileId, {
|
||||||
const task = resp?.data?.task;
|
segmentIndex: segmentIdx,
|
||||||
|
})) as any;
|
||||||
|
const data = resp?.data;
|
||||||
|
const task = data?.task;
|
||||||
if (!task) {
|
if (!task) {
|
||||||
message.error("获取任务详情失败");
|
message.error("获取任务详情失败");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (seq !== initSeqRef.current) return;
|
if (seq !== initSeqRef.current) return;
|
||||||
|
|
||||||
|
// 更新分段状态
|
||||||
|
if (data?.segmented) {
|
||||||
|
setSegmented(true);
|
||||||
|
setSegments(data.segments || []);
|
||||||
|
setCurrentSegmentIndex(data.currentSegmentIndex || 0);
|
||||||
|
} else {
|
||||||
|
setSegmented(false);
|
||||||
|
setSegments([]);
|
||||||
|
setCurrentSegmentIndex(0);
|
||||||
|
}
|
||||||
|
|
||||||
expectedTaskIdRef.current = Number(task?.id) || null;
|
expectedTaskIdRef.current = Number(task?.id) || null;
|
||||||
postToIframe("LS_INIT", {
|
postToIframe("LS_INIT", {
|
||||||
labelConfig: project.labelConfig,
|
labelConfig: project.labelConfig,
|
||||||
@@ -173,9 +200,23 @@ export default function LabelStudioTextEditor() {
|
|||||||
|
|
||||||
setSaving(true);
|
setSaving(true);
|
||||||
try {
|
try {
|
||||||
await upsertEditorAnnotationUsingPut(projectId, String(fileId), { annotation });
|
await upsertEditorAnnotationUsingPut(projectId, String(fileId), {
|
||||||
|
annotation,
|
||||||
|
segmentIndex: segmented ? currentSegmentIndex : undefined,
|
||||||
|
});
|
||||||
message.success("标注已保存");
|
message.success("标注已保存");
|
||||||
await loadTasks(true);
|
await loadTasks(true);
|
||||||
|
|
||||||
|
// 分段模式下更新当前段落的标注状态
|
||||||
|
if (segmented) {
|
||||||
|
setSegments((prev) =>
|
||||||
|
prev.map((seg) =>
|
||||||
|
seg.idx === currentSegmentIndex
|
||||||
|
? { ...seg, hasAnnotation: true }
|
||||||
|
: seg
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
message.error("保存失败");
|
message.error("保存失败");
|
||||||
@@ -192,6 +233,13 @@ export default function LabelStudioTextEditor() {
|
|||||||
postToIframe("LS_EXPORT", {});
|
postToIframe("LS_EXPORT", {});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 段落切换处理
|
||||||
|
const handleSegmentChange = async (newIndex: number) => {
|
||||||
|
if (newIndex === currentSegmentIndex) return;
|
||||||
|
setCurrentSegmentIndex(newIndex);
|
||||||
|
await initEditorForFile(selectedFileId, newIndex);
|
||||||
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setIframeReady(false);
|
setIframeReady(false);
|
||||||
setProject(null);
|
setProject(null);
|
||||||
@@ -200,6 +248,10 @@ export default function LabelStudioTextEditor() {
|
|||||||
initSeqRef.current = 0;
|
initSeqRef.current = 0;
|
||||||
setLsReady(false);
|
setLsReady(false);
|
||||||
expectedTaskIdRef.current = null;
|
expectedTaskIdRef.current = null;
|
||||||
|
// 重置分段状态
|
||||||
|
setSegmented(false);
|
||||||
|
setSegments([]);
|
||||||
|
setCurrentSegmentIndex(0);
|
||||||
|
|
||||||
if (projectId) loadProject();
|
if (projectId) loadProject();
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
@@ -379,26 +431,55 @@ export default function LabelStudioTextEditor() {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 右侧编辑器 - Label Studio iframe */}
|
{/* 右侧编辑器 - Label Studio iframe */}
|
||||||
<div className="flex-1 relative">
|
<div className="flex-1 flex flex-col min-h-0">
|
||||||
{(!iframeReady || loadingTaskDetail || (selectedFileId && !lsReady)) && (
|
{/* 段落导航栏 */}
|
||||||
<div className="absolute inset-0 z-10 flex items-center justify-center bg-white/80">
|
{segmented && segments.length > 0 && (
|
||||||
<Spin
|
<div className="flex items-center gap-2 px-3 py-2 bg-gray-50 border-b border-gray-200">
|
||||||
tip={
|
<Typography.Text style={{ fontSize: 12 }}>段落:</Typography.Text>
|
||||||
!iframeReady
|
<div className="flex gap-1 flex-wrap">
|
||||||
? "编辑器资源加载中..."
|
{segments.map((seg) => (
|
||||||
: loadingTaskDetail
|
<Button
|
||||||
? "任务数据加载中..."
|
key={seg.idx}
|
||||||
: "编辑器初始化中..."
|
size="small"
|
||||||
}
|
type={seg.idx === currentSegmentIndex ? "primary" : "default"}
|
||||||
/>
|
onClick={() => handleSegmentChange(seg.idx)}
|
||||||
|
style={{ minWidth: 32, padding: "0 8px" }}
|
||||||
|
>
|
||||||
|
{seg.idx + 1}
|
||||||
|
{seg.hasAnnotation && (
|
||||||
|
<CheckOutlined style={{ marginLeft: 2, fontSize: 10 }} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<Tag color="blue" style={{ marginLeft: 8 }}>
|
||||||
|
{currentSegmentIndex + 1} / {segments.length}
|
||||||
|
</Tag>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<iframe
|
|
||||||
ref={iframeRef}
|
{/* 编辑器区域 */}
|
||||||
title="Label Studio Frontend"
|
<div className="flex-1 relative">
|
||||||
src={LSF_IFRAME_SRC}
|
{(!iframeReady || loadingTaskDetail || (selectedFileId && !lsReady)) && (
|
||||||
className="w-full h-full border-0"
|
<div className="absolute inset-0 z-10 flex items-center justify-center bg-white/80">
|
||||||
/>
|
<Spin
|
||||||
|
tip={
|
||||||
|
!iframeReady
|
||||||
|
? "编辑器资源加载中..."
|
||||||
|
: loadingTaskDetail
|
||||||
|
? "任务数据加载中..."
|
||||||
|
: "编辑器初始化中..."
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<iframe
|
||||||
|
ref={iframeRef}
|
||||||
|
title="Label Studio Frontend"
|
||||||
|
src={LSF_IFRAME_SRC}
|
||||||
|
className="w-full h-full border-0"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -69,14 +69,22 @@ export function listEditorTasksUsingGet(projectId: string, params?: any) {
|
|||||||
return get(`/api/annotation/editor/projects/${projectId}/tasks`, params);
|
return get(`/api/annotation/editor/projects/${projectId}/tasks`, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getEditorTaskUsingGet(projectId: string, fileId: string) {
|
export function getEditorTaskUsingGet(
|
||||||
return get(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}`);
|
projectId: string,
|
||||||
|
fileId: string,
|
||||||
|
params?: { segmentIndex?: number }
|
||||||
|
) {
|
||||||
|
return get(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}`, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function upsertEditorAnnotationUsingPut(
|
export function upsertEditorAnnotationUsingPut(
|
||||||
projectId: string,
|
projectId: string,
|
||||||
fileId: string,
|
fileId: string,
|
||||||
data: any
|
data: {
|
||||||
|
annotation: any;
|
||||||
|
expectedUpdatedAt?: string;
|
||||||
|
segmentIndex?: number;
|
||||||
|
}
|
||||||
) {
|
) {
|
||||||
return put(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}/annotation`, data);
|
return put(`/api/annotation/editor/projects/${projectId}/tasks/${fileId}/annotation`, data);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ Label Studio Editor(前端嵌入式)接口
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Path
|
from fastapi import APIRouter, Depends, Query, Path
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -67,10 +69,11 @@ async def list_editor_tasks(
|
|||||||
async def get_editor_task(
|
async def get_editor_task(
|
||||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||||
|
segment_index: Optional[int] = Query(None, alias="segmentIndex", description="段落索引(分段模式下使用)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db)
|
||||||
task = await service.get_task(project_id, file_id)
|
task = await service.get_task(project_id, file_id, segment_index=segment_index)
|
||||||
return StandardResponse(code=200, message="success", data=task)
|
return StandardResponse(code=200, message="success", data=task)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,12 +51,30 @@ class EditorTaskListResponse(BaseModel):
|
|||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentInfo(BaseModel):
|
||||||
|
"""段落信息(用于文本分段标注)"""
|
||||||
|
|
||||||
|
idx: int = Field(..., description="段落索引")
|
||||||
|
text: str = Field(..., description="段落文本")
|
||||||
|
start: int = Field(..., description="在原文中的起始位置")
|
||||||
|
end: int = Field(..., description="在原文中的结束位置")
|
||||||
|
has_annotation: bool = Field(False, alias="hasAnnotation", description="该段落是否已有标注")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
class EditorTaskResponse(BaseModel):
|
class EditorTaskResponse(BaseModel):
|
||||||
"""编辑器任务详情(可直接喂给 Label Studio Editor 的 task 对象)"""
|
"""编辑器任务详情(可直接喂给 Label Studio Editor 的 task 对象)"""
|
||||||
|
|
||||||
task: Dict[str, Any] = Field(..., description="Label Studio task 对象")
|
task: Dict[str, Any] = Field(..., description="Label Studio task 对象")
|
||||||
annotation_updated_at: Optional[datetime] = Field(None, alias="annotationUpdatedAt", description="标注更新时间")
|
annotation_updated_at: Optional[datetime] = Field(None, alias="annotationUpdatedAt", description="标注更新时间")
|
||||||
|
|
||||||
|
# 分段相关字段
|
||||||
|
segmented: bool = Field(False, description="是否启用分段模式")
|
||||||
|
segments: Optional[List[SegmentInfo]] = Field(None, description="段落列表")
|
||||||
|
total_segments: int = Field(0, alias="totalSegments", description="总段落数")
|
||||||
|
current_segment_index: int = Field(0, alias="currentSegmentIndex", description="当前段落索引")
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +87,12 @@ class UpsertAnnotationRequest(BaseModel):
|
|||||||
alias="expectedUpdatedAt",
|
alias="expectedUpdatedAt",
|
||||||
description="乐观锁:若提供则要求与当前记录 updated_at 一致,否则返回 409",
|
description="乐观锁:若提供则要求与当前记录 updated_at 一致,否则返回 409",
|
||||||
)
|
)
|
||||||
|
# 分段保存支持
|
||||||
|
segment_index: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
alias="segmentIndex",
|
||||||
|
description="段落索引(分段模式下必填)",
|
||||||
|
)
|
||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
标注文本分割器
|
||||||
|
|
||||||
|
职责:将长文本按指定规则分割为适合标注的段落
|
||||||
|
- 最大200字符(CJK按1字符计)
|
||||||
|
- 分隔符:。;以及正则 \\?|\\!|(?<!\\d)\\.(?!\\d)
|
||||||
|
- 超长句子保持完整
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentInfo(TypedDict):
|
||||||
|
"""段落信息"""
|
||||||
|
idx: int # 段落索引
|
||||||
|
text: str # 段落文本
|
||||||
|
start: int # 在原文中的起始位置
|
||||||
|
end: int # 在原文中的结束位置
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationTextSplitter:
|
||||||
|
"""标注文本分割器"""
|
||||||
|
|
||||||
|
# 分隔符正则:全角句号、全角分号、以及非数字间的英文句号/问号/感叹号
|
||||||
|
# 使用捕获组保留分隔符
|
||||||
|
SEPARATOR_PATTERN = r'(。|;|\?|\!|(?<!\d)\.(?!\d))'
|
||||||
|
|
||||||
|
def __init__(self, max_chars: int = 200):
|
||||||
|
"""
|
||||||
|
初始化分割器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_chars: 每个段落的最大字符数(默认200)
|
||||||
|
"""
|
||||||
|
self.max_chars = max_chars
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[SegmentInfo]:
|
||||||
|
"""
|
||||||
|
将文本分割为段落列表
|
||||||
|
|
||||||
|
规则:
|
||||||
|
1. 按分隔符切分为句子
|
||||||
|
2. 贪心合并句子,直到超过 max_chars
|
||||||
|
3. 单句超过 max_chars 则独立成段(保持句子完整)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待分割的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
段落列表,每个元素包含 idx, text, start, end
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return [{"idx": 0, "text": "", "start": 0, "end": 0}]
|
||||||
|
|
||||||
|
# 短文本不需要分割
|
||||||
|
if len(text) <= self.max_chars:
|
||||||
|
return [{"idx": 0, "text": text, "start": 0, "end": len(text)}]
|
||||||
|
|
||||||
|
# 按分隔符切分,保留分隔符
|
||||||
|
parts = re.split(self.SEPARATOR_PATTERN, text)
|
||||||
|
|
||||||
|
# 合并句子和分隔符
|
||||||
|
sentences: List[str] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(parts):
|
||||||
|
part = parts[i]
|
||||||
|
# 检查下一个是否是分隔符(匹配捕获组)
|
||||||
|
if i + 1 < len(parts) and re.fullmatch(self.SEPARATOR_PATTERN, parts[i + 1]):
|
||||||
|
# 将分隔符附加到当前部分
|
||||||
|
part += parts[i + 1]
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
# 跳过空字符串
|
||||||
|
if part:
|
||||||
|
sentences.append(part)
|
||||||
|
|
||||||
|
# 贪心合并
|
||||||
|
segments: List[SegmentInfo] = []
|
||||||
|
current_text = ""
|
||||||
|
current_start = 0
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
if not current_text:
|
||||||
|
# 开始新段落
|
||||||
|
current_text = sentence
|
||||||
|
elif len(current_text) + len(sentence) <= self.max_chars:
|
||||||
|
# 可以合并到当前段落
|
||||||
|
current_text += sentence
|
||||||
|
else:
|
||||||
|
# 当前段落已满,保存
|
||||||
|
segments.append({
|
||||||
|
"idx": idx,
|
||||||
|
"text": current_text,
|
||||||
|
"start": current_start,
|
||||||
|
"end": current_start + len(current_text)
|
||||||
|
})
|
||||||
|
idx += 1
|
||||||
|
current_start += len(current_text)
|
||||||
|
current_text = sentence
|
||||||
|
|
||||||
|
# 处理最后一个段落
|
||||||
|
if current_text:
|
||||||
|
segments.append({
|
||||||
|
"idx": idx,
|
||||||
|
"text": current_text,
|
||||||
|
"start": current_start,
|
||||||
|
"end": current_start + len(current_text)
|
||||||
|
})
|
||||||
|
|
||||||
|
return segments
|
||||||
@@ -27,10 +27,12 @@ from app.module.annotation.schema.editor import (
|
|||||||
EditorTaskListItem,
|
EditorTaskListItem,
|
||||||
EditorTaskListResponse,
|
EditorTaskListResponse,
|
||||||
EditorTaskResponse,
|
EditorTaskResponse,
|
||||||
|
SegmentInfo,
|
||||||
UpsertAnnotationRequest,
|
UpsertAnnotationRequest,
|
||||||
UpsertAnnotationResponse,
|
UpsertAnnotationResponse,
|
||||||
)
|
)
|
||||||
from app.module.annotation.service.template import AnnotationTemplateService
|
from app.module.annotation.service.template import AnnotationTemplateService
|
||||||
|
from app.module.annotation.service.annotation_text_splitter import AnnotationTextSplitter
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +40,9 @@ logger = get_logger(__name__)
|
|||||||
class AnnotationEditorService:
|
class AnnotationEditorService:
|
||||||
"""Label Studio Editor 集成服务(TEXT POC 版)"""
|
"""Label Studio Editor 集成服务(TEXT POC 版)"""
|
||||||
|
|
||||||
|
# 分段阈值:超过此字符数自动分段
|
||||||
|
SEGMENT_THRESHOLD = 200
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession):
|
def __init__(self, db: AsyncSession):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.template_service = AnnotationTemplateService()
|
self.template_service = AnnotationTemplateService()
|
||||||
@@ -206,7 +211,12 @@ class AnnotationEditorService:
|
|||||||
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, err={e}")
|
logger.error(f"读取文本失败: dataset={dataset_id}, file={file_id}, err={e}")
|
||||||
raise HTTPException(status_code=502, detail="读取文本失败(下载接口调用异常)")
|
raise HTTPException(status_code=502, detail="读取文本失败(下载接口调用异常)")
|
||||||
|
|
||||||
async def get_task(self, project_id: str, file_id: str) -> EditorTaskResponse:
|
async def get_task(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
file_id: str,
|
||||||
|
segment_index: Optional[int] = None,
|
||||||
|
) -> EditorTaskResponse:
|
||||||
project = await self._get_project_or_404(project_id)
|
project = await self._get_project_or_404(project_id)
|
||||||
|
|
||||||
# TEXT 支持校验
|
# TEXT 支持校验
|
||||||
@@ -226,6 +236,7 @@ class AnnotationEditorService:
|
|||||||
|
|
||||||
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
||||||
|
|
||||||
|
# 获取现有标注
|
||||||
ann_result = await self.db.execute(
|
ann_result = await self.db.execute(
|
||||||
select(AnnotationResult).where(
|
select(AnnotationResult).where(
|
||||||
AnnotationResult.project_id == project_id,
|
AnnotationResult.project_id == project_id,
|
||||||
@@ -236,10 +247,44 @@ class AnnotationEditorService:
|
|||||||
|
|
||||||
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
||||||
|
|
||||||
|
# 判断是否需要分段
|
||||||
|
needs_segmentation = len(text_content) > self.SEGMENT_THRESHOLD
|
||||||
|
segments: Optional[List[SegmentInfo]] = None
|
||||||
|
current_segment_index = 0
|
||||||
|
display_text = text_content
|
||||||
|
|
||||||
|
if needs_segmentation:
|
||||||
|
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
||||||
|
raw_segments = splitter.split(text_content)
|
||||||
|
current_segment_index = segment_index if segment_index is not None else 0
|
||||||
|
|
||||||
|
# 校验段落索引
|
||||||
|
if current_segment_index < 0 or current_segment_index >= len(raw_segments):
|
||||||
|
current_segment_index = 0
|
||||||
|
|
||||||
|
# 标记每个段落是否已有标注
|
||||||
|
segment_annotations: Dict[str, Any] = {}
|
||||||
|
if ann and ann.annotation and ann.annotation.get("segmented"):
|
||||||
|
segment_annotations = ann.annotation.get("segments", {})
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
for seg in raw_segments:
|
||||||
|
segments.append(SegmentInfo(
|
||||||
|
idx=seg["idx"],
|
||||||
|
text=seg["text"],
|
||||||
|
start=seg["start"],
|
||||||
|
end=seg["end"],
|
||||||
|
hasAnnotation=str(seg["idx"]) in segment_annotations,
|
||||||
|
))
|
||||||
|
|
||||||
|
# 当前段落文本用于 task.data.text
|
||||||
|
display_text = raw_segments[current_segment_index]["text"]
|
||||||
|
|
||||||
|
# 构造 task 对象
|
||||||
task: Dict[str, Any] = {
|
task: Dict[str, Any] = {
|
||||||
"id": ls_task_id,
|
"id": ls_task_id,
|
||||||
"data": {
|
"data": {
|
||||||
"text": text_content,
|
"text": display_text,
|
||||||
"file_id": file_id,
|
"file_id": file_id,
|
||||||
"dataset_id": project.dataset_id,
|
"dataset_id": project.dataset_id,
|
||||||
"file_name": getattr(file_record, "file_name", ""),
|
"file_name": getattr(file_record, "file_name", ""),
|
||||||
@@ -250,15 +295,43 @@ class AnnotationEditorService:
|
|||||||
annotation_updated_at = None
|
annotation_updated_at = None
|
||||||
if ann:
|
if ann:
|
||||||
annotation_updated_at = ann.updated_at
|
annotation_updated_at = ann.updated_at
|
||||||
# 直接返回存储的 annotation 原始对象(Label Studio 兼容)
|
|
||||||
stored = dict(ann.annotation or {})
|
if needs_segmentation and ann.annotation and ann.annotation.get("segmented"):
|
||||||
stored["task"] = ls_task_id
|
# 分段模式:获取当前段落的标注
|
||||||
if not isinstance(stored.get("id"), int):
|
segment_annotations = ann.annotation.get("segments", {})
|
||||||
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
|
seg_ann = segment_annotations.get(str(current_segment_index), {})
|
||||||
task["annotations"] = [stored]
|
stored = {
|
||||||
|
"id": self._make_ls_annotation_id(project_id, file_id) + current_segment_index,
|
||||||
|
"task": ls_task_id,
|
||||||
|
"result": seg_ann.get("result", []),
|
||||||
|
"created_at": seg_ann.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
||||||
|
"updated_at": seg_ann.get("updated_at", datetime.utcnow().isoformat() + "Z"),
|
||||||
|
}
|
||||||
|
task["annotations"] = [stored]
|
||||||
|
elif not needs_segmentation:
|
||||||
|
# 非分段模式:直接返回存储的 annotation 原始对象
|
||||||
|
stored = dict(ann.annotation or {})
|
||||||
|
stored["task"] = ls_task_id
|
||||||
|
if not isinstance(stored.get("id"), int):
|
||||||
|
stored["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||||
|
task["annotations"] = [stored]
|
||||||
|
else:
|
||||||
|
# 首次从非分段切换到分段:提供空标注
|
||||||
|
empty_ann_id = self._make_ls_annotation_id(project_id, file_id) + current_segment_index
|
||||||
|
task["annotations"] = [
|
||||||
|
{
|
||||||
|
"id": empty_ann_id,
|
||||||
|
"task": ls_task_id,
|
||||||
|
"result": [],
|
||||||
|
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||||
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
|
# 提供一个空 annotation,避免前端在没有选中 annotation 时无法产生 result
|
||||||
empty_ann_id = self._make_ls_annotation_id(project_id, file_id)
|
empty_ann_id = self._make_ls_annotation_id(project_id, file_id)
|
||||||
|
if needs_segmentation:
|
||||||
|
empty_ann_id += current_segment_index
|
||||||
task["annotations"] = [
|
task["annotations"] = [
|
||||||
{
|
{
|
||||||
"id": empty_ann_id,
|
"id": empty_ann_id,
|
||||||
@@ -272,6 +345,10 @@ class AnnotationEditorService:
|
|||||||
return EditorTaskResponse(
|
return EditorTaskResponse(
|
||||||
task=task,
|
task=task,
|
||||||
annotationUpdatedAt=annotation_updated_at,
|
annotationUpdatedAt=annotation_updated_at,
|
||||||
|
segmented=needs_segmentation,
|
||||||
|
segments=segments,
|
||||||
|
totalSegments=len(segments) if segments else 1,
|
||||||
|
currentSegmentIndex=current_segment_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
|
async def upsert_annotation(self, project_id: str, file_id: str, request: UpsertAnnotationRequest) -> UpsertAnnotationResponse:
|
||||||
@@ -293,9 +370,6 @@ class AnnotationEditorService:
|
|||||||
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
|
raise HTTPException(status_code=400, detail="annotation.result 必须为数组")
|
||||||
|
|
||||||
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
ls_task_id = self._make_ls_task_id(project_id, file_id)
|
||||||
annotation_payload["task"] = ls_task_id
|
|
||||||
if not isinstance(annotation_payload.get("id"), int):
|
|
||||||
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
|
|
||||||
|
|
||||||
existing_result = await self.db.execute(
|
existing_result = await self.db.execute(
|
||||||
select(AnnotationResult).where(
|
select(AnnotationResult).where(
|
||||||
@@ -307,12 +381,27 @@ class AnnotationEditorService:
|
|||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
# 判断是否为分段保存模式
|
||||||
|
if request.segment_index is not None:
|
||||||
|
# 分段模式:合并段落标注到整体结构
|
||||||
|
final_payload = self._merge_segment_annotation(
|
||||||
|
existing.annotation if existing else None,
|
||||||
|
request.segment_index,
|
||||||
|
annotation_payload,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 非分段模式:直接使用传入的 annotation
|
||||||
|
annotation_payload["task"] = ls_task_id
|
||||||
|
if not isinstance(annotation_payload.get("id"), int):
|
||||||
|
annotation_payload["id"] = self._make_ls_annotation_id(project_id, file_id)
|
||||||
|
final_payload = annotation_payload
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
if request.expected_updated_at and existing.updated_at:
|
if request.expected_updated_at and existing.updated_at:
|
||||||
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
|
if existing.updated_at != request.expected_updated_at.replace(tzinfo=None):
|
||||||
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
|
raise HTTPException(status_code=409, detail="标注已被更新,请刷新后重试")
|
||||||
|
|
||||||
existing.annotation = annotation_payload # type: ignore[assignment]
|
existing.annotation = final_payload # type: ignore[assignment]
|
||||||
existing.updated_at = now # type: ignore[assignment]
|
existing.updated_at = now # type: ignore[assignment]
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
await self.db.refresh(existing)
|
await self.db.refresh(existing)
|
||||||
@@ -327,7 +416,7 @@ class AnnotationEditorService:
|
|||||||
id=new_id,
|
id=new_id,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
annotation=annotation_payload,
|
annotation=final_payload,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
@@ -340,3 +429,39 @@ class AnnotationEditorService:
|
|||||||
updatedAt=record.updated_at or now,
|
updatedAt=record.updated_at or now,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _merge_segment_annotation(
|
||||||
|
self,
|
||||||
|
existing: Optional[Dict[str, Any]],
|
||||||
|
segment_index: int,
|
||||||
|
new_annotation: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
合并段落标注到整体结构
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing: 现有的 annotation 数据
|
||||||
|
segment_index: 段落索引
|
||||||
|
new_annotation: 新的段落标注数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的 annotation 结构
|
||||||
|
"""
|
||||||
|
if not existing or not existing.get("segmented"):
|
||||||
|
# 初始化分段结构
|
||||||
|
base: Dict[str, Any] = {
|
||||||
|
"segmented": True,
|
||||||
|
"version": 1,
|
||||||
|
"segments": {},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
base = dict(existing)
|
||||||
|
|
||||||
|
# 更新指定段落的标注
|
||||||
|
base["segments"][str(segment_index)] = {
|
||||||
|
"result": new_annotation.get("result", []),
|
||||||
|
"created_at": new_annotation.get("created_at", datetime.utcnow().isoformat() + "Z"),
|
||||||
|
"updated_at": datetime.utcnow().isoformat() + "Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user