You've already forked DataMate
feat: Add labeling template (#72)
* feat: Enhance annotation module with template management and validation - Added DatasetMappingCreateRequest and DatasetMappingUpdateRequest schemas to handle dataset mapping requests with camelCase and snake_case support. - Introduced Annotation Template schemas including CreateAnnotationTemplateRequest, UpdateAnnotationTemplateRequest, and AnnotationTemplateResponse for managing annotation templates. - Implemented AnnotationTemplateService for creating, updating, retrieving, and deleting annotation templates, including validation of configurations and XML generation. - Added utility class LabelStudioConfigValidator for validating Label Studio configurations and XML formats. - Updated database schema for annotation templates and labeling projects to include new fields and constraints. - Seeded initial annotation templates for various use cases including image classification, object detection, and text classification. * feat: Enhance TemplateForm with improved validation and dynamic field rendering; update LabelStudio config validation for camelCase support * feat: Update docker-compose.yml to mark datamate dataset volume and network as external
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update, func
|
||||
from sqlalchemy.orm import aliased
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import LabelingProject
|
||||
from app.db.models.dataset_management import Dataset
|
||||
from app.module.annotation.schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingUpdateRequest,
|
||||
@@ -21,6 +23,61 @@ class DatasetMappingService:
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
def _build_query_with_dataset_name(self):
|
||||
"""Build base query with dataset name joined"""
|
||||
return select(
|
||||
LabelingProject,
|
||||
Dataset.name.label('dataset_name')
|
||||
).outerjoin(
|
||||
Dataset,
|
||||
LabelingProject.dataset_id == Dataset.id
|
||||
)
|
||||
|
||||
def _to_response_from_row(self, row) -> DatasetMappingResponse:
|
||||
"""Convert query row (mapping + dataset_name) to response"""
|
||||
mapping = row[0] # LabelingProject object
|
||||
dataset_name = row[1] # dataset_name from join
|
||||
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": mapping.dataset_id,
|
||||
"dataset_name": dataset_name,
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
}
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
|
||||
"""Convert ORM model to response with dataset name (for single entity operations)"""
|
||||
# Fetch dataset name
|
||||
dataset_name = None
|
||||
dataset_id = getattr(mapping, 'dataset_id', None)
|
||||
if dataset_id:
|
||||
dataset_result = await self.db.execute(
|
||||
select(Dataset.name).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset_name = dataset_result.scalar_one_or_none()
|
||||
|
||||
# Create response dict with all fields
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset_name,
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
}
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def create_mapping(
|
||||
self,
|
||||
labeling_project: LabelingProject
|
||||
@@ -28,19 +85,13 @@ class DatasetMappingService:
|
||||
"""创建数据集映射"""
|
||||
logger.info(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
|
||||
|
||||
db_mapping = LabelingProject(
|
||||
id=str(uuid.uuid4()),
|
||||
dataset_id=labeling_project.dataset_id,
|
||||
labeling_project_id=labeling_project.labeling_project_id,
|
||||
name=labeling_project.name
|
||||
)
|
||||
|
||||
self.db.add(db_mapping)
|
||||
# Use the passed object directly
|
||||
self.db.add(labeling_project)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(db_mapping)
|
||||
await self.db.refresh(labeling_project)
|
||||
|
||||
logger.debug(f"Mapping created: {db_mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(db_mapping)
|
||||
logger.debug(f"Mapping created: {labeling_project.id}")
|
||||
return await self._to_response(labeling_project)
|
||||
|
||||
async def get_mapping_by_source_uuid(
|
||||
self,
|
||||
@@ -59,7 +110,7 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
|
||||
return None
|
||||
@@ -72,7 +123,7 @@ class DatasetMappingService:
|
||||
"""根据源数据集ID获取所有映射关系"""
|
||||
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
|
||||
|
||||
query = select(LabelingProject).where(
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
|
||||
@@ -82,10 +133,10 @@ class DatasetMappingService:
|
||||
result = await self.db.execute(
|
||||
query.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
|
||||
async def get_mapping_by_labeling_project_id(
|
||||
self,
|
||||
@@ -103,8 +154,8 @@ class DatasetMappingService:
|
||||
mapping = result.scalar_one_or_none()
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.mapping_id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
|
||||
return None
|
||||
@@ -123,9 +174,9 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return DatasetMappingResponse.model_validate(mapping)
|
||||
return await self._to_response(mapping)
|
||||
|
||||
logger.debug(f"Mapping not found: {mapping_id}")
|
||||
logger.debug(f"No mapping found for mapping id: {mapping_id}")
|
||||
return None
|
||||
|
||||
async def update_mapping(
|
||||
@@ -184,17 +235,20 @@ class DatasetMappingService:
|
||||
"""获取所有有效映射"""
|
||||
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
|
||||
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.deleted_at.is_(None)
|
||||
)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject)
|
||||
.where(LabelingProject.deleted_at.is_(None))
|
||||
query
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
|
||||
async def count_mappings(self, include_deleted: bool = False) -> int:
|
||||
"""统计映射总数"""
|
||||
@@ -216,7 +270,7 @@ class DatasetMappingService:
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
|
||||
|
||||
# 构建查询
|
||||
query = select(LabelingProject)
|
||||
query = self._build_query_with_dataset_name()
|
||||
if not include_deleted:
|
||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||
|
||||
@@ -235,10 +289,10 @@ class DatasetMappingService:
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
|
||||
async def get_mappings_by_source_with_count(
|
||||
self,
|
||||
@@ -251,7 +305,7 @@ class DatasetMappingService:
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
|
||||
|
||||
# 构建查询
|
||||
query = select(LabelingProject).where(
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
LabelingProject.dataset_id == dataset_id
|
||||
)
|
||||
|
||||
@@ -275,7 +329,7 @@ class DatasetMappingService:
|
||||
.limit(limit)
|
||||
.order_by(LabelingProject.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Annotation Template Service
|
||||
"""
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uuid import uuid4
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.db.models.annotation_management import AnnotationTemplate
|
||||
from app.module.annotation.schema.template import (
|
||||
CreateAnnotationTemplateRequest,
|
||||
UpdateAnnotationTemplateRequest,
|
||||
AnnotationTemplateResponse,
|
||||
AnnotationTemplateListResponse,
|
||||
TemplateConfiguration
|
||||
)
|
||||
from app.module.annotation.utils.config_validator import LabelStudioConfigValidator
|
||||
|
||||
|
||||
class AnnotationTemplateService:
|
||||
"""标注模板服务"""
|
||||
|
||||
@staticmethod
|
||||
def generate_label_studio_config(config: TemplateConfiguration) -> str:
|
||||
"""
|
||||
从配置JSON生成Label Studio XML配置
|
||||
|
||||
Args:
|
||||
config: 模板配置对象
|
||||
|
||||
Returns:
|
||||
Label Studio XML字符串
|
||||
"""
|
||||
xml_parts = ['<View>']
|
||||
|
||||
# 生成对象定义
|
||||
for obj in config.objects:
|
||||
obj_attrs = [
|
||||
f'name="{obj.name}"',
|
||||
f'value="{obj.value}"'
|
||||
]
|
||||
xml_parts.append(f' <{obj.type} {" ".join(obj_attrs)}/>')
|
||||
|
||||
# 生成标签定义
|
||||
for label in config.labels:
|
||||
label_attrs = [
|
||||
f'name="{label.from_name}"',
|
||||
f'toName="{label.to_name}"'
|
||||
]
|
||||
|
||||
# 添加可选属性
|
||||
if label.required:
|
||||
label_attrs.append('required="true"')
|
||||
|
||||
tag_type = label.type.capitalize() if label.type else "Choices"
|
||||
|
||||
# 处理带选项的标签类型
|
||||
if label.options or label.labels:
|
||||
choices = label.options or label.labels or []
|
||||
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}>')
|
||||
for choice in choices:
|
||||
xml_parts.append(f' <Label value="{choice}"/>')
|
||||
xml_parts.append(f' </{tag_type}>')
|
||||
else:
|
||||
# 处理简单标签类型
|
||||
xml_parts.append(f' <{tag_type} {" ".join(label_attrs)}/>')
|
||||
|
||||
xml_parts.append('</View>')
|
||||
return '\n'.join(xml_parts)
|
||||
|
||||
async def create_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
request: CreateAnnotationTemplateRequest
|
||||
) -> AnnotationTemplateResponse:
|
||||
"""
|
||||
创建标注模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request: 创建请求
|
||||
|
||||
Returns:
|
||||
创建的模板响应
|
||||
"""
|
||||
# 验证配置JSON
|
||||
config_dict = request.configuration.model_dump(mode='json', by_alias=False)
|
||||
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
|
||||
|
||||
# 生成Label Studio XML配置(用于验证,但不存储)
|
||||
label_config = self.generate_label_studio_config(request.configuration)
|
||||
|
||||
# 验证生成的XML
|
||||
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
|
||||
|
||||
# 创建模板对象(不包含label_config字段)
|
||||
template = AnnotationTemplate(
|
||||
id=str(uuid4()),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
data_type=request.data_type,
|
||||
labeling_type=request.labeling_type,
|
||||
configuration=config_dict,
|
||||
style=request.style,
|
||||
category=request.category,
|
||||
built_in=False,
|
||||
version="1.0.0",
|
||||
created_at=datetime.now()
|
||||
)
|
||||
|
||||
db.add(template)
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return self._to_response(template)
|
||||
|
||||
async def get_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str
|
||||
) -> Optional[AnnotationTemplateResponse]:
|
||||
"""
|
||||
获取单个模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
模板响应或None
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if template:
|
||||
return self._to_response(template)
|
||||
return None
|
||||
|
||||
async def list_templates(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
category: Optional[str] = None,
|
||||
data_type: Optional[str] = None,
|
||||
labeling_type: Optional[str] = None,
|
||||
built_in: Optional[bool] = None
|
||||
) -> AnnotationTemplateListResponse:
|
||||
"""
|
||||
获取模板列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
page: 页码
|
||||
size: 每页大小
|
||||
category: 分类筛选
|
||||
data_type: 数据类型筛选
|
||||
labeling_type: 标注类型筛选
|
||||
built_in: 是否内置模板筛选
|
||||
|
||||
Returns:
|
||||
模板列表响应
|
||||
"""
|
||||
# 构建查询条件
|
||||
conditions: List = [AnnotationTemplate.deleted_at.is_(None)]
|
||||
|
||||
if category:
|
||||
conditions.append(AnnotationTemplate.category == category) # type: ignore
|
||||
if data_type:
|
||||
conditions.append(AnnotationTemplate.data_type == data_type) # type: ignore
|
||||
if labeling_type:
|
||||
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
|
||||
if built_in is not None:
|
||||
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
|
||||
|
||||
# 查询总数
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(AnnotationTemplate).where(*conditions)
|
||||
)
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# 分页查询
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(*conditions)
|
||||
.order_by(AnnotationTemplate.created_at.desc())
|
||||
.limit(size)
|
||||
.offset((page - 1) * size)
|
||||
)
|
||||
templates = result.scalars().all()
|
||||
|
||||
return AnnotationTemplateListResponse(
|
||||
content=[self._to_response(t) for t in templates],
|
||||
total=total,
|
||||
page=page,
|
||||
size=size,
|
||||
totalPages=(total + size - 1) // size
|
||||
)
|
||||
|
||||
async def update_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str,
|
||||
request: UpdateAnnotationTemplateRequest
|
||||
) -> Optional[AnnotationTemplateResponse]:
|
||||
"""
|
||||
更新模板
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
request: 更新请求
|
||||
|
||||
Returns:
|
||||
更新后的模板响应或None
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=False)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == 'configuration' and value is not None:
|
||||
# 验证配置JSON
|
||||
config_dict = value.model_dump(mode='json', by_alias=False)
|
||||
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
|
||||
|
||||
# 重新生成Label Studio XML配置(用于验证)
|
||||
label_config = self.generate_label_studio_config(value)
|
||||
|
||||
# 验证生成的XML
|
||||
valid, error = LabelStudioConfigValidator.validate_xml(label_config)
|
||||
if not valid:
|
||||
raise HTTPException(status_code=400, detail=f"Generated XML is invalid: {error}")
|
||||
|
||||
# 只更新configuration字段,不存储label_config
|
||||
setattr(template, field, config_dict)
|
||||
else:
|
||||
setattr(template, field, value)
|
||||
|
||||
template.updated_at = datetime.now() # type: ignore
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(template)
|
||||
|
||||
return self._to_response(template)
|
||||
|
||||
async def delete_template(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
template_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
删除模板(软删除)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(AnnotationTemplate)
|
||||
.where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
return False
|
||||
|
||||
template.deleted_at = datetime.now() # type: ignore
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
|
||||
def _to_response(self, template: AnnotationTemplate) -> AnnotationTemplateResponse:
|
||||
"""
|
||||
转换为响应对象
|
||||
|
||||
Args:
|
||||
template: 数据库模型对象
|
||||
|
||||
Returns:
|
||||
模板响应对象
|
||||
"""
|
||||
# 将配置JSON转换为TemplateConfiguration对象
|
||||
from typing import cast, Dict, Any
|
||||
config_dict = cast(Dict[str, Any], template.configuration)
|
||||
config = TemplateConfiguration(**config_dict)
|
||||
|
||||
# 动态生成Label Studio XML配置
|
||||
label_config = self.generate_label_studio_config(config)
|
||||
|
||||
# 使用model_validate从ORM对象创建响应对象
|
||||
response = AnnotationTemplateResponse.model_validate(template)
|
||||
response.configuration = config
|
||||
response.label_config = label_config # type: ignore
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user