feat: Add labeling template (#72)

* feat: Enhance annotation module with template management and validation

- Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support.
- Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates.
- Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation.
- Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats.
- Updated database schema for annotation templates and labeling projects to include new fields and constraints.
- Seeded initial annotation templates for various use cases including image classification, object detection, and text classification.

* feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support

* feat: Update docker-compose.yml to mark datamate dataset volume and network as external
This commit is contained in:
Jason Wang
2025-11-11 09:14:14 +08:00
committed by GitHub
parent 451d3c8207
commit c5ccc56cca
24 changed files with 2794 additions and 253 deletions

View File

@@ -103,6 +103,7 @@ class Client:
"""创建Label Studio项目"""
try:
logger.debug(f"Creating Label Studio project: {title}")
logger.debug(f"Label Studio URL: {self.base_url}/api/projects")
project_data = {
"title": title,
@@ -123,10 +124,28 @@ class Client:
return project
except httpx.HTTPStatusError as e:
logger.error(f"Create project failed HTTP {e.response.status_code}: {e.response.text}")
logger.error(
f"Create project failed - HTTP {e.response.status_code}\n"
f"URL: {e.request.url}\n"
f"Response Headers: {dict(e.response.headers)}\n"
f"Response Body: {e.response.text[:1000]}" # First 1000 chars
)
return None
except httpx.ConnectError as e:
logger.error(
f"Failed to connect to Label Studio at {self.base_url}\n"
f"Error: {str(e)}\n"
f"Possible causes:\n"
f" - Label Studio service is not running\n"
f" - Incorrect URL configuration\n"
f" - Network connectivity issue"
)
return None
except httpx.TimeoutException as e:
logger.error(f"Request to Label Studio timed out after {self.timeout}s: {str(e)}")
return None
except Exception as e:
logger.error(f"Error while creating Label Studio project: {e}")
logger.error(f"Error while creating Label Studio project: {str(e)}", exc_info=True)
return None
async def import_tasks(

View File

@@ -3,6 +3,7 @@ from fastapi import APIRouter
from .about import router as about_router
from .project import router as project_router
from .task import router as task_router
from .template import router as template_router
router = APIRouter(
prefix="/annotation",
@@ -11,4 +12,5 @@ router = APIRouter(
router.include_router(about_router)
router.include_router(project_router)
router.include_router(task_router)
router.include_router(task_router)
router.include_router(template_router)

View File

@@ -1,5 +1,6 @@
from typing import Optional
import math
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
@@ -14,6 +15,7 @@ from app.core.config import settings
from ..client import LabelStudioClient
from ..service.mapping import DatasetMappingService
from ..service.sync import SyncService
from ..service.template import AnnotationTemplateService
from ..schema import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
@@ -39,6 +41,8 @@ async def create_mapping(
在数据库中记录这一关联关系,返回Label Studio数据集的ID
注意:一个数据集可以创建多个标注项目
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
"""
try:
dm_client = DatasetManagementService(db)
@@ -46,6 +50,7 @@ async def create_mapping(
token=settings.label_studio_user_token)
mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client, ls_client, mapping_service)
template_service = AnnotationTemplateService()
logger.info(f"Create dataset mapping request: {request.dataset_id}")
@@ -65,10 +70,24 @@ async def create_mapping(
dataset_info.description or \
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
# 如果提供了模板ID,获取模板配置
label_config = None
if request.template_id:
logger.info(f"Using template: {request.template_id}")
template = await template_service.get_template(db, request.template_id)
if not template:
raise HTTPException(
status_code=404,
detail=f"Template not found: {request.template_id}"
)
label_config = template.label_config
logger.debug(f"Template label config loaded for template: {template.name}")
# 在Label Studio中创建项目
project_data = await ls_client.create_project(
title=project_name,
description=project_description,
label_config=label_config # 传递模板配置
)
if not project_data:
@@ -96,9 +115,11 @@ async def create_mapping(
logger.info(f"Local storage configured for project {project_id}: {local_storage_path}")
labeling_project = LabelingProject(
id=str(uuid.uuid4()), # Generate UUID here
dataset_id=request.dataset_id,
labeling_project_id=str(project_id),
name=project_name,
template_id=request.template_id, # Save template_id to database
)
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)

View File

@@ -0,0 +1,137 @@
"""
Annotation Template API Endpoints
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.module.shared.schema import StandardResponse
from app.module.annotation.schema.template import (
CreateAnnotationTemplateRequest,
UpdateAnnotationTemplateRequest,
AnnotationTemplateResponse,
AnnotationTemplateListResponse
)
from app.module.annotation.service.template import AnnotationTemplateService
router = APIRouter(prefix="/templates", tags=["Annotation Template"])
template_service = AnnotationTemplateService()
@router.post(
"",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="创建标注模板"
)
async def create_template(
request: CreateAnnotationTemplateRequest,
db: AsyncSession = Depends(get_db)
):
"""
创建新的标注模板
- **name**: 模板名称(必填,最多100字符)
- **description**: 模板描述(可选,最多500字符)
- **dataType**: 数据类型(必填)
- **labelingType**: 标注类型(必填)
- **configuration**: 标注配置(必填,包含labels和objects)
- **style**: 样式配置(默认horizontal)
- **category**: 模板分类(默认custom)
"""
template = await template_service.create_template(db, request)
return StandardResponse(code=200, message="success", data=template)
@router.get(
"/{template_id}",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="获取模板详情"
)
async def get_template(
template_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据ID获取模板详情
"""
template = await template_service.get_template(db, template_id)
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return StandardResponse(code=200, message="success", data=template)
@router.get(
"",
response_model=StandardResponse[AnnotationTemplateListResponse],
summary="获取模板列表"
)
async def list_templates(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(10, ge=1, le=100, description="每页大小"),
category: Optional[str] = Query(None, description="分类筛选"),
dataType: Optional[str] = Query(None, alias="dataType", description="数据类型筛选"),
labelingType: Optional[str] = Query(None, alias="labelingType", description="标注类型筛选"),
builtIn: Optional[bool] = Query(None, alias="builtIn", description="是否内置模板"),
db: AsyncSession = Depends(get_db)
):
"""
获取模板列表,支持分页和筛选
- **page**: 页码(从1开始)
- **size**: 每页大小(1-100)
- **category**: 模板分类筛选
- **dataType**: 数据类型筛选
- **labelingType**: 标注类型筛选
- **builtIn**: 是否只显示内置模板
"""
templates = await template_service.list_templates(
db=db,
page=page,
size=size,
category=category,
data_type=dataType,
labeling_type=labelingType,
built_in=builtIn
)
return StandardResponse(code=200, message="success", data=templates)
@router.put(
"/{template_id}",
response_model=StandardResponse[AnnotationTemplateResponse],
summary="更新模板"
)
async def update_template(
template_id: str,
request: UpdateAnnotationTemplateRequest,
db: AsyncSession = Depends(get_db)
):
"""
更新模板信息
所有字段都是可选的,只更新提供的字段
"""
template = await template_service.update_template(db, template_id, request)
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return StandardResponse(code=200, message="success", data=template)
@router.delete(
"/{template_id}",
response_model=StandardResponse[bool],
summary="删除模板"
)
async def delete_template(
template_id: str,
db: AsyncSession = Depends(get_db)
):
"""
删除模板(软删除)
"""
success = await template_service.delete_template(db, template_id)
if not success:
raise HTTPException(status_code=404, detail="Template not found")
return StandardResponse(code=200, message="success", data=True)

View File

@@ -11,13 +11,14 @@ class DatasetMappingCreateRequest(BaseModel):
Accept both snake_case and camelCase field names from frontend JSON by
declaring explicit aliases. Frontend sends `datasetId`, `name`,
`description` (camelCase), so provide aliases so pydantic will map them
`description`, `templateId` (camelCase), so provide aliases so pydantic will map them
to the internal attributes used in the service code (dataset_id, name,
description).
description, template_id).
"""
dataset_id: str = Field(..., alias="datasetId", description="源数据集ID")
name: Optional[str] = Field(None, alias="name", description="标注项目名称")
description: Optional[str] = Field(None, alias="description", description="标注项目描述")
template_id: Optional[str] = Field(None, alias="templateId", description="标注模板ID")
class Config:
# allow population by field name when constructing model programmatically
@@ -34,13 +35,16 @@ class DatasetMappingUpdateRequest(BaseResponseModel):
dataset_id: Optional[str] = Field(None, description="源数据集ID")
class DatasetMappingResponse(BaseModel):
dataset_id: str = Field(..., description="源数据集ID")
"""数据集映射 查询 响应模型"""
id: str = Field(..., description="映射UUID")
labeling_project_id: str = Field(..., description="标注项目ID")
dataset_id: str = Field(..., alias="datasetId", description="源数据集ID")
dataset_name: Optional[str] = Field(None, alias="datasetName", description="数据集名称")
labeling_project_id: str = Field(..., alias="labelingProjectId", description="标注项目ID")
name: Optional[str] = Field(None, description="标注项目名称")
created_at: datetime = Field(..., description="创建时间")
deleted_at: Optional[datetime] = Field(None, description="删除时间")
description: Optional[str] = Field(None, description="标注项目描述")
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
deleted_at: Optional[datetime] = Field(None, alias="deletedAt", description="删除时间")
class Config:
from_attributes = True

View File

@@ -0,0 +1,93 @@
"""
Annotation Template Schemas
"""
from typing import List, Dict, Any, Optional, Literal
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
class LabelDefinition(BaseModel):
"""标签定义"""
from_name: str = Field(alias="fromName", description="控件名称")
to_name: str = Field(alias="toName", description="目标对象名称")
type: str = Field(description="控件类型: choices/rectanglelabels/polygonlabels/textarea/etc")
options: Optional[List[str]] = Field(None, description="选项列表(用于choices类型)")
labels: Optional[List[str]] = Field(None, description="标签列表(用于rectanglelabels等类型)")
required: bool = Field(False, description="是否必填")
description: Optional[str] = Field(None, description="标签描述")
model_config = ConfigDict(populate_by_name=True)
class ObjectDefinition(BaseModel):
"""对象定义"""
name: str = Field(description="对象标识符")
type: str = Field(description="对象类型: Image/Text/Audio/Video/etc")
value: str = Field(description="变量名,如$image")
model_config = ConfigDict(populate_by_name=True)
class TemplateConfiguration(BaseModel):
"""模板配置结构"""
labels: List[LabelDefinition] = Field(description="标签定义列表")
objects: List[ObjectDefinition] = Field(description="对象定义列表")
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
model_config = ConfigDict(populate_by_name=True)
class CreateAnnotationTemplateRequest(BaseModel):
"""创建标注模板请求"""
name: str = Field(..., min_length=1, max_length=100, description="模板名称")
description: Optional[str] = Field(None, max_length=500, description="模板描述")
data_type: str = Field(alias="dataType", description="数据类型")
labeling_type: str = Field(alias="labelingType", description="标注类型")
configuration: TemplateConfiguration = Field(..., description="标注配置")
style: str = Field(default="horizontal", description="样式配置")
category: str = Field(default="custom", description="模板分类")
model_config = ConfigDict(populate_by_name=True)
class UpdateAnnotationTemplateRequest(BaseModel):
"""更新标注模板请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称")
description: Optional[str] = Field(None, max_length=500, description="模板描述")
data_type: Optional[str] = Field(None, alias="dataType", description="数据类型")
labeling_type: Optional[str] = Field(None, alias="labelingType", description="标注类型")
configuration: Optional[TemplateConfiguration] = Field(None, description="标注配置")
style: Optional[str] = Field(None, description="样式配置")
category: Optional[str] = Field(None, description="模板分类")
model_config = ConfigDict(populate_by_name=True)
class AnnotationTemplateResponse(BaseModel):
"""标注模板响应"""
id: str = Field(..., description="模板ID")
name: str = Field(..., description="模板名称")
description: Optional[str] = Field(None, description="模板描述")
data_type: str = Field(alias="dataType", description="数据类型")
labeling_type: str = Field(alias="labelingType", description="标注类型")
configuration: TemplateConfiguration = Field(..., description="标注配置")
label_config: Optional[str] = Field(None, alias="labelConfig", description="生成的Label Studio XML配置")
style: str = Field(..., description="样式配置")
category: str = Field(..., description="模板分类")
built_in: bool = Field(alias="builtIn", description="是否内置模板")
version: str = Field(..., description="版本号")
created_at: datetime = Field(alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
model_config = ConfigDict(populate_by_name=True, from_attributes=True)
class AnnotationTemplateListResponse(BaseModel):
"""模板列表响应"""
content: List[AnnotationTemplateResponse] = Field(..., description="模板列表")
total: int = Field(..., description="总数")
page: int = Field(..., description="当前页")
size: int = Field(..., description="每页大小")
total_pages: int = Field(alias="totalPages", description="总页数")
model_config = ConfigDict(populate_by_name=True)

View File

@@ -1,12 +1,14 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from sqlalchemy.orm import aliased
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.db.models.dataset_management import Dataset
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
@@ -21,6 +23,61 @@ class DatasetMappingService:
def __init__(self, db: AsyncSession):
self.db = db
def _build_query_with_dataset_name(self):
"""Build base query with dataset name joined"""
return select(
LabelingProject,
Dataset.name.label('dataset_name')
).outerjoin(
Dataset,
LabelingProject.dataset_id == Dataset.id
)
def _to_response_from_row(self, row) -> DatasetMappingResponse:
"""Convert query row (mapping + dataset_name) to response"""
mapping = row[0] # LabelingProject object
dataset_name = row[1] # dataset_name from join
response_data = {
"id": mapping.id,
"dataset_id": mapping.dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
"""Convert ORM model to response with dataset name (for single entity operations)"""
# Fetch dataset name
dataset_name = None
dataset_id = getattr(mapping, 'dataset_id', None)
if dataset_id:
dataset_result = await self.db.execute(
select(Dataset.name).where(Dataset.id == dataset_id)
)
dataset_name = dataset_result.scalar_one_or_none()
# Create response dict with all fields
response_data = {
"id": mapping.id,
"dataset_id": dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def create_mapping(
self,
labeling_project: LabelingProject
@@ -28,19 +85,13 @@ class DatasetMappingService:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
db_mapping = LabelingProject(
id=str(uuid.uuid4()),
dataset_id=labeling_project.dataset_id,
labeling_project_id=labeling_project.labeling_project_id,
name=labeling_project.name
)
self.db.add(db_mapping)
# Use the passed object directly
self.db.add(labeling_project)
await self.db.commit()
await self.db.refresh(db_mapping)
await self.db.refresh(labeling_project)
logger.debug(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
logger.debug(f"Mapping created: {labeling_project.id}")
return await self._to_response(labeling_project)
async def get_mapping_by_source_uuid(
self,
@@ -59,7 +110,7 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
@@ -72,7 +123,7 @@ class DatasetMappingService:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -82,10 +133,10 @@ class DatasetMappingService:
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def get_mapping_by_labeling_project_id(
self,
@@ -103,8 +154,8 @@ class DatasetMappingService:
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Found mapping: {mapping.id}")
return await self._to_response(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
@@ -123,9 +174,9 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
logger.debug(f"No mapping found for mapping id: {mapping_id}")
return None
async def update_mapping(
@@ -184,17 +235,20 @@ class DatasetMappingService:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
query = self._build_query_with_dataset_name().where(
LabelingProject.deleted_at.is_(None)
)
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
@@ -216,7 +270,7 @@ class DatasetMappingService:
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
query = self._build_query_with_dataset_name()
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
@@ -235,10 +289,10 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
async def get_mappings_by_source_with_count(
self,
@@ -251,7 +305,7 @@ class DatasetMappingService:
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -275,7 +329,7 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total

View File

@@ -0,0 +1,327 @@
"""
Annotation Template Service
"""
from typing import Optional, List
from datetime import datetime
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import uuid4
from fastapi import HTTPException
from app.db.models.annotation_management import AnnotationTemplate
from app.module.annotation.schema.template import (
CreateAnnotationTemplateRequest,
UpdateAnnotationTemplateRequest,
AnnotationTemplateResponse,
AnnotationTemplateListResponse,
TemplateConfiguration
)
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
class AnnotationTemplateService:
"""标注模板服务"""
@staticmethod
def generate_label_studio_config(config: TemplateConfiguration) -> str:
"""
从配置JSON生成Label Studio XML配置
Args:
config: 模板配置对象
Returns:
Label Studio XML字符串
"""
xml_parts = ['<View>']
# 生成对象定义
for obj in config.objects:
obj_attrs = [
f'name="{obj.name}"',
f'value="{obj.value}"'
]
xml_parts.append(f' <{obj.type} {" ".join(obj_attrs)}/>')
# 生成标签定义
for label in config.labels:
label_attrs = [
f'name="{label.from_name}"',
f'toName="{label.to_name}"'
]
# 添加可选属性
if label.required:
label_attrs.append('required="true"')
tag_type = label.type.capitalize() if label.type else "Choices"
# 处理带选项的标签类型
if label.options or label.labels:
choices = label.options or label.labels or []
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
for choice in choices:
xml_parts.append(f' <Label value="{choice}"/>')
xml_parts.append(f' </{tag_type}>')
else:
# 处理简单标签类型
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
xml_parts.append('</View>')
return '\n'.join(xml_parts)
async def create_template(
self,
db: AsyncSession,
request: CreateAnnotationTemplateRequest
) -> AnnotationTemplateResponse:
"""
创建标注模板
Args:
db: 数据库会话
request: 创建请求
Returns:
创建的模板响应
"""
# 验证配置JSON
config_dict = request.configuration.model_dump(mode='json', by_alias=False)
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
if not valid:
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
# 生成Label Studio XML配置(用于验证,但不存储)
label_config = self.generate_label_studio_config(request.configuration)
# 验证生成的XML
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
# 创建模板对象(不包含label_config字段)
template = AnnotationTemplate(
id=str(uuid4()),
name=request.name,
description=request.description,
data_type=request.data_type,
labeling_type=request.labeling_type,
configuration=config_dict,
style=request.style,
category=request.category,
built_in=False,
version="1.0.0",
created_at=datetime.now()
)
db.add(template)
await db.commit()
await db.refresh(template)
return self._to_response(template)
async def get_template(
self,
db: AsyncSession,
template_id: str
) -> Optional[AnnotationTemplateResponse]:
"""
获取单个模板
Args:
db: 数据库会话
template_id: 模板ID
Returns:
模板响应或None
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if template:
return self._to_response(template)
return None
async def list_templates(
self,
db: AsyncSession,
page: int = 1,
size: int = 10,
category: Optional[str] = None,
data_type: Optional[str] = None,
labeling_type: Optional[str] = None,
built_in: Optional[bool] = None
) -> AnnotationTemplateListResponse:
"""
获取模板列表
Args:
db: 数据库会话
page: 页码
size: 每页大小
category: 分类筛选
data_type: 数据类型筛选
labeling_type: 标注类型筛选
built_in: 是否内置模板筛选
Returns:
模板列表响应
"""
# 构建查询条件
conditions: List = [AnnotationTemplate.deleted_at.is_(None)]
if category:
conditions.append(AnnotationTemplate.category == category) # type: ignore
if data_type:
conditions.append(AnnotationTemplate.data_type == data_type) # type: ignore
if labeling_type:
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
if built_in is not None:
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
# 查询总数
count_result = await db.execute(
select(func.count()).select_from(AnnotationTemplate).where(*conditions)
)
total = count_result.scalar() or 0
# 分页查询
result = await db.execute(
select(AnnotationTemplate)
.where(*conditions)
.order_by(AnnotationTemplate.created_at.desc())
.limit(size)
.offset((page - 1) * size)
)
templates = result.scalars().all()
return AnnotationTemplateListResponse(
content=[self._to_response(t) for t in templates],
total=total,
page=page,
size=size,
totalPages=(total + size - 1) // size
)
async def update_template(
self,
db: AsyncSession,
template_id: str,
request: UpdateAnnotationTemplateRequest
) -> Optional[AnnotationTemplateResponse]:
"""
更新模板
Args:
db: 数据库会话
template_id: 模板ID
request: 更新请求
Returns:
更新后的模板响应或None
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if not template:
return None
# 更新字段
update_data = request.model_dump(exclude_unset=True, by_alias=False)
for field, value in update_data.items():
if field == 'configuration' and value is not None:
# 验证配置JSON
config_dict = value.model_dump(mode='json', by_alias=False)
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
if not valid:
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
# 重新生成Label Studio XML配置(用于验证)
label_config = self.generate_label_studio_config(value)
# 验证生成的XML
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
# 只更新configuration字段,不存储label_config
setattr(template, field, config_dict)
else:
setattr(template, field, value)
template.updated_at = datetime.now() # type: ignore
await db.commit()
await db.refresh(template)
return self._to_response(template)
async def delete_template(
self,
db: AsyncSession,
template_id: str
) -> bool:
"""
删除模板(软删除)
Args:
db: 数据库会话
template_id: 模板ID
Returns:
是否删除成功
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if not template:
return False
template.deleted_at = datetime.now() # type: ignore
await db.commit()
return True
def _to_response(self, template: AnnotationTemplate) -> AnnotationTemplateResponse:
"""
转换为响应对象
Args:
template: 数据库模型对象
Returns:
模板响应对象
"""
# 将配置JSON转换为TemplateConfiguration对象
from typing import cast, Dict, Any
config_dict = cast(Dict[str, Any], template.configuration)
config = TemplateConfiguration(**config_dict)
# 动态生成Label Studio XML配置
label_config = self.generate_label_studio_config(config)
# 使用model_validate从ORM对象创建响应对象
response = AnnotationTemplateResponse.model_validate(template)
response.configuration = config
response.label_config = label_config # type: ignore
return response

View File

@@ -0,0 +1,6 @@
"""
Annotation Module Utilities
"""
from .config_validator import LabelStudioConfigValidator
__all__ = ['LabelStudioConfigValidator']

View File

@@ -0,0 +1,263 @@
"""
Label Studio Configuration Validation Utilities
"""
from typing import Dict, List, Tuple, Optional
import xml.etree.ElementTree as ET
class LabelStudioConfigValidator:
"""验证Label Studio配置的工具类"""
# 支持的控件类型
CONTROL_TYPES = {
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
'TextArea', 'Rating', 'KeyPointLabels', 'BrushLabels',
'EllipseLabels', 'VideoRectangle', 'AudioPlus'
}
# 支持的对象类型
OBJECT_TYPES = {
'Image', 'Text', 'Audio', 'Video', 'HyperText',
'AudioPlus', 'Paragraphs', 'Table'
}
# 需要子标签的控件类型
LABEL_BASED_CONTROLS = {
'Choices', 'RectangleLabels', 'PolygonLabels', 'Labels',
'KeyPointLabels', 'BrushLabels', 'EllipseLabels'
}
@staticmethod
def validate_xml(xml_string: str) -> Tuple[bool, Optional[str]]:
"""
验证XML格式是否正确
Args:
xml_string: Label Studio XML配置字符串
Returns:
(是否有效, 错误信息)
"""
try:
root = ET.fromstring(xml_string)
# 检查根元素
if root.tag != 'View':
return False, "Root element must be <View>"
# 检查是否有对象定义
objects = [child for child in root if child.tag in LabelStudioConfigValidator.OBJECT_TYPES]
if not objects:
return False, "No data objects (Image, Text, etc.) found"
# 检查是否有控件定义
controls = [child for child in root if child.tag in LabelStudioConfigValidator.CONTROL_TYPES]
if not controls:
return False, "No annotation controls found"
# 验证每个控件
for control in controls:
valid, error = LabelStudioConfigValidator._validate_control(control)
if not valid:
return False, f"Control {control.tag}: {error}"
return True, None
except ET.ParseError as e:
return False, f"XML parse error: {str(e)}"
except Exception as e:
return False, f"Validation error: {str(e)}"
@staticmethod
def _validate_control(control: ET.Element) -> Tuple[bool, Optional[str]]:
"""
验证单个控件元素
Args:
control: 控件XML元素
Returns:
(是否有效, 错误信息)
"""
# 检查必需属性
if 'name' not in control.attrib:
return False, "Missing 'name' attribute"
if 'toName' not in control.attrib:
return False, "Missing 'toName' attribute"
# 检查标签型控件是否有子标签
if control.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
labels = control.findall('Label')
if not labels:
return False, f"{control.tag} must have at least one <Label> child"
# 检查每个标签是否有value
for label in labels:
if 'value' not in label.attrib:
return False, "Label missing 'value' attribute"
return True, None
@staticmethod
def extract_label_values(xml_string: str) -> Dict[str, List[str]]:
"""
从XML中提取所有标签值
Args:
xml_string: Label Studio XML配置字符串
Returns:
字典,键为控件名称,值为标签值列表
"""
result = {}
try:
root = ET.fromstring(xml_string)
controls = [child for child in root if child.tag in LabelStudioConfigValidator.LABEL_BASED_CONTROLS]
for control in controls:
control_name = control.get('name', 'unknown')
labels = control.findall('Label')
label_values = [label.get('value', '') for label in labels]
result[control_name] = label_values
except Exception:
pass
return result
@staticmethod
def validate_configuration_json(config: Dict) -> Tuple[bool, Optional[str]]:
"""
验证配置JSON结构
Args:
config: 配置字典
Returns:
(是否有效, 错误信息)
"""
# 检查必需字段
if 'labels' not in config:
return False, "Missing 'labels' field"
if 'objects' not in config:
return False, "Missing 'objects' field"
if not isinstance(config['labels'], list):
return False, "'labels' must be an array"
if not isinstance(config['objects'], list):
return False, "'objects' must be an array"
if not config['labels']:
return False, "'labels' array cannot be empty"
if not config['objects']:
return False, "'objects' array cannot be empty"
# 验证每个标签定义
for idx, label in enumerate(config['labels']):
valid, error = LabelStudioConfigValidator._validate_label_definition(label)
if not valid:
return False, f"Label {idx}: {error}"
# 验证每个对象定义
for idx, obj in enumerate(config['objects']):
valid, error = LabelStudioConfigValidator._validate_object_definition(obj)
if not valid:
return False, f"Object {idx}: {error}"
# 验证toName引用
object_names = {obj['name'] for obj in config['objects']}
for label in config['labels']:
to_name = label.get('toName') or label.get('to_name')
from_name = label.get('fromName') or label.get('from_name')
if to_name not in object_names:
return False, f"Label '{from_name}' references unknown object '{to_name}'"
return True, None
@staticmethod
def _validate_label_definition(label: Dict) -> Tuple[bool, Optional[str]]:
"""验证标签定义"""
# Support both camelCase and snake_case
from_name = label.get('fromName') or label.get('from_name')
to_name = label.get('toName') or label.get('to_name')
label_type = label.get('type')
if not from_name:
return False, "Missing required field 'fromName'"
if not to_name:
return False, "Missing required field 'toName'"
if not label_type:
return False, "Missing required field 'type'"
# 检查类型是否支持
if label_type not in LabelStudioConfigValidator.CONTROL_TYPES:
return False, f"Unsupported control type '{label_type}'"
# 检查标签型控件是否有选项或标签
if label_type in LabelStudioConfigValidator.LABEL_BASED_CONTROLS:
if 'options' not in label and 'labels' not in label:
return False, f"{label_type} must have 'options' or 'labels' field"
return True, None
@staticmethod
def _validate_object_definition(obj: Dict) -> Tuple[bool, Optional[str]]:
"""验证对象定义"""
required_fields = ['name', 'type', 'value']
for field in required_fields:
if field not in obj:
return False, f"Missing required field '{field}'"
# 检查类型是否支持
if obj['type'] not in LabelStudioConfigValidator.OBJECT_TYPES:
return False, f"Unsupported object type '{obj['type']}'"
# 检查value格式
if not obj['value'].startswith('$'):
return False, "Object value must start with '$' (e.g., '$image')"
return True, None
# 使用示例
if __name__ == "__main__":
# 验证XML
xml = """<View>
<Image name="image" value="$image"/>
<Choices name="choice" toName="image" required="true">
<Label value="Cat"/>
<Label value="Dog"/>
</Choices>
</View>"""
valid, error = LabelStudioConfigValidator.validate_xml(xml)
print(f"XML Valid: {valid}, Error: {error}")
# 验证配置JSON
config = {
"labels": [
{
"fromName": "choice",
"toName": "image",
"type": "Choices",
"options": ["Cat", "Dog"],
"required": True
}
],
"objects": [
{
"name": "image",
"type": "Image",
"value": "$image"
}
]
}
valid, error = LabelStudioConfigValidator.validate_configuration_json(config)
print(f"Config Valid: {valid}, Error: {error}")