You've already forked DataMate
feat: Enhance file tag update functionality with automatic format conversion (#84)
- Updated `update_file_tags` to support both simplified and full tag formats. - Introduced `TagFormatConverter` to handle conversion from simplified external tags to internal storage format. - Added logic to fetch and utilize the appropriate annotation template for conversion. - Improved error handling for missing templates and unknown controls during tag updates. - Created example script demonstrating the usage of the new tag format conversion feature. - Added unit tests for `TagFormatConverter` to ensure correct functionality and edge case handling.
This commit is contained in:
@@ -7,12 +7,13 @@ from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import LabelingProject
|
||||
from app.db.models import LabelingProject, AnnotationTemplate
|
||||
from app.db.models.dataset_management import Dataset
|
||||
from app.module.annotation.schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingUpdateRequest,
|
||||
DatasetMappingResponse
|
||||
DatasetMappingResponse,
|
||||
AnnotationTemplateResponse
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -33,11 +34,32 @@ class DatasetMappingService:
|
||||
LabelingProject.dataset_id == Dataset.id
|
||||
)
|
||||
|
||||
def _to_response_from_row(self, row) -> DatasetMappingResponse:
|
||||
"""Convert query row (mapping + dataset_name) to response"""
|
||||
async def _to_response_from_row(
|
||||
self,
|
||||
row,
|
||||
include_template: bool = False
|
||||
) -> DatasetMappingResponse:
|
||||
"""
|
||||
Convert query row (mapping + dataset_name) to response
|
||||
|
||||
Args:
|
||||
row: Query result row containing (LabelingProject, dataset_name)
|
||||
include_template: If True, fetch and include full template details
|
||||
"""
|
||||
mapping = row[0] # LabelingProject object
|
||||
dataset_name = row[1] # dataset_name from join
|
||||
|
||||
# Get template_id from mapping
|
||||
template_id = getattr(mapping, 'template_id', None)
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
if include_template and template_id:
|
||||
from ..service.template import AnnotationTemplateService
|
||||
template_service = AnnotationTemplateService()
|
||||
template_response = await template_service.get_template(self.db, template_id)
|
||||
logger.debug(f"Included template details for template_id: {template_id}")
|
||||
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": mapping.dataset_id,
|
||||
@@ -45,6 +67,8 @@ class DatasetMappingService:
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
@@ -52,8 +76,18 @@ class DatasetMappingService:
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
|
||||
"""Convert ORM model to response with dataset name (for single entity operations)"""
|
||||
async def _to_response(
|
||||
self,
|
||||
mapping: LabelingProject,
|
||||
include_template: bool = False
|
||||
) -> DatasetMappingResponse:
|
||||
"""
|
||||
Convert ORM model to response with dataset name (for single entity operations)
|
||||
|
||||
Args:
|
||||
mapping: LabelingProject ORM object
|
||||
include_template: If True, fetch and include full template details
|
||||
"""
|
||||
# Fetch dataset name
|
||||
dataset_name = None
|
||||
dataset_id = getattr(mapping, 'dataset_id', None)
|
||||
@@ -63,6 +97,17 @@ class DatasetMappingService:
|
||||
)
|
||||
dataset_name = dataset_result.scalar_one_or_none()
|
||||
|
||||
# Get template_id from mapping
|
||||
template_id = getattr(mapping, 'template_id', None)
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
if include_template and template_id:
|
||||
from ..service.template import AnnotationTemplateService
|
||||
template_service = AnnotationTemplateService()
|
||||
template_response = await template_service.get_template(self.db, template_id)
|
||||
logger.debug(f"Included template details for template_id: {template_id}")
|
||||
|
||||
# Create response dict with all fields
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
@@ -71,6 +116,8 @@ class DatasetMappingService:
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
@@ -136,7 +183,12 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=False)
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def get_mapping_by_labeling_project_id(
|
||||
self,
|
||||
@@ -160,9 +212,19 @@ class DatasetMappingService:
|
||||
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
|
||||
return None
|
||||
|
||||
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
|
||||
"""根据映射UUID获取映射"""
|
||||
logger.debug(f"Get mapping: {mapping_id}")
|
||||
async def get_mapping_by_uuid(
|
||||
self,
|
||||
mapping_id: str,
|
||||
include_template: bool = False
|
||||
) -> Optional[DatasetMappingResponse]:
|
||||
"""
|
||||
根据映射UUID获取映射
|
||||
|
||||
Args:
|
||||
mapping_id: 映射UUID
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject).where(
|
||||
@@ -174,7 +236,7 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return await self._to_response(mapping)
|
||||
return await self._to_response(mapping, include_template=include_template)
|
||||
|
||||
logger.debug(f"No mapping found for mapping id: {mapping_id}")
|
||||
return None
|
||||
@@ -248,7 +310,12 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=False)
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def count_mappings(self, include_deleted: bool = False) -> int:
|
||||
"""统计映射总数"""
|
||||
@@ -264,10 +331,19 @@ class DatasetMappingService:
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""获取所有映射及总数(用于分页)"""
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
|
||||
"""
|
||||
获取所有映射及总数(用于分页)
|
||||
|
||||
Args:
|
||||
skip: 跳过记录数
|
||||
limit: 返回记录数
|
||||
include_deleted: 是否包含已删除的记录
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}")
|
||||
|
||||
# 构建查询
|
||||
query = self._build_query_with_dataset_name()
|
||||
@@ -292,17 +368,62 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=include_template)
|
||||
responses.append(response)
|
||||
return responses, total
|
||||
|
||||
async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get template ID for a dataset by finding its labeling project
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
|
||||
Returns:
|
||||
Template ID or None if no labeling project found or no template associated
|
||||
"""
|
||||
logger.debug(f"Looking up template for dataset: {dataset_id}")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject.template_id)
|
||||
.where(
|
||||
LabelingProject.dataset_id == dataset_id,
|
||||
LabelingProject.deleted_at.is_(None)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
template_id = result.scalar_one_or_none()
|
||||
|
||||
if template_id:
|
||||
logger.debug(f"Found template {template_id} for dataset {dataset_id}")
|
||||
else:
|
||||
logger.warning(f"No template found for dataset {dataset_id}")
|
||||
|
||||
return template_id
|
||||
|
||||
async def get_mappings_by_source_with_count(
|
||||
self,
|
||||
dataset_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""根据源数据集ID获取映射关系及总数(用于分页)"""
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
|
||||
"""
|
||||
根据源数据集ID获取映射关系及总数(用于分页)
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
skip: 跳过记录数
|
||||
limit: 返回记录数
|
||||
include_deleted: 是否包含已删除的记录
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}")
|
||||
|
||||
# 构建查询
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
@@ -332,4 +453,9 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=include_template)
|
||||
responses.append(response)
|
||||
return responses, total
|
||||
Reference in New Issue
Block a user