feat: Add labeling template (#72)

* feat: Enhance annotation module with template management and validation

- Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support.
- Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates.
- Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation.
- Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats.
- Updated database schema for annotation templates and labeling projects to include new fields and constraints.
- Seeded initial annotation templates for various use cases including image classification, object detection, and text classification.

* feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support

* feat: Update docker-compose.yml to mark datamate dataset volume and network as external
This commit is contained in:
Jason Wang
2025-11-11 09:14:14 +08:00
committed by GitHub
parent 451d3c8207
commit c5ccc56cca
24 changed files with 2794 additions and 253 deletions

View File

@@ -1,12 +1,14 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from sqlalchemy.orm import aliased
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.db.models.dataset_management import Dataset
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
@@ -21,6 +23,61 @@ class DatasetMappingService:
def __init__(self, db: AsyncSession):
self.db = db
def _build_query_with_dataset_name(self):
"""Build base query with dataset name joined"""
return select(
LabelingProject,
Dataset.name.label('dataset_name')
).outerjoin(
Dataset,
LabelingProject.dataset_id == Dataset.id
)
def _to_response_from_row(self, row) -> DatasetMappingResponse:
"""Convert query row (mapping + dataset_name) to response"""
mapping = row[0] # LabelingProject object
dataset_name = row[1] # dataset_name from join
response_data = {
"id": mapping.id,
"dataset_id": mapping.dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
"""Convert ORM model to response with dataset name (for single entity operations)"""
# Fetch dataset name
dataset_name = None
dataset_id = getattr(mapping, 'dataset_id', None)
if dataset_id:
dataset_result = await self.db.execute(
select(Dataset.name).where(Dataset.id == dataset_id)
)
dataset_name = dataset_result.scalar_one_or_none()
# Create response dict with all fields
response_data = {
"id": mapping.id,
"dataset_id": dataset_id,
"dataset_name": dataset_name,
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
}
return DatasetMappingResponse(**response_data)
async def create_mapping(
self,
labeling_project: LabelingProject
@@ -28,19 +85,13 @@ class DatasetMappingService:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
db_mapping = LabelingProject(
id=str(uuid.uuid4()),
dataset_id=labeling_project.dataset_id,
labeling_project_id=labeling_project.labeling_project_id,
name=labeling_project.name
)
self.db.add(db_mapping)
# Use the passed object directly
self.db.add(labeling_project)
await self.db.commit()
await self.db.refresh(db_mapping)
await self.db.refresh(labeling_project)
logger.debug(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
logger.debug(f"Mapping created: {labeling_project.id}")
return await self._to_response(labeling_project)
async def get_mapping_by_source_uuid(
self,
@@ -59,7 +110,7 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
@@ -72,7 +123,7 @@ class DatasetMappingService:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -82,10 +133,10 @@ class DatasetMappingService:
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def get_mapping_by_labeling_project_id(
self,
@@ -103,8 +154,8 @@ class DatasetMappingService:
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Found mapping: {mapping.id}")
return await self._to_response(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
@@ -123,9 +174,9 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
return await self._to_response(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
logger.debug(f"No mapping found for mapping id: {mapping_id}")
return None
async def update_mapping(
@@ -184,17 +235,20 @@ class DatasetMappingService:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
query = self._build_query_with_dataset_name().where(
LabelingProject.deleted_at.is_(None)
)
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
@@ -216,7 +270,7 @@ class DatasetMappingService:
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
query = self._build_query_with_dataset_name()
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
@@ -235,10 +289,10 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
async def get_mappings_by_source_with_count(
self,
@@ -251,7 +305,7 @@ class DatasetMappingService:
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
query = self._build_query_with_dataset_name().where(
LabelingProject.dataset_id == dataset_id
)
@@ -275,7 +329,7 @@ class DatasetMappingService:
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
rows = result.all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total