You've already forked DataMate
feat(annotation): 实现文本切片预生成功能
在创建标注任务时自动预生成文本切片结构,避免每次进入标注页面时的实时计算。 修改内容: 1. 在 AnnotationEditorService 中新增 precompute_segmentation_for_project 方法 - 为项目的所有文本文件预计算切片结构 - 使用 AnnotationTextSplitter 执行切片 - 将切片结构持久化到 AnnotationResult 表(状态为 IN_PROGRESS) - 支持失败重试机制 - 返回统计信息 2. 修改 create_mapping 接口 - 在创建标注任务后,如果启用分段且为文本数据集,自动触发切片预生成 - 使用 try-except 捕获异常,确保切片失败不影响项目创建 特点: - 使用现有的 AnnotationTextSplitter 类 - 切片数据结构与现有分段标注格式一致 - 向后兼容(未切片的任务仍然使用实时计算) - 性能优化:避免进入标注页面时的重复计算 相关文件: - runtime/datamate-python/app/module/annotation/service/editor.py - runtime/datamate-python/app/module/annotation/interface/project.py
This commit is contained in:
@@ -150,6 +150,18 @@ async def create_mapping(
|
||||
labeling_project, snapshot_file_ids
|
||||
)
|
||||
|
||||
# 如果启用了分段且为文本数据集,预生成切片结构
|
||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled:
|
||||
try:
|
||||
from ..service.editor import AnnotationEditorService
|
||||
editor_service = AnnotationEditorService(db)
|
||||
# 异步预计算切片(不阻塞创建响应)
|
||||
segmentation_result = await editor_service.precompute_segmentation_for_project(labeling_project.id)
|
||||
logger.info(f"Precomputed segmentation for project {labeling_project.id}: {segmentation_result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to precompute segmentation for project {labeling_project.id}: {e}")
|
||||
# 不影响项目创建,只记录警告
|
||||
|
||||
response_data = DatasetMappingCreateResponse(
|
||||
id=mapping.id,
|
||||
labeling_project_id=str(mapping.labeling_project_id),
|
||||
|
||||
@@ -1185,3 +1185,195 @@ class AnnotationEditorService:
|
||||
except Exception as exc:
|
||||
logger.warning("标注同步知识管理失败:%s", exc)
|
||||
|
||||
async def precompute_segmentation_for_project(
|
||||
self,
|
||||
project_id: str,
|
||||
max_retries: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为指定项目的所有文本文件预计算切片结构并持久化到数据库
|
||||
|
||||
Args:
|
||||
project_id: 标注项目ID
|
||||
max_retries: 失败重试次数
|
||||
|
||||
Returns:
|
||||
统计信息:{total_files, succeeded, failed}
|
||||
"""
|
||||
project = await self._get_project_or_404(project_id)
|
||||
dataset_type = self._normalize_dataset_type(await self._get_dataset_type(project.dataset_id))
|
||||
|
||||
# 只处理文本数据集
|
||||
if dataset_type != DATASET_TYPE_TEXT:
|
||||
logger.info(f"项目 {project_id} 不是文本数据集,跳过切片预生成")
|
||||
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
||||
|
||||
# 检查是否启用分段
|
||||
if not self._resolve_segmentation_enabled(project):
|
||||
logger.info(f"项目 {project_id} 未启用分段,跳过切片预生成")
|
||||
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
||||
|
||||
# 获取项目的所有文本文件(排除源文档)
|
||||
files_result = await self.db.execute(
|
||||
select(DatasetFiles)
|
||||
.join(LabelingProjectFile, LabelingProjectFile.file_id == DatasetFiles.id)
|
||||
.where(
|
||||
LabelingProjectFile.project_id == project_id,
|
||||
DatasetFiles.dataset_id == project.dataset_id,
|
||||
)
|
||||
)
|
||||
file_records = files_result.scalars().all()
|
||||
|
||||
if not file_records:
|
||||
logger.info(f"项目 {project_id} 没有文件,跳过切片预生成")
|
||||
return {"total_files": 0, "succeeded": 0, "failed": 0}
|
||||
|
||||
# 过滤源文档文件
|
||||
valid_files = []
|
||||
for file_record in file_records:
|
||||
file_type = str(getattr(file_record, "file_type", "") or "").lower()
|
||||
file_name = str(getattr(file_record, "file_name", "")).lower()
|
||||
is_source_document = (
|
||||
file_type in SOURCE_DOCUMENT_TYPES or
|
||||
any(file_name.endswith(ext) for ext in SOURCE_DOCUMENT_EXTENSIONS)
|
||||
)
|
||||
if not is_source_document:
|
||||
valid_files.append(file_record)
|
||||
|
||||
total_files = len(valid_files)
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
label_config = await self._resolve_project_label_config(project)
|
||||
primary_text_key = self._resolve_primary_text_key(label_config)
|
||||
|
||||
for file_record in valid_files:
|
||||
file_id = str(file_record.id) # type: ignore
|
||||
file_name = str(getattr(file_record, "file_name", ""))
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
# 读取文本内容
|
||||
text_content = await self._fetch_text_content_via_download_api(project.dataset_id, file_id)
|
||||
if not isinstance(text_content, str):
|
||||
logger.warning(f"文件 {file_id} 内容不是字符串,跳过切片")
|
||||
failed += 1
|
||||
break
|
||||
|
||||
# 解析文本记录
|
||||
records: List[Tuple[Optional[Dict[str, Any]], str]] = []
|
||||
if file_name.lower().endswith(JSONL_EXTENSION):
|
||||
records = self._parse_jsonl_records(text_content)
|
||||
else:
|
||||
parsed_payload = self._try_parse_json_payload(text_content)
|
||||
if parsed_payload:
|
||||
records = [(parsed_payload, text_content)]
|
||||
|
||||
if not records:
|
||||
records = [(None, text_content)]
|
||||
|
||||
record_texts = [
|
||||
self._resolve_primary_text_value(payload, raw_text, primary_text_key)
|
||||
for payload, raw_text in records
|
||||
]
|
||||
if not record_texts:
|
||||
record_texts = [text_content]
|
||||
|
||||
# 判断是否需要分段
|
||||
needs_segmentation = len(records) > 1 or any(
|
||||
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
|
||||
)
|
||||
|
||||
if not needs_segmentation:
|
||||
# 不需要分段的文件,跳过
|
||||
succeeded += 1
|
||||
break
|
||||
|
||||
# 执行切片
|
||||
splitter = AnnotationTextSplitter(max_chars=self.SEGMENT_THRESHOLD)
|
||||
segment_cursor = 0
|
||||
segments = {}
|
||||
|
||||
for record_index, ((payload, raw_text), record_text) in enumerate(zip(records, record_texts)):
|
||||
normalized_text = record_text or ""
|
||||
|
||||
if len(normalized_text) > self.SEGMENT_THRESHOLD:
|
||||
raw_segments = splitter.split(normalized_text)
|
||||
for chunk_index, seg in enumerate(raw_segments):
|
||||
segments[str(segment_cursor)] = {
|
||||
SEGMENT_RESULT_KEY: [],
|
||||
SEGMENT_CREATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
|
||||
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
|
||||
}
|
||||
segment_cursor += 1
|
||||
else:
|
||||
segments[str(segment_cursor)] = {
|
||||
SEGMENT_RESULT_KEY: [],
|
||||
SEGMENT_CREATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
|
||||
SEGMENT_UPDATED_AT_KEY: datetime.utcnow().isoformat() + "Z",
|
||||
}
|
||||
segment_cursor += 1
|
||||
|
||||
if not segments:
|
||||
succeeded += 1
|
||||
break
|
||||
|
||||
# 构造分段标注结构
|
||||
final_payload = {
|
||||
SEGMENTED_KEY: True,
|
||||
"version": 1,
|
||||
SEGMENTS_KEY: segments,
|
||||
SEGMENT_TOTAL_KEY: segment_cursor,
|
||||
}
|
||||
|
||||
# 检查是否已存在标注
|
||||
existing_result = await self.db.execute(
|
||||
select(AnnotationResult).where(
|
||||
AnnotationResult.project_id == project_id,
|
||||
AnnotationResult.file_id == file_id,
|
||||
)
|
||||
)
|
||||
existing = existing_result.scalar_one_or_none()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if existing:
|
||||
# 更新现有标注
|
||||
existing.annotation = final_payload # type: ignore[assignment]
|
||||
existing.annotation_status = ANNOTATION_STATUS_IN_PROGRESS # type: ignore[assignment]
|
||||
existing.updated_at = now # type: ignore[assignment]
|
||||
else:
|
||||
# 创建新标注记录
|
||||
record = AnnotationResult(
|
||||
id=str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
file_id=file_id,
|
||||
annotation=final_payload,
|
||||
annotation_status=ANNOTATION_STATUS_IN_PROGRESS,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(record)
|
||||
|
||||
await self.db.commit()
|
||||
succeeded += 1
|
||||
logger.info(f"成功为文件 {file_id} 预生成 {segment_cursor} 个切片")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"为文件 {file_id} 预生成切片失败 (重试 {retry + 1}/{max_retries}): {e}"
|
||||
)
|
||||
if retry == max_retries - 1:
|
||||
failed += 1
|
||||
await self.db.rollback()
|
||||
|
||||
logger.info(
|
||||
f"项目 {project_id} 切片预生成完成: 总计 {total_files}, 成功 {succeeded}, 失败 {failed}"
|
||||
)
|
||||
return {
|
||||
"total_files": total_files,
|
||||
"succeeded": succeeded,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user