You've already forked DataMate
feat: Add labeling template (#72)
* feat: Enhance annotation module with template management and validation - Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support. - Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates. - Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation. - Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats. - Updated database schema for annotation templates and labeling projects to include new fields and constraints. - Seeded initial annotation templates for various use cases including image classification, object detection, and text classification. * feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support * feat: Update docker-compose.yml to mark datamate dataset volume and network as external
This commit is contained in:
@@ -3,25 +3,32 @@ Tables of Annotation Management Module
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date
|
||||
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP, Text, Integer, JSON, Date, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
class AnnotationTemplate(Base):
|
||||
"""标注模板模型"""
|
||||
"""标注配置模板模型"""
|
||||
|
||||
__tablename__ = "t_dm_annotation_templates"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
|
||||
name = Column(String(32), nullable=False, comment="模板名称")
|
||||
description = Column(String(255), nullable=True, comment="模板描述")
|
||||
configuration = Column(JSON, nullable=True, comment="配置信息(JSON格式)")
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
name = Column(String(100), nullable=False, comment="模板名称")
|
||||
description = Column(String(500), nullable=True, comment="模板描述")
|
||||
data_type = Column(String(50), nullable=False, comment="数据类型: image/text/audio/video/timeseries")
|
||||
labeling_type = Column(String(50), nullable=False, comment="标注类型: classification/detection/segmentation/ner/relation/etc")
|
||||
configuration = Column(JSON, nullable=False, comment="标注配置(包含labels定义等)")
|
||||
style = Column(String(32), nullable=False, comment="样式配置: horizontal/vertical")
|
||||
category = Column(String(50), default='custom', comment="模板分类: medical/general/custom/system")
|
||||
built_in = Column(Boolean, default=False, comment="是否系统内置模板")
|
||||
version = Column(String(20), default='1.0', comment="模板版本")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name})>"
|
||||
return f"<AnnotationTemplate(id={self.id}, name={self.name}, data_type={self.data_type})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
@@ -29,21 +36,23 @@ class AnnotationTemplate(Base):
|
||||
return self.deleted_at is not None
|
||||
|
||||
class LabelingProject(Base):
|
||||
"""标注工程表"""
|
||||
"""标注项目模型"""
|
||||
|
||||
__tablename__ = "t_dm_labeling_projects"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
|
||||
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
|
||||
name = Column(String(32), nullable=False, comment="项目名称")
|
||||
name = Column(String(100), nullable=False, comment="项目名称")
|
||||
labeling_project_id = Column(String(8), nullable=False, comment="Label Studio项目ID")
|
||||
configuration = Column(JSON, nullable=True, comment="标签配置")
|
||||
progress = Column(JSON, nullable=True, comment="标注进度统计")
|
||||
template_id = Column(String(36), ForeignKey('t_dm_annotation_templates.id', ondelete='SET NULL'), nullable=True, comment="使用的模板ID")
|
||||
configuration = Column(JSON, nullable=True, comment="项目配置(可能包含对模板的自定义修改)")
|
||||
progress = Column(JSON, nullable=True, comment="项目进度信息")
|
||||
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
|
||||
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
|
||||
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LabelingProject(id={self.id}, dataset_id={self.dataset_id}, name={self.name})>"
|
||||
return f"<LabelingProject(id={self.id}, name={self.name}, dataset_id={self.dataset_id})>"
|
||||
|
||||
@property
|
||||
def is_deleted(self) -> bool:
|
||||
|
||||
@@ -103,6 +103,7 @@ class Client:
|
||||
"""创建Label Studio项目"""
|
||||
try:
|
||||
logger.debug(f"Creating Label Studio project: {title}")
|
||||
logger.debug(f"Label Studio URL: {self.base_url}/api/projects")
|
||||
|
||||
project_data = {
|
||||
"title": title,
|
||||
@@ -123,10 +124,28 @@ class Client:
|
||||
return project
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Create project failed HTTP {e.response.status_code}: {e.response.text}")
|
||||
logger.error(
|
||||
f"Create project failed - HTTP {e.response.status_code}\n"
|
||||
f"URL: {e.request.url}\n"
|
||||
f"Response Headers: {dict(e.response.headers)}\n"
|
||||
f"Response Body: {e.response.text[:1000]}" # First 1000 chars
|
||||
)
|
||||
return None
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(
|
||||
f"Failed to connect to Label Studio at {self.base_url}\n"
|
||||
f"Error: {str(e)}\n"
|
||||
f"Possible causes:\n"
|
||||
f" - Label Studio service is not running\n"
|
||||
f" - Incorrect URL configuration\n"
|
||||
f" - Network connectivity issue"
|
||||
)
|
||||
return None
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"Request to Label Studio timed out after {self.timeout}s: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating Label Studio project: {e}")
|
||||
logger.error(f"Error while creating Label Studio project: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def import_tasks(
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from .about import router as about_router
|
||||
from .project import router as project_router
|
||||
from .task import router as task_router
|
||||
from .template import router as template_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/annotation",
|
||||
@@ -11,4 +12,5 @@ router = APIRouter(
|
||||
|
||||
router.include_router(about_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(task_router)
|
||||
router.include_router(template_router)
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -14,6 +15,7 @@ from app.core.config import settings
|
||||
from ..client import LabelStudioClient
|
||||
from ..service.mapping import DatasetMappingService
|
||||
from ..service.sync import SyncService
|
||||
from ..service.template import AnnotationTemplateService
|
||||
from ..schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingCreateResponse,
|
||||
@@ -39,6 +41,8 @@ async def create_mapping(
|
||||
在数据库中记录这一关联关系,返回Label Studio数据集的ID
|
||||
|
||||
注意:一个数据集可以创建多个标注项目
|
||||
|
||||
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
|
||||
"""
|
||||
try:
|
||||
dm_client = DatasetManagementService(db)
|
||||
@@ -46,6 +50,7 @@ async def create_mapping(
|
||||
token=settings.label_studio_user_token)
|
||||
mapping_service = DatasetMappingService(db)
|
||||
sync_service = SyncService(dm_client, ls_client, mapping_service)
|
||||
template_service = AnnotationTemplateService()
|
||||
|
||||
logger.info(f"Create dataset mapping request: {request.dataset_id}")
|
||||
|
||||
@@ -65,10 +70,24 @@ async def create_mapping(
|
||||
dataset_info.description or \
|
||||
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
||||
|
||||
# 如果提供了模板ID,获取模板配置
|
||||
label_config = None
|
||||
if request.template_id:
|
||||
logger.info(f"Using template: {request.template_id}")
|
||||
template = await template_service.get_template(db, request.template_id)
|
||||
if not template:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Template not found: {request.template_id}"
|
||||
)
|
||||
label_config = template.label_config
|
||||
logger.debug(f"Template label config loaded for template: {template.name}")
|
||||
|
||||
# 在Label Studio中创建项目
|
||||
project_data = await ls_client.create_project(
|
||||
title=project_name,
|
||||
description=project_description,
|
||||
label_config=label_config # 传递模板配置
|
||||
)
|
||||
|
||||
if not project_data:
|
||||
@@ -96,9 +115,11 @@ async def create_mapping(
|
||||
logger.info(f"Local storage configured for project {project_id}: {local_storage_path}")
|
||||
|
||||
labeling_project = LabelingProject(
|
||||
id=str(uuid.uuid4()), # Generate UUID here
|
||||
dataset_id=request.dataset_id,
|
||||
labeling_project_id=str(project_id),
|
||||
name=project_name,
|
||||
template_id=request.template_id, # Save template_id to database
|
||||
)
|
||||
|
||||
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Annotation Template API Endpoints
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
from app.module.shared.schema import StandardResponse
|
||||
from app.module.annotation.schema.template import (
|
||||
CreateAnnotationTemplateRequest,
|
||||
UpdateAnnotationTemplateRequest,
|
||||
AnnotationTemplateResponse,
|
||||
AnnotationTemplateListResponse
|
||||
)
|
||||
from app.module.annotation.service.template import AnnotationTemplateService
|
||||
|
||||
router = APIRouter(prefix="/templates", tags=["Annotation Template"])
|
||||
|
||||
template_service = AnnotationTemplateService()
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=StandardResponse[AnnotationTemplateResponse],
|
||||
summary="创建标注模板"
|
||||
)
|
||||
async def create_template(
|
||||
request: CreateAnnotationTemplateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建新的标注模板
|
||||
|
||||
- **name**: 模板名称(必填,最多100字符)
|
||||
- **description**: 模板描述(可选,最多500字符)
|
||||
- **dataType**: 数据类型(必填)
|
||||
- **labelingType**: 标注类型(必填)
|
||||
- **configuration**: 标注配置(必填,包含labels和objects)
|
||||
- **style**: 样式配置(默认horizontal)
|
||||
- **category**: 模板分类(默认custom)
|
||||
"""
|
||||
template = await template_service.create_template(db, request)
|
||||
return StandardResponse(code=200, message="success", data=template)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=StandardResponse[AnnotationTemplateResponse],
|
||||
summary="获取模板详情"
|
||||
)
|
||||
async def get_template(
|
||||
template_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据ID获取模板详情
|
||||
"""
|
||||
template = await template_service.get_template(db, template_id)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return StandardResponse(code=200, message="success", data=template)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=StandardResponse[AnnotationTemplateListResponse],
|
||||
summary="获取模板列表"
|
||||
)
|
||||
async def list_templates(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
category: Optional[str] = Query(None, description="分类筛选"),
|
||||
dataType: Optional[str] = Query(None, alias="dataType", description="数据类型筛选"),
|
||||
labelingType: Optional[str] = Query(None, alias="labelingType", description="标注类型筛选"),
|
||||
builtIn: Optional[bool] = Query(None, alias="builtIn", description="是否内置模板"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取模板列表,支持分页和筛选
|
||||
|
||||
- **page**: 页码(从1开始)
|
||||
- **size**: 每页大小(1-100)
|
||||
- **category**: 模板分类筛选
|
||||
- **dataType**: 数据类型筛选
|
||||
- **labelingType**: 标注类型筛选
|
||||
- **builtIn**: 是否只显示内置模板
|
||||
"""
|
||||
templates = await template_service.list_templates(
|
||||
db=db,
|
||||
page=page,
|
||||
size=size,
|
||||
category=category,
|
||||
data_type=dataType,
|
||||
labeling_type=labelingType,
|
||||
built_in=builtIn
|
||||
)
|
||||
return StandardResponse(code=200, message="success", data=templates)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{template_id}",
|
||||
response_model=StandardResponse[AnnotationTemplateResponse],
|
||||
summary="更新模板"
|
||||
)
|
||||
async def update_template(
|
||||
template_id: str,
|
||||
request: UpdateAnnotationTemplateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新模板信息
|
||||
|
||||
所有字段都是可选的,只更新提供的字段
|
||||
"""
|
||||
template = await template_service.update_template(db, template_id, request)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return StandardResponse(code=200, message="success", data=template)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
response_model=StandardResponse[bool],
|
||||
summary="删除模板"
|
||||
)
|
||||
async def delete_template(
|
||||
template_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除模板(软删除)
|
||||
"""
|
||||
success = await template_service.delete_template(db, template_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return StandardResponse(code=200, message="success", data=True)
|
||||
@@ -11,13 +11,14 @@ class DatasetMappingCreateRequest(BaseModel):
|
||||
|
||||
Accept both snake_case and camelCase field names from frontend JSON by
|
||||
declaring explicit aliases. Frontend sends `datasetId`, `name`,
|
||||
`description` (camelCase), so provide aliases so pydantic will map them
|
||||
`description`, `templateId` (camelCase), so provide aliases so pydantic will map them
|
||||
to the internal attributes used in the service code (dataset_id, name,
|
||||
description).
|
||||
description, template_id).
|
||||
"""
|
||||
dataset_id: str = Field(..., alias="datasetId", description="源数据集ID")
|
||||
name: Optional[str] = Field(None, alias="name", description="标注项目名称")
|
||||
description: Optional[str] = Field(None, alias="description", description="标注项目描述")
|
||||
template_id: Optional[str] = Field(None, alias="templateId", description="标注模板ID")
|
||||
|
||||
class Config:
|
||||
# allow population by field name when constructing model programmatically
|
||||
@@ -34,13 +35,16 @@ class DatasetMappingUpdateRequest(BaseResponseModel):
|
||||
dataset_id: Optional[str] = Field(None, description="源数据集ID")
|
||||
|
||||
class DatasetMappingResponse(BaseModel):
|
||||
dataset_id: str = Field(..., description="源数据集ID")
|
||||
"""数据集映射 查询 响应模型"""
|
||||
id: str = Field(..., description="映射UUID")
|
||||
labeling_project_id: str = Field(..., description="标注项目ID")
|
||||
dataset_id: str = Field(..., alias="datasetId", description="源数据集ID")
|
||||
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
|
||||
labeling_project_id: str = Field(..., alias="labelingProjectId", description="标注项目ID")
|
||||
name: Optional[str] = Field(None, description="标注项目名称")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
deleted_at: Optional[datetime] = Field(None, description="删除时间")
|
||||
description: Optional[str] = Field(None, description="标注项目描述")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
deleted_at: Optional[datetime] = Field(None, alias="deletedAt", description="删除时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Annotation Template Schemas
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class LabelDefinition(BaseModel):
|
||||
"""标签定义"""
|
||||
from_name: str = Field(alias="fromName", description="控件名称")
|
||||
to_name: str = Field(alias="toName", description="目标对象名称")
|
||||
type: str = Field(description="控件类型: choices/rectanglelabels/polygonlabels/textarea/etc")
|
||||
options: Optional[List[str]] = Field(None, description="选项列表(用于choices类型)")
|
||||
labels: Optional[List[str]] = Field(None, description="标签列表(用于rectanglelabels等类型)")
|
||||
required: bool = Field(False, description="是否必填")
|
||||
description: Optional[str] = Field(None, description="标签描述")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class ObjectDefinition(BaseModel):
|
||||
"""对象定义"""
|
||||
name: str = Field(description="对象标识符")
|
||||
type: str = Field(description="对象类型: Image/Text/Audio/Video/etc")
|
||||
value: str = Field(description="变量名,如$image")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class TemplateConfiguration(BaseModel):
|
||||
"""模板配置结构"""
|
||||
labels: List[LabelDefinition] = Field(description="标签定义列表")
|
||||
objects: List[ObjectDefinition] = Field(description="对象定义列表")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class CreateAnnotationTemplateRequest(BaseModel):
|
||||
"""创建标注模板请求"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="模板名称")
|
||||
description: Optional[str] = Field(None, max_length=500, description="模板描述")
|
||||
data_type: str = Field(alias="dataType", description="数据类型")
|
||||
labeling_type: str = Field(alias="labelingType", description="标注类型")
|
||||
configuration: TemplateConfiguration = Field(..., description="标注配置")
|
||||
style: str = Field(default="horizontal", description="样式配置")
|
||||
category: str = Field(default="custom", description="模板分类")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class UpdateAnnotationTemplateRequest(BaseModel):
|
||||
"""更新标注模板请求"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称")
|
||||
description: Optional[str] = Field(None, max_length=500, description="模板描述")
|
||||
data_type: Optional[str] = Field(None, alias="dataType", description="数据类型")
|
||||
labeling_type: Optional[str] = Field(None, alias="labelingType", description="标注类型")
|
||||
configuration: Optional[TemplateConfiguration] = Field(None, description="标注配置")
|
||||
style: Optional[str] = Field(None, description="样式配置")
|
||||
category: Optional[str] = Field(None, description="模板分类")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class AnnotationTemplateResponse(BaseModel):
|
||||
"""标注模板响应"""
|
||||
id: str = Field(..., description="模板ID")
|
||||
name: str = Field(..., description="模板名称")
|
||||
description: Optional[str] = Field(None, description="模板描述")
|
||||
data_type: str = Field(alias="dataType", description="数据类型")
|
||||
labeling_type: str = Field(alias="labelingType", description="标注类型")
|
||||
configuration: TemplateConfiguration = Field(..., description="标注配置")
|
||||
label_config: Optional[str] = Field(None, alias="labelConfig", description="生成的Label Studio XML配置")
|
||||
style: str = Field(..., description="样式配置")
|
||||
category: str = Field(..., description="模板分类")
|
||||
built_in: bool = Field(alias="builtIn", description="是否内置模板")
|
||||
version: str = Field(..., description="版本号")
|
||||
created_at: datetime = Field(alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
|
||||
|
||||
|
||||
class AnnotationTemplateListResponse(BaseModel):
|
||||
"""模板列表响应"""
|
||||
content: List[AnnotationTemplateResponse] = Field(..., description="模板列表")
|
||||
total: int = Field(..., description="总数")
|
||||
page: int = Field(..., description="当前页")
|
||||
size: int = Field(..., description="每页大小")
|
||||
total_pages: int = Field(alias="totalPages", description="总页数")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
@@ -1,12 +1,14 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update, func
|
||||
from sqlalchemy.orm import aliased
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import LabelingProject
|
||||
from app.db.models.dataset_management import Dataset
|
||||
from app.module.annotation.schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingUpdateRequest,
|
||||
@@ -21,6 +23,61 @@ class DatasetMappingService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
def _build_query_with_dataset_name(self):
|
||||
"""Build base query with dataset name joined"""
|
||||
return select(
|
||||
LabelingProject,
|
||||
Dataset.name.label('dataset_name')
|
||||
).outerjoin(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id
|
||||
)
|
||||
|
||||
def _to_response_from_row(self, row) -> DatasetMappingResponse:
|
||||
"""Convert query row (mapping + dataset_name) to response"""
|
||||
mapping = row[0] # LabelingProject object
|
||||
dataset_name = row[1] # dataset_name from join
|
||||
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": mapping.dataset_id,
|
||||
"dataset_name": dataset_name,
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
}
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
|
||||
"""Convert ORM model to response with dataset name (for single entity operations)"""
|
||||
# Fetch dataset name
|
||||
dataset_name = None
|
||||
dataset_id = getattr(mapping, 'dataset_id', None)
|
||||
if dataset_id:
|
||||
dataset_result = await self.db.execute(
|
||||
select(Dataset.name).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset_name = dataset_result.scalar_one_or_none()
|
||||
|
||||
# Create response dict with all fields
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset_name,
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
}
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def create_mapping(
|
||||
self,
|
||||
labeling_project: LabelingProject
|
||||
@@ -28,19 +85,13 @@ class DatasetMappingService:
|
||||
"""创建数据集映射"""
|
||||
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
|
||||
|
||||
db_mapping = LabelingProject(
|
||||
id=str(uuid.uuid4()),
|
||||
dataset_id=labeling_project.dataset_id,
|
||||
labeling_project_id=labeling_project.labeling_project_id,
|
||||
name=labeling_project.name
|
||||
)
|
||||
|
||||
self.db.add(db_mapping)
|
||||
# Use the passed object directly
|
||||
self.db.add(labeling_project)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(db_mapping)
|
||||
await self.db.refresh(labeling_project)
|
||||
|
||||
logger.debug(f"Mapping created: {db_mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(db_mapping)
|
||||
logger.debug(f"Mapping created: {labeling_project.id}")
|
||||
return await self._to_response(labeling_project)
|
||||
|
||||
async def get_mapping_by_source_uuid(
|
||||
self,
|
||||
@@ -59,7 +110,7 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
|
||||
return None
|
||||
@@ -72,7 +123,7 @@ class DatasetMappingService:
|
||||
"""根据源数据集ID获取所有映射关系"""
|
||||
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
|
||||
|
||||
query = select(LabelingProject).where(
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
|
||||
@@ -82,10 +133,10 @@ class DatasetMappingService:
|
||||
result = await self.db.execute(
|
||||
query.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
|
||||
async def get_mapping_by_labeling_project_id(
|
||||
self,
|
||||
@@ -103,8 +154,8 @@ class DatasetMappingService:
|
||||
mapping = result.scalar_one_or_none()
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.mapping_id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
|
||||
return None
|
||||
@@ -123,9 +174,9 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"Mapping not found: {mapping_id}")
|
||||
logger.debug(f"No mapping found for mapping id: {mapping_id}")
|
||||
return None
|
||||
|
||||
async def update_mapping(
|
||||
@@ -184,17 +235,20 @@ class DatasetMappingService:
|
||||
"""获取所有有效映射"""
|
||||
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
|
||||
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.deleted_at.is_(None)
|
||||
)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject)
|
||||
.where(LabelingProject.deleted_at.is_(None))
|
||||
query
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
|
||||
async def count_mappings(self, include_deleted: bool = False) -> int:
|
||||
"""统计映射总数"""
|
||||
@@ -216,7 +270,7 @@ class DatasetMappingService:
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
|
||||
|
||||
# 构建查询
|
||||
query = select(LabelingProject)
|
||||
query = self._build_query_with_dataset_name()
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
@@ -235,10 +289,10 @@ class DatasetMappingService:
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
|
||||
async def get_mappings_by_source_with_count(
|
||||
self,
|
||||
@@ -251,7 +305,7 @@ class DatasetMappingService:
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
|
||||
|
||||
# 构建查询
|
||||
query = select(LabelingProject).where(
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
|
||||
@@ -275,7 +329,7 @@ class DatasetMappingService:
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Annotation Template Service
|
||||
"""
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uuid import uuid4
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.db.models.annotation_management import AnnotationTemplate
|
||||
from app.module.annotation.schema.template import (
|
||||
CreateAnnotationTemplateRequest,
|
||||
UpdateAnnotationTemplateRequest,
|
||||
AnnotationTemplateResponse,
|
||||
AnnotationTemplateListResponse,
|
||||
TemplateConfiguration
|
||||
)
|
||||
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
|
||||
|
||||
|
||||
class AnnotationTemplateService:
|
||||
"""标注模板服务"""
|
||||
|
||||
@staticmethod
|
||||
def generate_label_studio_config(config: TemplateConfiguration) -> str:
|
||||
"""
|
||||
从配置JSON生成Label Studio XML配置
|
||||
|
||||
Args:
|
||||
config: 模板配置对象
|
||||
|
||||
Returns:
|
||||
Label Studio XML字符串
|
||||
"""
|
||||
xml_parts = ['<View>']
|
||||
|
||||
# 生成对象定义
|
||||
for obj in config.objects:
|
||||
obj_attrs = [
|
||||
f'name="{obj.name}"',
|
||||
f'value="{obj.value}"'
|
||||
]
|
||||
xml_parts.append(f' <{obj.type} {" ".join(obj_attrs)}/>')
|
||||
|
||||
# 生成标签定义
|
||||
for label in config.labels:
|
||||
label_attrs = [
|
||||
f'name="{label.from_name}"',
|
||||
f'toName="{label.to_name}"'
|
||||
]
|
||||
|
||||
# 添加可选属性
|
||||
if label.required:
|
||||
label_attrs.append('required="true"')
|
||||
|
||||
tag_type = label.type.capitalize() if label.type else "Choices"
|
||||
|
||||
# 处理带选项的标签类型
|
||||
if label.options or label.labels:
|
||||
choices = label.options or label.labels or []
|
||||
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
|
||||
for choice in choices:
|
||||
xml_parts.append(f' <Label value="{choice}"/>')
|
||||
xml_parts.append(f' </{tag_type}>')
|
||||
else:
|
||||
# 处理简单标签类型
|
||||
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
|
||||
|
||||
xml_parts.append('</View>')
|
||||
return '\n'.join(xml_parts)
|
||||
|
||||
async def create_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAnnotationTemplateRequest
|
||||
) -> AnnotationTemplateResponse:
|
||||
"""
|
||||
创建标注模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request: 创建请求
|
||||
|
||||
Returns:
|
||||
创建的模板响应
|
||||
"""
|
||||
# 验证配置JSON
|
||||
config_dict = request.configuration.model_dump(mode='json', by_alias=False)
|
||||
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
|
||||
|
||||
# 生成Label Studio XML配置(用于验证,但不存储)
|
||||
label_config = self.generate_label_studio_config(request.configuration)
|
||||
|
||||
# 验证生成的XML
|
||||
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
|
||||
|
||||
# 创建模板对象(不包含label_config字段)
|
||||
template = AnnotationTemplate(
|
||||
id=str(uuid4()),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
data_type=request.data_type,
|
||||
labeling_type=request.labeling_type,
|
||||
configuration=config_dict,
|
||||
style=request.style,
|
||||
category=request.category,
|
||||
built_in=False,
|
||||
version="1.0.0",
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
db.add(template)
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return self._to_response(template)
|
||||
|
||||
async def get_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str
|
||||
) -> Optional[AnnotationTemplateResponse]:
|
||||
"""
|
||||
获取单个模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
模板响应或None
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
return self._to_response(template)
|
||||
return None
|
||||
|
||||
async def list_templates(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
category: Optional[str] = None,
|
||||
data_type: Optional[str] = None,
|
||||
labeling_type: Optional[str] = None,
|
||||
built_in: Optional[bool] = None
|
||||
) -> AnnotationTemplateListResponse:
|
||||
"""
|
||||
获取模板列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
page: 页码
|
||||
size: 每页大小
|
||||
category: 分类筛选
|
||||
data_type: 数据类型筛选
|
||||
labeling_type: 标注类型筛选
|
||||
built_in: 是否内置模板筛选
|
||||
|
||||
Returns:
|
||||
模板列表响应
|
||||
"""
|
||||
# 构建查询条件
|
||||
conditions: List = [AnnotationTemplate.deleted_at.is_(None)]
|
||||
|
||||
if category:
|
||||
conditions.append(AnnotationTemplate.category == category) # type: ignore
|
||||
if data_type:
|
||||
conditions.append(AnnotationTemplate.data_type == data_type) # type: ignore
|
||||
if labeling_type:
|
||||
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
|
||||
if built_in is not None:
|
||||
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
|
||||
|
||||
# 查询总数
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(AnnotationTemplate).where(*conditions)
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# 分页查询
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(*conditions)
|
||||
.order_by(AnnotationTemplate.created_at.desc())
|
||||
.limit(size)
|
||||
.offset((page - 1) * size)
|
||||
)
|
||||
templates = result.scalars().all()
|
||||
|
||||
return AnnotationTemplateListResponse(
|
||||
content=[self._to_response(t) for t in templates],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
totalPages=(total + size - 1) // size
|
||||
)
|
||||
|
||||
async def update_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str,
|
||||
request: UpdateAnnotationTemplateRequest
|
||||
) -> Optional[AnnotationTemplateResponse]:
|
||||
"""
|
||||
更新模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
request: 更新请求
|
||||
|
||||
Returns:
|
||||
更新后的模板响应或None
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=False)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == 'configuration' and value is not None:
|
||||
# 验证配置JSON
|
||||
config_dict = value.model_dump(mode='json', by_alias=False)
|
||||
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
|
||||
|
||||
# 重新生成Label Studio XML配置(用于验证)
|
||||
label_config = self.generate_label_studio_config(value)
|
||||
|
||||
# 验证生成的XML
|
||||
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
|
||||
|
||||
# 只更新configuration字段,不存储label_config
|
||||
setattr(template, field, config_dict)
|
||||
else:
|
||||
setattr(template, field, value)
|
||||
|
||||
template.updated_at = datetime.now() # type: ignore
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return self._to_response(template)
|
||||
|
||||
async def delete_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除模板(软删除)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
return False
|
||||
|
||||
template.deleted_at = datetime.now() # type: ignore
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
|
||||
def _to_response(self, template: AnnotationTemplate) -> AnnotationTemplateResponse:
|
||||
"""
|
||||
转换为响应对象
|
||||
|
||||
Args:
|
||||
template: 数据库模型对象
|
||||
|
||||
Returns:
|
||||
模板响应对象
|
||||
"""
|
||||
# 将配置JSON转换为TemplateConfiguration对象
|
||||
from typing import cast, Dict, Any
|
||||
config_dict = cast(Dict[str, Any], template.configuration)
|
||||
config = TemplateConfiguration(**config_dict)
|
||||
|
||||
# 动态生成Label Studio XML配置
|
||||
label_config = self.generate_label_studio_config(config)
|
||||
|
||||
# 使用model_validate从ORM对象创建响应对象
|
||||
response = AnnotationTemplateResponse.model_validate(template)
|
||||
response.configuration = config
|
||||
response.label_config = label_config # type: ignore
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Annotation Module Utilities
|
||||
"""
|
||||
from .config_validator import LabelStudioConfigValidator
|
||||
|
||||
__all__ = ['LabelStudioConfigValidator']
|
||||
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Label Studio Configuration Validation Utilities
|
||||
"""
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
class LabelStudioConfigValidator:
|
||||
"""验证Label Studio配置的工具类"""
|
||||
|
||||
# 支持的控件类型
|
||||
CONTROL_TYPES = {
|
||||
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
|
||||
'TextArea', 'Rating', 'KeyPointLabels', 'BrushLabels',
|
||||
'EllipseLabels', 'VideoRectangle', 'AudioPlus'
|
||||
}
|
||||
|
||||
# 支持的对象类型
|
||||
OBJECT_TYPES = {
|
||||
'Image', 'Text', 'Audio', 'Video', 'HyperText',
|
||||
'AudioPlus', 'Paragraphs', 'Table'
|
||||
}
|
||||
|
||||
# 需要子标签的控件类型
|
||||
LABEL_BASED_CONTROLS = {
|
||||
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
|
||||
'KeyPointLabels', 'BrushLabels', 'EllipseLabels'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_xml(xml_string: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证XML格式是否正确
|
||||
|
||||
Args:
|
||||
xml_string: Label Studio XML配置字符串
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
root = ET.fromstring(xml_string)
|
||||
|
||||
# 检查根元素
|
||||
if root.tag != 'View':
|
||||
return False, "Root element must be <View>"
|
||||
|
||||
# 检查是否有对象定义
|
||||
objects = [child for child in root if child.tag in LabelStudioConfigValidator.OBJECT_TYPES]
|
||||
if not objects:
|
||||
return False, "No data objects (Image, Text, etc.) found"
|
||||
|
||||
# 检查是否有控件定义
|
||||
controls = [child for child in root if child.tag in LabelStudioConfigValidator.CONTROL_TYPES]
|
||||
if not controls:
|
||||
return False, "No annotation controls found"
|
||||
|
||||
# 验证每个控件
|
||||
for control in controls:
|
||||
valid, error = LabelStudioConfigValidator._validate_control(control)
|
||||
if not valid:
|
||||
return False, f"Control {control.tag}: {error}"
|
||||
|
||||
return True, None
|
||||
|
||||
except ET.ParseError as e:
|
||||
return False, f"XML parse error: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"Validation error: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _validate_control(control: ET.Element) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证单个控件元素
|
||||
|
||||
Args:
|
||||
control: 控件XML元素
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
# 检查必需属性
|
||||
if 'name' not in control.attrib:
|
||||
return False, "Missing 'name' attribute"
|
||||
|
||||
if 'toName' not in control.attrib:
|
||||
return False, "Missing 'toName' attribute"
|
||||
|
||||
# 检查标签型控件是否有子标签
|
||||
if control.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
|
||||
labels = control.findall('Label')
|
||||
if not labels:
|
||||
return False, f"{control.tag} must have at least one <Label> child"
|
||||
|
||||
# 检查每个标签是否有value
|
||||
for label in labels:
|
||||
if 'value' not in label.attrib:
|
||||
return False, "Label missing 'value' attribute"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def extract_label_values(xml_string: str) -> Dict[str, List[str]]:
|
||||
"""
|
||||
从XML中提取所有标签值
|
||||
|
||||
Args:
|
||||
xml_string: Label Studio XML配置字符串
|
||||
|
||||
Returns:
|
||||
字典,键为控件名称,值为标签值列表
|
||||
"""
|
||||
result = {}
|
||||
|
||||
try:
|
||||
root = ET.fromstring(xml_string)
|
||||
controls = [child for child in root if child.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS]
|
||||
|
||||
for control in controls:
|
||||
control_name = control.get('name', 'unknown')
|
||||
labels = control.findall('Label')
|
||||
label_values = [label.get('value', '') for label in labels]
|
||||
result[control_name] = label_values
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def validate_configuration_json(config: Dict) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证配置JSON结构
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
# 检查必需字段
|
||||
if 'labels' not in config:
|
||||
return False, "Missing 'labels' field"
|
||||
|
||||
if 'objects' not in config:
|
||||
return False, "Missing 'objects' field"
|
||||
|
||||
if not isinstance(config['labels'], list):
|
||||
return False, "'labels' must be an array"
|
||||
|
||||
if not isinstance(config['objects'], list):
|
||||
return False, "'objects' must be an array"
|
||||
|
||||
if not config['labels']:
|
||||
return False, "'labels' array cannot be empty"
|
||||
|
||||
if not config['objects']:
|
||||
return False, "'objects' array cannot be empty"
|
||||
|
||||
# 验证每个标签定义
|
||||
for idx, label in enumerate(config['labels']):
|
||||
valid, error = LabelStudioConfigValidator._validate_label_definition(label)
|
||||
if not valid:
|
||||
return False, f"Label {idx}: {error}"
|
||||
|
||||
# 验证每个对象定义
|
||||
for idx, obj in enumerate(config['objects']):
|
||||
valid, error = LabelStudioConfigValidator._validate_object_definition(obj)
|
||||
if not valid:
|
||||
return False, f"Object {idx}: {error}"
|
||||
|
||||
# 验证toName引用
|
||||
object_names = {obj['name'] for obj in config['objects']}
|
||||
for label in config['labels']:
|
||||
to_name = label.get('toName') or label.get('to_name')
|
||||
from_name = label.get('fromName') or label.get('from_name')
|
||||
if to_name not in object_names:
|
||||
return False, f"Label '{from_name}' references unknown object '{to_name}'"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _validate_label_definition(label: Dict) -> Tuple[bool, Optional[str]]:
|
||||
"""验证标签定义"""
|
||||
# Support both camelCase and snake_case
|
||||
from_name = label.get('fromName') or label.get('from_name')
|
||||
to_name = label.get('toName') or label.get('to_name')
|
||||
label_type = label.get('type')
|
||||
|
||||
if not from_name:
|
||||
return False, "Missing required field 'fromName'"
|
||||
if not to_name:
|
||||
return False, "Missing required field 'toName'"
|
||||
if not label_type:
|
||||
return False, "Missing required field 'type'"
|
||||
|
||||
# 检查类型是否支持
|
||||
if label_type not in LabelStudioConfigValidator.CONTROL_TYPES:
|
||||
return False, f"Unsupported control type '{label_type}'"
|
||||
|
||||
# 检查标签型控件是否有选项或标签
|
||||
if label_type in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
|
||||
if 'options' not in label and 'labels' not in label:
|
||||
return False, f"{label_type} must have 'options' or 'labels' field"
|
||||
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
def _validate_object_definition(obj: Dict) -> Tuple[bool, Optional[str]]:
|
||||
"""验证对象定义"""
|
||||
required_fields = ['name', 'type', 'value']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in obj:
|
||||
return False, f"Missing required field '{field}'"
|
||||
|
||||
# 检查类型是否支持
|
||||
if obj['type'] not in LabelStudioConfigValidator.OBJECT_TYPES:
|
||||
return False, f"Unsupported object type '{obj['type']}'"
|
||||
|
||||
# 检查value格式
|
||||
if not obj['value'].startswith('$'):
|
||||
return False, "Object value must start with '$' (e.g., '$image')"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 验证XML
|
||||
xml = """<View>
|
||||
<Image name="image" value="$image"/>
|
||||
<Choices name="choice" toName="image" required="true">
|
||||
<Label value="Cat"/>
|
||||
<Label value="Dog"/>
|
||||
</Choices>
|
||||
</View>"""
|
||||
|
||||
valid, error = LabelStudioConfigValidator.validate_xml(xml)
|
||||
print(f"XML Valid: {valid}, Error: {error}")
|
||||
|
||||
# 验证配置JSON
|
||||
config = {
|
||||
"labels": [
|
||||
{
|
||||
"fromName": "choice",
|
||||
"toName": "image",
|
||||
"type": "Choices",
|
||||
"options": ["Cat", "Dog"],
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"objects": [
|
||||
{
|
||||
"name": "image",
|
||||
"type": "Image",
|
||||
"value": "$image"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
valid, error = LabelStudioConfigValidator.validate_configuration_json(config)
|
||||
print(f"Config Valid: {valid}, Error: {error}")
|
||||
Reference in New Issue
Block a user