feat: Add labeling template (#72)

* feat: Enhance annotation module with template management and validation

- Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support.
- Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates.
- Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation.
- Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats.
- Updated database schema for annotation templates and labeling projects to include new fields and constraints.
- Seeded initial annotation templates for various use cases including image classification, object detection, and text classification.

* feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support

* feat: Update docker-compose.yml to mark datamate dataset volume and network as external
This commit is contained in:
Jason Wang
2025-11-11 09:14:14 +08:00
committed by GitHub
parent 451d3c8207
commit c5ccc56cca
24 changed files with 2794 additions and 253 deletions

View File

@@ -1,12 +1,14 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from sqlalchemy.orm import aliased
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.db.models.dataset_management import Dataset
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
@@ -21,6 +23,61 @@ class DatasetMappingService:
def __init__(self, db: AsyncSession):
self.db = db
def _build_query_with_dataset_name(self):
"""Build base query with dataset name joined"""
return select(
LabelingProject,
Dataset.name.label('dataset_name')
).outerjoin(
Dataset,
LabelingProject.dataset_id == Dataset.id
)
def _to_response_from_row(self, row) -> DatasetMappingResponse:
"""Convert query row (mapping + dataset_name) to response"""
mapping = row[0] # LabelingProject object
dataset_name = row[1] # dataset_name from join
response_data = {
"id": mapping.id,
"dataset_id": mapping.dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
"""Convert ORM model to response with dataset name (for single entity operations)"""
# Fetch dataset name
dataset_name = None
dataset_id = getattr(mapping, 'dataset_id', None)
if dataset_id:
dataset_result = await self.db.execute(
select(Dataset.name).where(Dataset.id == dataset_id)
)
dataset_name = dataset_result.scalar_one_or_none()
# Create response dict with all fields
response_data = {
"id": mapping.id,
"dataset_id": dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def create_mapping(
self,
labeling_project: LabelingProject
@@ -28,19 +85,13 @@ class DatasetMappingService:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
db_mapping = LabelingProject(
id=str(uuid.uuid4()),
dataset_id=labeling_project.dataset_id,
labeling_project_id=labeling_project.labeling_project_id,
name=labeling_project.name
)
self.db.add(db_mapping)
# Use the passed object directly
self.db.add(labeling_project)
await self.db.commit()
await self.db.refresh(db_mapping)
await self.db.refresh(labeling_project)
logger.debug(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
logger.debug(f"Mapping created: {labeling_project.id}")
return await self._to_response(labeling_project)
async def get_mapping_by_source_uuid(
self,
@@ -59,7 +110,7 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
@@ -72,7 +123,7 @@ class DatasetMappingService:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -82,10 +133,10 @@ class DatasetMappingService:
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def get_mapping_by_labeling_project_id(
self,
@@ -103,8 +154,8 @@ class DatasetMappingService:
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Found mapping: {mapping.id}")
return await self._to_response(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
@@ -123,9 +174,9 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
logger.debug(f"No mapping found for mapping id: {mapping_id}")
return None
async def update_mapping(
@@ -184,17 +235,20 @@ class DatasetMappingService:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
query = self._build_query_with_dataset_name().where(
LabelingProject.deleted_at.is_(None)
)
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
@@ -216,7 +270,7 @@ class DatasetMappingService:
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
query = self._build_query_with_dataset_name()
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
@@ -235,10 +289,10 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
async def get_mappings_by_source_with_count(
self,
@@ -251,7 +305,7 @@ class DatasetMappingService:
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -275,7 +329,7 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total

View File

@@ -0,0 +1,327 @@
"""
Annotation Template Service
"""
from typing import Optional, List
from datetime import datetime
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import uuid4
from fastapi import HTTPException
from app.db.models.annotation_management import AnnotationTemplate
from app.module.annotation.schema.template import (
CreateAnnotationTemplateRequest,
UpdateAnnotationTemplateRequest,
AnnotationTemplateResponse,
AnnotationTemplateListResponse,
TemplateConfiguration
)
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
class AnnotationTemplateService:
"""标注模板服务"""
@staticmethod
def generate_label_studio_config(config: TemplateConfiguration) -> str:
"""
从配置JSON生成Label Studio XML配置
Args:
config: 模板配置对象
Returns:
Label Studio XML字符串
"""
xml_parts = ['<View>']
# 生成对象定义
for obj in config.objects:
obj_attrs = [
f'name="{obj.name}"',
f'value="{obj.value}"'
]
xml_parts.append(f' <{obj.type} {" ".join(obj_attrs)}/>')
# 生成标签定义
for label in config.labels:
label_attrs = [
f'name="{label.from_name}"',
f'toName="{label.to_name}"'
]
# 添加可选属性
if label.required:
label_attrs.append('required="true"')
tag_type = label.type.capitalize() if label.type else "Choices"
# 处理带选项的标签类型
if label.options or label.labels:
choices = label.options or label.labels or []
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
for choice in choices:
xml_parts.append(f' <Label value="{choice}"/>')
xml_parts.append(f' </{tag_type}>')
else:
# 处理简单标签类型
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
xml_parts.append('</View>')
return '\n'.join(xml_parts)
async def create_template(
self,
db: AsyncSession,
request: CreateAnnotationTemplateRequest
) -> AnnotationTemplateResponse:
"""
创建标注模板
Args:
db: 数据库会话
request: 创建请求
Returns:
创建的模板响应
"""
# 验证配置JSON
config_dict = request.configuration.model_dump(mode='json', by_alias=False)
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
if not valid:
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
# 生成Label Studio XML配置(用于验证,但不存储)
label_config = self.generate_label_studio_config(request.configuration)
# 验证生成的XML
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
# 创建模板对象(不包含label_config字段)
template = AnnotationTemplate(
id=str(uuid4()),
name=request.name,
description=request.description,
data_type=request.data_type,
labeling_type=request.labeling_type,
configuration=config_dict,
style=request.style,
category=request.category,
built_in=False,
version="1.0.0",
created_at=datetime.now()
)
db.add(template)
await db.commit()
await db.refresh(template)
return self._to_response(template)
async def get_template(
self,
db: AsyncSession,
template_id: str
) -> Optional[AnnotationTemplateResponse]:
"""
获取单个模板
Args:
db: 数据库会话
template_id: 模板ID
Returns:
模板响应或None
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if template:
return self._to_response(template)
return None
async def list_templates(
self,
db: AsyncSession,
page: int = 1,
size: int = 10,
category: Optional[str] = None,
data_type: Optional[str] = None,
labeling_type: Optional[str] = None,
built_in: Optional[bool] = None
) -> AnnotationTemplateListResponse:
"""
获取模板列表
Args:
db: 数据库会话
page: 页码
size: 每页大小
category: 分类筛选
data_type: 数据类型筛选
labeling_type: 标注类型筛选
built_in: 是否内置模板筛选
Returns:
模板列表响应
"""
# 构建查询条件
conditions: List = [AnnotationTemplate.deleted_at.is_(None)]
if category:
conditions.append(AnnotationTemplate.category == category) # type: ignore
if data_type:
conditions.append(AnnotationTemplate.data_type == data_type) # type: ignore
if labeling_type:
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
if built_in is not None:
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
# 查询总数
count_result = await db.execute(
select(func.count()).select_from(AnnotationTemplate).where(*conditions)
)
total = count_result.scalar() or 0
# 分页查询
result = await db.execute(
select(AnnotationTemplate)
.where(*conditions)
.order_by(AnnotationTemplate.created_at.desc())
.limit(size)
.offset((page - 1) * size)
)
templates = result.scalars().all()
return AnnotationTemplateListResponse(
content=[self._to_response(t) for t in templates],
total=total,
page=page,
size=size,
totalPages=(total + size - 1) // size
)
async def update_template(
self,
db: AsyncSession,
template_id: str,
request: UpdateAnnotationTemplateRequest
) -> Optional[AnnotationTemplateResponse]:
"""
更新模板
Args:
db: 数据库会话
template_id: 模板ID
request: 更新请求
Returns:
更新后的模板响应或None
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if not template:
return None
# 更新字段
update_data = request.model_dump(exclude_unset=True, by_alias=False)
for field, value in update_data.items():
if field == 'configuration' and value is not None:
# 验证配置JSON
config_dict = value.model_dump(mode='json', by_alias=False)
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
if not valid:
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
# 重新生成Label Studio XML配置(用于验证)
label_config = self.generate_label_studio_config(value)
# 验证生成的XML
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
if not valid:
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
# 只更新configuration字段,不存储label_config
setattr(template, field, config_dict)
else:
setattr(template, field, value)
template.updated_at = datetime.now() # type: ignore
await db.commit()
await db.refresh(template)
return self._to_response(template)
async def delete_template(
self,
db: AsyncSession,
template_id: str
) -> bool:
"""
删除模板(软删除)
Args:
db: 数据库会话
template_id: 模板ID
Returns:
是否删除成功
"""
result = await db.execute(
select(AnnotationTemplate)
.where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = result.scalar_one_or_none()
if not template:
return False
template.deleted_at = datetime.now() # type: ignore
await db.commit()
return True
def _to_response(self, template: AnnotationTemplate) -> AnnotationTemplateResponse:
"""
转换为响应对象
Args:
template: 数据库模型对象
Returns:
模板响应对象
"""
# 将配置JSON转换为TemplateConfiguration对象
from typing import cast, Dict, Any
config_dict = cast(Dict[str, Any], template.configuration)
config = TemplateConfiguration(**config_dict)
# 动态生成Label Studio XML配置
label_config = self.generate_label_studio_config(config)
# 使用model_validate从ORM对象创建响应对象
response = AnnotationTemplateResponse.model_validate(template)
response.configuration = config
response.label_config = label_config # type: ignore
return response