feat: Enhance file tag update functionality with automatic format conversion (#84)

- Updated `update_file_tags` to support both simplified and full tag formats.
- Introduced `TagFormatConverter` to handle conversion from simplified external tags to internal storage format.
- Added logic to fetch and utilize the appropriate annotation template for conversion.
- Improved error handling for missing templates and unknown controls during tag updates.
- Created example script demonstrating the usage of the new tag format conversion feature.
- Added unit tests for `TagFormatConverter` to ensure correct functionality and edge case handling.
This commit is contained in:
Jason Wang
2025-11-14 12:42:39 +08:00
committed by GitHub
parent 5cef9cb273
commit df853a5177
10 changed files with 1127 additions and 54 deletions

View File

@@ -7,12 +7,13 @@ from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.db.models import LabelingProject, AnnotationTemplate
from app.db.models.dataset_management import Dataset
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
DatasetMappingResponse,
AnnotationTemplateResponse
)
logger = get_logger(__name__)
@@ -33,11 +34,32 @@ class DatasetMappingService:
LabelingProject.dataset_id == Dataset.id
)
def _to_response_from_row(self, row) -> DatasetMappingResponse:
"""Convert query row (mapping + dataset_name) to response"""
async def _to_response_from_row(
self,
row,
include_template: bool = False
) -> DatasetMappingResponse:
"""
Convert query row (mapping + dataset_name) to response
Args:
row: Query result row containing (LabelingProject, dataset_name)
include_template: If True, fetch and include full template details
"""
mapping = row[0] # LabelingProject object
dataset_name = row[1] # dataset_name from join
# Get template_id from mapping
template_id = getattr(mapping, 'template_id', None)
# Optionally fetch full template details
template_response = None
if include_template and template_id:
from ..service.template import AnnotationTemplateService
template_service = AnnotationTemplateService()
template_response = await template_service.get_template(self.db, template_id)
logger.debug(f"Included template details for template_id: {template_id}")
response_data = {
"id": mapping.id,
"dataset_id": mapping.dataset_id,
@@ -45,6 +67,8 @@ class DatasetMappingService:
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"template_id": template_id,
"template": template_response,
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
@@ -52,8 +76,18 @@ class DatasetMappingService:
return DatasetMappingResponse(**response_data)
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
"""Convert ORM model to response with dataset name (for single entity operations)"""
async def _to_response(
self,
mapping: LabelingProject,
include_template: bool = False
) -> DatasetMappingResponse:
"""
Convert ORM model to response with dataset name (for single entity operations)
Args:
mapping: LabelingProject ORM object
include_template: If True, fetch and include full template details
"""
# Fetch dataset name
dataset_name = None
dataset_id = getattr(mapping, 'dataset_id', None)
@@ -63,6 +97,17 @@ class DatasetMappingService:
)
dataset_name = dataset_result.scalar_one_or_none()
# Get template_id from mapping
template_id = getattr(mapping, 'template_id', None)
# Optionally fetch full template details
template_response = None
if include_template and template_id:
from ..service.template import AnnotationTemplateService
template_service = AnnotationTemplateService()
template_response = await template_service.get_template(self.db, template_id)
logger.debug(f"Included template details for template_id: {template_id}")
# Create response dict with all fields
response_data = {
"id": mapping.id,
@@ -71,6 +116,8 @@ class DatasetMappingService:
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"template_id": template_id,
"template": template_response,
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
@@ -136,7 +183,12 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=False)
responses.append(response)
return responses
async def get_mapping_by_labeling_project_id(
self,
@@ -160,9 +212,19 @@ class DatasetMappingService:
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
async def get_mapping_by_uuid(
self,
mapping_id: str,
include_template: bool = False
) -> Optional[DatasetMappingResponse]:
"""
根据映射UUID获取映射
Args:
mapping_id: 映射UUID
include_template: 是否包含完整的模板信息
"""
logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}")
result = await self.db.execute(
select(LabelingProject).where(
@@ -174,7 +236,7 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return await self._to_response(mapping)
return await self._to_response(mapping, include_template=include_template)
logger.debug(f"No mapping found for mapping id: {mapping_id}")
return None
@@ -248,7 +310,12 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=False)
responses.append(response)
return responses
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
@@ -264,10 +331,19 @@ class DatasetMappingService:
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
include_deleted: bool = False,
include_template: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
"""
获取所有映射及总数(用于分页)
Args:
skip: 跳过记录数
limit: 返回记录数
include_deleted: 是否包含已删除的记录
include_template: 是否包含完整的模板信息
"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}")
# 构建查询
query = self._build_query_with_dataset_name()
@@ -292,17 +368,62 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=include_template)
responses.append(response)
return responses, total
async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]:
"""
Get template ID for a dataset by finding its labeling project
Args:
dataset_id: Dataset UUID
Returns:
Template ID or None if no labeling project found or no template associated
"""
logger.debug(f"Looking up template for dataset: {dataset_id}")
result = await self.db.execute(
select(LabelingProject.template_id)
.where(
LabelingProject.dataset_id == dataset_id,
LabelingProject.deleted_at.is_(None)
)
.limit(1)
)
template_id = result.scalar_one_or_none()
if template_id:
logger.debug(f"Found template {template_id} for dataset {dataset_id}")
else:
logger.warning(f"No template found for dataset {dataset_id}")
return template_id
async def get_mappings_by_source_with_count(
self,
dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
include_deleted: bool = False,
include_template: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
"""
根据源数据集ID获取映射关系及总数(用于分页)
Args:
dataset_id: 数据集ID
skip: 跳过记录数
limit: 返回记录数
include_deleted: 是否包含已删除的记录
include_template: 是否包含完整的模板信息
"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}")
# 构建查询
query = self._build_query_with_dataset_name().where(
@@ -332,4 +453,9 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=include_template)
responses.append(response)
return responses, total