You've already forked DataMate
feat(annotation): 添加文本数据集段落切片功能
- 在前端组件中新增 segmentationEnabled 字段控制切片开关 - 为文本数据集添加段落切片配置选项,默认启用切片功能 - 在后端接口中新增 segmentation_enabled 参数传递给标注项目 - 实现切片逻辑控制,支持文本数据的自动段落分割 - 添加数据集类型判断,仅文本数据集支持切片配置 - 更新标注任务创建和编辑表单中的切片相关字段处理
This commit is contained in:
@@ -25,6 +25,7 @@ router = APIRouter(
|
||||
tags=["annotation/project"]
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
TEXT_DATASET_TYPE = "TEXT"
|
||||
|
||||
@router.get("/{mapping_id}/login")
|
||||
async def login_label_studio(
|
||||
@@ -62,6 +63,12 @@ async def create_mapping(
|
||||
detail=f"Dataset not found in DM service: {request.dataset_id}"
|
||||
)
|
||||
|
||||
dataset_type = (
|
||||
getattr(dataset_info, "datasetType", None)
|
||||
or getattr(dataset_info, "dataset_type", None)
|
||||
or ""
|
||||
).upper()
|
||||
|
||||
project_name = request.name or \
|
||||
dataset_info.name or \
|
||||
"A new project from DataMate"
|
||||
@@ -97,6 +104,8 @@ async def create_mapping(
|
||||
project_configuration["label_config"] = label_config
|
||||
if project_description:
|
||||
project_configuration["description"] = project_description
|
||||
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None:
|
||||
project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled)
|
||||
|
||||
labeling_project = LabelingProject(
|
||||
id=str(uuid.uuid4()), # Generate UUID here
|
||||
|
||||
@@ -23,6 +23,11 @@ class DatasetMappingCreateRequest(BaseModel):
|
||||
description: Optional[str] = Field(None, alias="description", description="标注项目描述")
|
||||
template_id: Optional[str] = Field(None, alias="templateId", description="标注模板ID")
|
||||
label_config: Optional[str] = Field(None, alias="labelConfig", description="Label Studio XML配置")
|
||||
segmentation_enabled: Optional[bool] = Field(
|
||||
None,
|
||||
alias="segmentationEnabled",
|
||||
description="是否启用文本分段",
|
||||
)
|
||||
|
||||
class Config:
|
||||
# allow population by field name when constructing model programmatically
|
||||
@@ -49,6 +54,11 @@ class DatasetMappingResponse(BaseModel):
|
||||
template_id: Optional[str] = Field(None, alias="templateId", description="关联的模板ID")
|
||||
template: Optional['AnnotationTemplateResponse'] = Field(None, description="关联的标注模板详情")
|
||||
label_config: Optional[str] = Field(None, alias="labelConfig", description="实际使用的 Label Studio XML 配置")
|
||||
segmentation_enabled: Optional[bool] = Field(
|
||||
None,
|
||||
alias="segmentationEnabled",
|
||||
description="是否启用文本分段",
|
||||
)
|
||||
total_count: int = Field(0, alias="totalCount", description="数据集总数据量")
|
||||
annotated_count: int = Field(0, alias="annotatedCount", description="已标注数据量")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
@@ -62,4 +72,4 @@ class DatasetMappingResponse(BaseModel):
|
||||
class DeleteDatasetResponse(BaseResponseModel):
|
||||
"""删除数据集响应模型"""
|
||||
id: str = Field(..., description="映射UUID")
|
||||
status: str = Field(..., description="删除状态")
|
||||
status: str = Field(..., description="删除状态")
|
||||
|
||||
@@ -56,6 +56,7 @@ TEXTUAL_OBJECT_CATEGORIES = {"text", "document"}
|
||||
MEDIA_OBJECT_CATEGORIES = {"image"}
|
||||
OBJECT_NAME_HEADER_PREFIX = "dm_object_header_"
|
||||
SUPPORTED_EDITOR_DATASET_TYPES = ("TEXT", "IMAGE")
|
||||
SEGMENTATION_ENABLED_KEY = "segmentation_enabled"
|
||||
|
||||
|
||||
class AnnotationEditorService:
|
||||
@@ -149,6 +150,18 @@ class AnnotationEditorService:
|
||||
label_config = self._decorate_label_config_for_editor(label_config)
|
||||
return label_config
|
||||
|
||||
@staticmethod
|
||||
def _resolve_segmentation_enabled(project: LabelingProject) -> bool:
|
||||
config = project.configuration
|
||||
if not isinstance(config, dict):
|
||||
return True
|
||||
value = config.get(SEGMENTATION_ENABLED_KEY)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return True
|
||||
return bool(value)
|
||||
|
||||
@classmethod
|
||||
def _resolve_primary_text_key(cls, label_config: Optional[str]) -> Optional[str]:
|
||||
if not label_config:
|
||||
@@ -513,13 +526,19 @@ class AnnotationEditorService:
|
||||
ls_task_id = self._make_ls_task_id(project.id, file_id)
|
||||
|
||||
# 判断是否需要分段(JSONL 多行或主文本超过阈值)
|
||||
needs_segmentation = len(records) > 1 or any(
|
||||
len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts
|
||||
segmentation_enabled = self._resolve_segmentation_enabled(project)
|
||||
if not segmentation_enabled:
|
||||
segment_index = None
|
||||
needs_segmentation = segmentation_enabled and (
|
||||
len(records) > 1 or any(len(text or "") > self.SEGMENT_THRESHOLD for text in record_texts)
|
||||
)
|
||||
segments: Optional[List[SegmentInfo]] = None
|
||||
current_segment_index = 0
|
||||
display_text = record_texts[0] if record_texts else text_content
|
||||
selected_payload = records[0][0] if records else None
|
||||
if not segmentation_enabled and len(records) > 1:
|
||||
selected_payload = None
|
||||
display_text = "\n".join(record_texts) if record_texts else text_content
|
||||
|
||||
segment_annotations: Dict[str, Any] = {}
|
||||
if ann and ann.annotation and ann.annotation.get("segmented"):
|
||||
|
||||
@@ -90,9 +90,11 @@ class DatasetMappingService:
|
||||
configuration = getattr(mapping, 'configuration', None) or {}
|
||||
label_config = None
|
||||
description = None
|
||||
segmentation_enabled = None
|
||||
if isinstance(configuration, dict):
|
||||
label_config = configuration.get('label_config')
|
||||
description = configuration.get('description')
|
||||
segmentation_enabled = configuration.get('segmentation_enabled')
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
@@ -117,6 +119,7 @@ class DatasetMappingService:
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"label_config": label_config,
|
||||
"segmentation_enabled": segmentation_enabled,
|
||||
"total_count": total_count,
|
||||
"annotated_count": annotated_count,
|
||||
"created_at": mapping.created_at,
|
||||
@@ -154,9 +157,11 @@ class DatasetMappingService:
|
||||
configuration = getattr(mapping, 'configuration', None) or {}
|
||||
label_config = None
|
||||
description = None
|
||||
segmentation_enabled = None
|
||||
if isinstance(configuration, dict):
|
||||
label_config = configuration.get('label_config')
|
||||
description = configuration.get('description')
|
||||
segmentation_enabled = configuration.get('segmentation_enabled')
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
@@ -184,6 +189,7 @@ class DatasetMappingService:
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"label_config": label_config,
|
||||
"segmentation_enabled": segmentation_enabled,
|
||||
"total_count": total_count,
|
||||
"annotated_count": annotated_count,
|
||||
"created_at": mapping.created_at,
|
||||
@@ -526,4 +532,4 @@ class DatasetMappingService:
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=include_template)
|
||||
responses.append(response)
|
||||
return responses, total
|
||||
return responses, total
|
||||
|
||||
Reference in New Issue
Block a user