diff --git a/runtime/datamate-python/app/module/annotation/interface/project.py b/runtime/datamate-python/app/module/annotation/interface/project.py index 737da21..71a90c5 100644 --- a/runtime/datamate-python/app/module/annotation/interface/project.py +++ b/runtime/datamate-python/app/module/annotation/interface/project.py @@ -150,12 +150,19 @@ async def create_mapping( async def list_mappings( page: int = Query(1, ge=1, description="页码(从1开始)"), page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"), + include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"), db: AsyncSession = Depends(get_db) ): """ 查询所有映射关系(分页) - 返回所有有效的数据集映射关系(未被软删除的),支持分页查询 + 返回所有有效的数据集映射关系(未被软删除的),支持分页查询。 + 可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。 + + 参数: + - page: 页码(从1开始) + - pageSize: 每页记录数 + - includeTemplate: 是否包含模板详情(默认false) """ try: service = DatasetMappingService(db) @@ -163,10 +170,14 @@ async def list_mappings( # 计算 skip skip = (page - 1) * page_size + logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}") + # 获取数据和总数 mappings, total = await service.get_all_mappings_with_count( - skip=skip, - limit=page_size + skip=skip, + limit=page_size, + include_deleted=False, + include_template=include_template ) # 计算总页数 @@ -181,7 +192,7 @@ async def list_mappings( content=mappings ) - logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}") + logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}") return StandardResponse( code=200, @@ -199,14 +210,21 @@ async def get_mapping( db: AsyncSession = Depends(get_db) ): """ - 根据 UUID 查询单个映射关系 + 根据 UUID 查询单个映射关系(包含关联的标注模板详情) + + 返回数据集映射关系以及关联的完整标注模板信息,包括: + - 映射基本信息 + - 数据集信息 + - Label Studio 项目信息 + - 完整的标注模板配置(如果存在) """ try: service = DatasetMappingService(db) - logger.info(f"Get mapping: {mapping_id}") + logger.info(f"Get mapping with template details: {mapping_id}") - mapping = await service.get_mapping_by_uuid(mapping_id) + # 获取映射,并包含完整的模板信息 + mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True) if not mapping: raise HTTPException( @@ -214,7 +232,7 @@ async def get_mapping( detail=f"Mapping not found: {mapping_id}" ) - logger.info(f"Found mapping: {mapping.id}") + logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}") return StandardResponse( code=200, @@ -233,12 +251,20 @@ async def get_mappings_by_source( dataset_id: str, page: int = Query(1, ge=1, description="页码(从1开始)"), page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"), + include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"), db: AsyncSession = Depends(get_db) ): """ - 根据源数据集 ID 查询所有映射关系(分页) + 根据源数据集 ID 查询所有映射关系(分页,包含模板详情) - 返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询 + 返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。 + 默认包含关联的完整标注模板信息。 + + 参数: + - dataset_id: 数据集ID + - page: 页码(从1开始) + - pageSize: 每页记录数 + - includeTemplate: 是否包含模板详情(默认true) """ try: service = DatasetMappingService(db) @@ -246,13 +272,14 @@ async def get_mappings_by_source( # 计算 skip skip = (page - 1) * page_size - logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}") + logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}") - # 获取数据和总数 + # 获取数据和总数(包含模板信息) mappings, total = await service.get_mappings_by_source_with_count( dataset_id=dataset_id, skip=skip, - limit=page_size + limit=page_size, + include_template=include_template ) # 计算总页数 @@ -267,7 +294,7 @@ async def get_mappings_by_source( content=mappings ) - logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}") + logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}") return StandardResponse( code=200, diff --git a/runtime/datamate-python/app/module/annotation/interface/task.py b/runtime/datamate-python/app/module/annotation/interface/task.py index e8f049a..13b6674 100644 --- a/runtime/datamate-python/app/module/annotation/interface/task.py +++ b/runtime/datamate-python/app/module/annotation/interface/task.py @@ -244,16 +244,53 @@ async def update_file_tags( db: AsyncSession = Depends(get_db) ): """ - Update File Tags (Partial Update) + Update File Tags (Partial Update with Auto Format Conversion) 接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。 + + 支持两种标签格式: + 1. 简化格式(外部用户提交): + [{"from_name": "label", "to_name": "image", "values": ["cat", "dog"]}] + + 2. 完整格式(内部存储): + [{"id": "...", "from_name": "label", "to_name": "image", "type": "choices", + "value": {"choices": ["cat", "dog"]}}] + + 系统会自动根据数据集关联的模板将简化格式转换为完整格式。 请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。 """ service = DatasetManagementService(db) + # 首先获取文件所属的数据集 + from sqlalchemy.future import select + from app.db.models import DatasetFiles + + result = await db.execute( + select(DatasetFiles).where(DatasetFiles.id == file_id) + ) + file_record = result.scalar_one_or_none() + + if not file_record: + raise HTTPException(status_code=404, detail=f"File not found: {file_id}") + + dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str + + # 查找数据集关联的模板ID + from ..service.mapping import DatasetMappingService + + mapping_service = DatasetMappingService(db) + template_id = await mapping_service.get_template_id_by_dataset_id(dataset_id) + + if template_id: + logger.info(f"Found template {template_id} for dataset {dataset_id}, will auto-convert tag format") + else: + logger.warning(f"No template found for dataset {dataset_id}, tags must be in full format") + + # 更新标签(如果有模板ID则自动转换格式) success, error_msg, updated_at = await service.update_file_tags_partial( file_id=file_id, - new_tags=request.tags + new_tags=request.tags, + template_id=template_id # 传递模板ID以启用自动转换 ) if not success: @@ -261,10 +298,7 @@ async def update_file_tags( raise HTTPException(status_code=404, detail=error_msg) raise HTTPException(status_code=500, detail=error_msg or "更新标签失败") - # 获取更新后的完整标签列表 - from sqlalchemy.future import select - from app.db.models import DatasetFiles - + # 重新获取更新后的文件记录(获取完整标签列表) result = await db.execute( select(DatasetFiles).where(DatasetFiles.id == file_id) ) diff --git a/runtime/datamate-python/app/module/annotation/schema/__init__.py b/runtime/datamate-python/app/module/annotation/schema/__init__.py index f849582..6fb8707 100644 --- a/runtime/datamate-python/app/module/annotation/schema/__init__.py +++ b/runtime/datamate-python/app/module/annotation/schema/__init__.py @@ -3,14 +3,6 @@ from .config import ( TagConfigResponse ) -from .mapping import ( - DatasetMappingCreateRequest, - DatasetMappingCreateResponse, - DatasetMappingUpdateRequest, - DatasetMappingResponse, - DeleteDatasetResponse, -) - from .sync import ( SyncDatasetRequest, SyncDatasetResponse, @@ -30,6 +22,17 @@ from .template import ( AnnotationTemplateListResponse ) +from .mapping import ( + DatasetMappingCreateRequest, + DatasetMappingCreateResponse, + DatasetMappingUpdateRequest, + DatasetMappingResponse, + DeleteDatasetResponse, +) + +# Rebuild model to resolve forward references +DatasetMappingResponse.model_rebuild() + __all__ = [ "ConfigResponse", "TagConfigResponse", diff --git a/runtime/datamate-python/app/module/annotation/schema/mapping.py b/runtime/datamate-python/app/module/annotation/schema/mapping.py index 3dda8ca..a3fcab1 100644 --- a/runtime/datamate-python/app/module/annotation/schema/mapping.py +++ b/runtime/datamate-python/app/module/annotation/schema/mapping.py @@ -1,10 +1,13 @@ from pydantic import Field, BaseModel -from typing import Optional +from typing import Optional, TYPE_CHECKING from datetime import datetime from app.module.shared.schema import BaseResponseModel from app.module.shared.schema import StandardResponse +if TYPE_CHECKING: + from .template import AnnotationTemplateResponse + class DatasetMappingCreateRequest(BaseModel): """数据集映射 创建 请求模型 @@ -42,6 +45,8 @@ class DatasetMappingResponse(BaseModel): labeling_project_id: str = Field(..., alias="labelingProjectId", description="标注项目ID") name: Optional[str] = Field(None, description="标注项目名称") description: Optional[str] = Field(None, description="标注项目描述") + template_id: Optional[str] = Field(None, alias="templateId", description="关联的模板ID") + template: Optional['AnnotationTemplateResponse'] = Field(None, description="关联的标注模板详情") created_at: datetime = Field(..., alias="createdAt", description="创建时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") deleted_at: Optional[datetime] = Field(None, alias="deletedAt", description="删除时间") diff --git a/runtime/datamate-python/app/module/annotation/service/mapping.py b/runtime/datamate-python/app/module/annotation/service/mapping.py index 4702f9f..43a627c 100644 --- a/runtime/datamate-python/app/module/annotation/service/mapping.py +++ b/runtime/datamate-python/app/module/annotation/service/mapping.py @@ -7,12 +7,13 @@ from datetime import datetime import uuid from app.core.logging import get_logger -from app.db.models import LabelingProject +from app.db.models import LabelingProject, AnnotationTemplate from app.db.models.dataset_management import Dataset from app.module.annotation.schema import ( DatasetMappingCreateRequest, DatasetMappingUpdateRequest, - DatasetMappingResponse + DatasetMappingResponse, + AnnotationTemplateResponse ) logger = get_logger(__name__) @@ -33,11 +34,32 @@ class DatasetMappingService: LabelingProject.dataset_id == Dataset.id ) - def _to_response_from_row(self, row) -> DatasetMappingResponse: - """Convert query row (mapping + dataset_name) to response""" + async def _to_response_from_row( + self, + row, + include_template: bool = False + ) -> DatasetMappingResponse: + """ + Convert query row (mapping + dataset_name) to response + + Args: + row: Query result row containing (LabelingProject, dataset_name) + include_template: If True, fetch and include full template details + """ mapping = row[0] # LabelingProject object dataset_name = row[1] # dataset_name from join + # Get template_id from mapping + template_id = getattr(mapping, 'template_id', None) + + # Optionally fetch full template details + template_response = None + if include_template and template_id: + from ..service.template import AnnotationTemplateService + template_service = AnnotationTemplateService() + template_response = await template_service.get_template(self.db, template_id) + logger.debug(f"Included template details for template_id: {template_id}") + response_data = { "id": mapping.id, "dataset_id": mapping.dataset_id, @@ -45,6 +67,8 @@ class DatasetMappingService: "labeling_project_id": mapping.labeling_project_id, "name": mapping.name, "description": getattr(mapping, 'description', None), + "template_id": template_id, + "template": template_response, "created_at": mapping.created_at, "updated_at": mapping.updated_at, "deleted_at": mapping.deleted_at, @@ -52,8 +76,18 @@ class DatasetMappingService: return DatasetMappingResponse(**response_data) - async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse: - """Convert ORM model to response with dataset name (for single entity operations)""" + async def _to_response( + self, + mapping: LabelingProject, + include_template: bool = False + ) -> DatasetMappingResponse: + """ + Convert ORM model to response with dataset name (for single entity operations) + + Args: + mapping: LabelingProject ORM object + include_template: If True, fetch and include full template details + """ # Fetch dataset name dataset_name = None dataset_id = getattr(mapping, 'dataset_id', None) @@ -63,6 +97,17 @@ class DatasetMappingService: ) dataset_name = dataset_result.scalar_one_or_none() + # Get template_id from mapping + template_id = getattr(mapping, 'template_id', None) + + # Optionally fetch full template details + template_response = None + if include_template and template_id: + from ..service.template import AnnotationTemplateService + template_service = AnnotationTemplateService() + template_response = await template_service.get_template(self.db, template_id) + logger.debug(f"Included template details for template_id: {template_id}") + # Create response dict with all fields response_data = { "id": mapping.id, @@ -71,6 +116,8 @@ class DatasetMappingService: "labeling_project_id": mapping.labeling_project_id, "name": mapping.name, "description": getattr(mapping, 'description', None), + "template_id": template_id, + "template": template_response, "created_at": mapping.created_at, "updated_at": mapping.updated_at, "deleted_at": mapping.deleted_at, @@ -136,7 +183,12 @@ class DatasetMappingService: rows = result.all() logger.debug(f"Found {len(rows)} mappings") - return [self._to_response_from_row(row) for row in rows] + # Convert rows to responses (async comprehension) + responses = [] + for row in rows: + response = await self._to_response_from_row(row, include_template=False) + responses.append(response) + return responses async def get_mapping_by_labeling_project_id( self, @@ -160,9 +212,19 @@ class DatasetMappingService: logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}") return None - async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]: - """根据映射UUID获取映射""" - logger.debug(f"Get mapping: {mapping_id}") + async def get_mapping_by_uuid( + self, + mapping_id: str, + include_template: bool = False + ) -> Optional[DatasetMappingResponse]: + """ + 根据映射UUID获取映射 + + Args: + mapping_id: 映射UUID + include_template: 是否包含完整的模板信息 + """ + logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}") result = await self.db.execute( select(LabelingProject).where( @@ -174,7 +236,7 @@ class DatasetMappingService: if mapping: logger.debug(f"Found mapping: {mapping.id}") - return await self._to_response(mapping) + return await self._to_response(mapping, include_template=include_template) logger.debug(f"No mapping found for mapping id: {mapping_id}") return None @@ -248,7 +310,12 @@ class DatasetMappingService: rows = result.all() logger.debug(f"Found {len(rows)} mappings") - return [self._to_response_from_row(row) for row in rows] + # Convert rows to responses (async comprehension) + responses = [] + for row in rows: + response = await self._to_response_from_row(row, include_template=False) + responses.append(response) + return responses async def count_mappings(self, include_deleted: bool = False) -> int: """统计映射总数""" @@ -264,10 +331,19 @@ class DatasetMappingService: self, skip: int = 0, limit: int = 100, - include_deleted: bool = False + include_deleted: bool = False, + include_template: bool = False ) -> Tuple[List[DatasetMappingResponse], int]: - """获取所有映射及总数(用于分页)""" - logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}") + """ + 获取所有映射及总数(用于分页) + + Args: + skip: 跳过记录数 + limit: 返回记录数 + include_deleted: 是否包含已删除的记录 + include_template: 是否包含完整的模板信息 + """ + logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}") # 构建查询 query = self._build_query_with_dataset_name() @@ -292,17 +368,62 @@ class DatasetMappingService: rows = result.all() logger.debug(f"Found {len(rows)} mappings, total: {total}") - return [self._to_response_from_row(row) for row in rows], total + # Convert rows to responses (async comprehension) + responses = [] + for row in rows: + response = await self._to_response_from_row(row, include_template=include_template) + responses.append(response) + return responses, total + + async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]: + """ + Get template ID for a dataset by finding its labeling project + + Args: + dataset_id: Dataset UUID + + Returns: + Template ID or None if no labeling project found or no template associated + """ + logger.debug(f"Looking up template for dataset: {dataset_id}") + + result = await self.db.execute( + select(LabelingProject.template_id) + .where( + LabelingProject.dataset_id == dataset_id, + LabelingProject.deleted_at.is_(None) + ) + .limit(1) + ) + + template_id = result.scalar_one_or_none() + + if template_id: + logger.debug(f"Found template {template_id} for dataset {dataset_id}") + else: + logger.warning(f"No template found for dataset {dataset_id}") + + return template_id async def get_mappings_by_source_with_count( self, dataset_id: str, skip: int = 0, limit: int = 100, - include_deleted: bool = False + include_deleted: bool = False, + include_template: bool = False ) -> Tuple[List[DatasetMappingResponse], int]: - """根据源数据集ID获取映射关系及总数(用于分页)""" - logger.debug(f"Get mappings by source dataset id with count: {dataset_id}") + """ + 根据源数据集ID获取映射关系及总数(用于分页) + + Args: + dataset_id: 数据集ID + skip: 跳过记录数 + limit: 返回记录数 + include_deleted: 是否包含已删除的记录 + include_template: 是否包含完整的模板信息 + """ + logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}") # 构建查询 query = self._build_query_with_dataset_name().where( @@ -332,4 +453,9 @@ class DatasetMappingService: rows = result.all() logger.debug(f"Found {len(rows)} mappings, total: {total}") - return [self._to_response_from_row(row) for row in rows], total \ No newline at end of file + # Convert rows to responses (async comprehension) + responses = [] + for row in rows: + response = await self._to_response_from_row(row, include_template=include_template) + responses.append(response) + return responses, total \ No newline at end of file diff --git a/runtime/datamate-python/app/module/annotation/utils/__init__.py b/runtime/datamate-python/app/module/annotation/utils/__init__.py index 1db0f0c..232f8a8 100644 --- a/runtime/datamate-python/app/module/annotation/utils/__init__.py +++ b/runtime/datamate-python/app/module/annotation/utils/__init__.py @@ -2,5 +2,10 @@ Annotation Module Utilities """ from .config_validator import LabelStudioConfigValidator +from .tag_converter import TagFormatConverter, create_converter_from_template_config -__all__ = ['LabelStudioConfigValidator'] +__all__ = [ + 'LabelStudioConfigValidator', + 'TagFormatConverter', + 'create_converter_from_template_config' +] diff --git a/runtime/datamate-python/app/module/annotation/utils/tag_converter.py b/runtime/datamate-python/app/module/annotation/utils/tag_converter.py new file mode 100644 index 0000000..357d7cc --- /dev/null +++ b/runtime/datamate-python/app/module/annotation/utils/tag_converter.py @@ -0,0 +1,232 @@ +""" +Tag Format Converter + +Converts simplified external tag format to internal storage format by looking up +the type from the annotation template configuration. + +External format (from users): +[ + { + "from_name": "label", + "to_name": "image", + "values": ["cat", "dog"] + } +] + +Internal storage format: +[ + { + "id": "unique_id", + "from_name": "label", + "to_name": "image", + "type": "choices", + "value": { + "choices": ["cat", "dog"] + } + } +] +""" + +import uuid +from typing import List, Dict, Any, Optional +from datetime import datetime + +from app.core.logging import get_logger +from ..schema.template import TemplateConfiguration + +logger = get_logger(__name__) + + +class TagFormatConverter: + """Convert between simplified external tag format and internal storage format""" + + def __init__(self, template_config: TemplateConfiguration): + """ + Initialize converter with template configuration + + Args: + template_config: The template configuration containing label definitions + """ + self.template_config = template_config + # Build a lookup map: from_name -> type + self._type_map = self._build_type_map() + + def _build_type_map(self) -> Dict[str, str]: + """ + Build a mapping from from_name to type from template labels + + Returns: + Dictionary mapping from_name to control type + """ + type_map = {} + for label_def in self.template_config.labels: + from_name = label_def.from_name + control_type = label_def.type + type_map[from_name] = control_type + logger.debug(f"Registered control: {from_name} -> {control_type}") + + return type_map + + def get_type_for_from_name(self, from_name: str) -> Optional[str]: + """ + Get the control type for a given from_name + + Args: + from_name: The control name + + Returns: + Control type or None if not found + """ + return self._type_map.get(from_name) + + def convert_simplified_to_full( + self, + simplified_tags: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Convert simplified tag format to full internal storage format + + Args: + simplified_tags: List of tags in simplified format with structure: + [ + { + "from_name": "label", + "to_name": "image", + "values": ["cat", "dog"] # Can be list or single value + } + ] + + Returns: + List of tags in full internal format: + [ + { + "id": "unique_id", + "from_name": "label", + "to_name": "image", + "type": "choices", + "value": { + "choices": ["cat", "dog"] + } + } + ] + """ + full_tags = [] + + for simplified_tag in simplified_tags: + # Support both camelCase and snake_case from external sources + from_name = simplified_tag.get('from_name') or simplified_tag.get('fromName') + to_name = simplified_tag.get('to_name') or simplified_tag.get('toName') + values = simplified_tag.get('values') + tag_id = simplified_tag.get('id') # Use existing ID if provided + + if not from_name or not to_name: + logger.warning(f"Skipping tag with missing from_name or to_name: {simplified_tag}") + continue + + # Look up the type from template configuration + control_type = self.get_type_for_from_name(from_name) + + if not control_type: + logger.warning( + f"Could not find type for from_name '{from_name}' in template. " + f"Tag will be skipped. Available controls: {list(self._type_map.keys())}" + ) + continue + + # Generate ID if not provided + if not tag_id: + tag_id = str(uuid.uuid4()) + + # Convert values to the proper nested structure + # The key in the value dict should match the control type + full_tag = { + "id": tag_id, + "from_name": from_name, + "to_name": to_name, + "type": control_type, + "value": { + control_type: values + } + } + + full_tags.append(full_tag) + logger.debug(f"Converted tag: {from_name} ({control_type}) with {len(values) if isinstance(values, list) else 1} values") + + return full_tags + + def is_simplified_format(self, tag: Dict[str, Any]) -> bool: + """ + Check if a tag is in simplified format (missing type field) + + Args: + tag: Tag dictionary to check + + Returns: + True if tag appears to be in simplified format + """ + # Simplified format has 'values' at top level and no 'type' field + has_values = 'values' in tag + has_type = 'type' in tag + has_value = 'value' in tag + + # If it has 'values' but no 'type', it's simplified + # If it has 'type' and nested 'value', it's already full format + return has_values and not has_type and not has_value + + def convert_if_needed( + self, + tags: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Convert tags to full format if they are in simplified format + + This method can handle mixed formats - it will convert simplified tags + and pass through tags that are already in full format. + + Args: + tags: List of tags in either format + + Returns: + List of tags in full internal format + """ + if not tags: + return [] + + result = [] + + for tag in tags: + if self.is_simplified_format(tag): + # Convert simplified format + converted = self.convert_simplified_to_full([tag]) + result.extend(converted) + else: + # Already in full format, pass through + result.append(tag) + + return result + + +def create_converter_from_template_config( + template_config_dict: Dict[str, Any] +) -> TagFormatConverter: + """ + Create a TagFormatConverter from a template configuration dictionary + + Args: + template_config_dict: Template configuration as dict (from database JSON) + + Returns: + TagFormatConverter instance + + Raises: + ValueError: If template configuration is invalid + """ + try: + # Parse the configuration using Pydantic model + from ..schema.template import TemplateConfiguration + + template_config = TemplateConfiguration(**template_config_dict) + return TagFormatConverter(template_config) + except Exception as e: + logger.error(f"Failed to create tag converter from template config: {e}") + raise ValueError(f"Invalid template configuration: {e}") diff --git a/runtime/datamate-python/app/module/annotation/utils/test_tag_converter.py b/runtime/datamate-python/app/module/annotation/utils/test_tag_converter.py new file mode 100644 index 0000000..4685704 --- /dev/null +++ b/runtime/datamate-python/app/module/annotation/utils/test_tag_converter.py @@ -0,0 +1,337 @@ +""" +Unit tests for TagFormatConverter + +Run with: pytest app/module/annotation/utils/test_tag_converter.py -v +""" + +import pytest +from .tag_converter import TagFormatConverter, create_converter_from_template_config +from ..schema.template import TemplateConfiguration, LabelDefinition, ObjectDefinition + + +@pytest.fixture +def sample_template_config(): + """Create a sample template configuration for testing""" + return TemplateConfiguration( + labels=[ + LabelDefinition( + fromName="sentiment", + toName="text", + type="choices", + options=["positive", "negative", "neutral"], + required=True, + labels=None, + description=None + ), + LabelDefinition( + fromName="bbox", + toName="image", + type="rectanglelabels", + labels=["cat", "dog", "bird"], + required=False, + options=None, + description=None + ), + LabelDefinition( + fromName="comment", + toName="text", + type="textarea", + required=False, + options=None, + labels=None, + description=None + ) + ], + objects=[ + ObjectDefinition(name="text", type="Text", value="$text"), + ObjectDefinition(name="image", type="Image", value="$image") + ], + metadata=None + ) + + +@pytest.fixture +def converter(sample_template_config): + """Create a converter instance""" + return TagFormatConverter(sample_template_config) + + +class TestTagFormatConverter: + """Test TagFormatConverter functionality""" + + def test_type_map_building(self, converter): + """Test that type map is built correctly from template""" + assert converter.get_type_for_from_name("sentiment") == "choices" + assert converter.get_type_for_from_name("bbox") == "rectanglelabels" + assert converter.get_type_for_from_name("comment") == "textarea" + assert converter.get_type_for_from_name("nonexistent") is None + + def test_convert_simplified_to_full_single_value(self, converter): + """Test conversion of simplified format with single value""" + simplified = [ + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 1 + tag = result[0] + assert tag["from_name"] == "sentiment" + assert tag["to_name"] == "text" + assert tag["type"] == "choices" + assert tag["value"] == {"choices": ["positive"]} + assert "id" in tag + + def test_convert_simplified_to_full_multiple_values(self, converter): + """Test conversion of simplified format with multiple values""" + simplified = [ + { + "from_name": "bbox", + "to_name": "image", + "values": ["cat", "dog"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 1 + tag = result[0] + assert tag["type"] == "rectanglelabels" + assert tag["value"] == {"rectanglelabels": ["cat", "dog"]} + + def test_convert_simplified_camelcase(self, converter): + """Test that camelCase field names are supported""" + simplified = [ + { + "fromName": "sentiment", # camelCase + "toName": "text", # camelCase + "values": ["neutral"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 1 + assert result[0]["from_name"] == "sentiment" + assert result[0]["to_name"] == "text" + + def test_convert_multiple_tags(self, converter): + """Test conversion of multiple tags at once""" + simplified = [ + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + }, + { + "from_name": "bbox", + "to_name": "image", + "values": ["cat"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 2 + assert result[0]["type"] == "choices" + assert result[1]["type"] == "rectanglelabels" + + def test_convert_with_existing_id(self, converter): + """Test that existing IDs are preserved""" + existing_id = "my-custom-id-123" + simplified = [ + { + "id": existing_id, + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert result[0]["id"] == existing_id + + def test_skip_unknown_from_name(self, converter): + """Test that tags with unknown from_name are skipped""" + simplified = [ + { + "from_name": "unknown_control", + "to_name": "text", + "values": ["value"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 0 # Should be skipped + + def test_skip_missing_fields(self, converter): + """Test that tags with missing required fields are skipped""" + simplified = [ + { + "from_name": "sentiment", + # Missing to_name + "values": ["positive"] + } + ] + + result = converter.convert_simplified_to_full(simplified) + + assert len(result) == 0 # Should be skipped + + def test_is_simplified_format(self, converter): + """Test detection of simplified format""" + # Simplified format + assert converter.is_simplified_format({ + "from_name": "x", + "to_name": "y", + "values": ["a"] + }) is True + + # Full format + assert converter.is_simplified_format({ + "id": "123", + "from_name": "x", + "to_name": "y", + "type": "choices", + "value": {"choices": ["a"]} + }) is False + + # Edge case: has both (should not be considered simplified) + assert converter.is_simplified_format({ + "from_name": "x", + "to_name": "y", + "type": "choices", + "values": ["a"] + }) is False + + def test_convert_if_needed_mixed_formats(self, converter): + """Test conversion of mixed format tags""" + mixed = [ + # Simplified format + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + }, + # Full format + { + "id": "existing-123", + "from_name": "bbox", + "to_name": "image", + "type": "rectanglelabels", + "value": {"rectanglelabels": ["cat"]} + } + ] + + result = converter.convert_if_needed(mixed) + + assert len(result) == 2 + # First should be converted + assert result[0]["type"] == "choices" + assert result[0]["value"] == {"choices": ["positive"]} + # Second should pass through unchanged + assert result[1]["id"] == "existing-123" + assert result[1]["type"] == "rectanglelabels" + + +class TestCreateConverterFromDict: + """Test the factory function for creating converter from dict""" + + def test_create_from_valid_dict(self): + """Test creating converter from valid configuration dict""" + config_dict = { + "labels": [ + { + "fromName": "label", + "toName": "image", + "type": "choices", + "options": ["a", "b"] + } + ], + "objects": [ + { + "name": "image", + "type": "Image", + "value": "$image" + } + ] + } + + converter = create_converter_from_template_config(config_dict) + + assert isinstance(converter, TagFormatConverter) + assert converter.get_type_for_from_name("label") == "choices" + + def test_create_from_invalid_dict(self): + """Test that invalid config raises ValueError""" + invalid_config = { + "labels": "not-a-list", # Should be a list + "objects": [] + } + + with pytest.raises(ValueError, match="Invalid template configuration"): + create_converter_from_template_config(invalid_config) + + +class TestIntegrationScenarios: + """Test real-world usage scenarios""" + + def test_external_api_submission(self, converter): + """Simulate external user submitting tags via API""" + # User submits simplified format + user_submission = [ + { + "fromName": "sentiment", # User uses camelCase + "toName": "text", + "values": ["positive", "negative"] + } + ] + + # System converts to internal format + internal_tags = converter.convert_if_needed(user_submission) + + # Verify correct storage format + assert len(internal_tags) == 1 + assert internal_tags[0]["type"] == "choices" + assert internal_tags[0]["value"] == {"choices": ["positive", "negative"]} + assert "id" in internal_tags[0] + + def test_update_existing_tags(self, converter): + """Simulate updating existing tags with new values""" + # Existing tags in database (full format) + existing_tags = [ + { + "id": "tag-001", + "from_name": "sentiment", + "to_name": "text", + "type": "choices", + "value": {"choices": ["positive"]} + } + ] + + # User updates with simplified format + update_request = [ + { + "id": "tag-001", # Same ID to update + "from_name": "sentiment", + "to_name": "text", + "values": ["negative"] # New value + } + ] + + # Convert update request + converted_update = converter.convert_if_needed(update_request) + + # Merge logic would replace tag-001 + assert converted_update[0]["id"] == "tag-001" + assert converted_update[0]["value"] == {"choices": ["negative"]} + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/runtime/datamate-python/app/module/dataset/service/service.py b/runtime/datamate-python/app/module/dataset/service/service.py index 210fb13..aff2f10 100644 --- a/runtime/datamate-python/app/module/dataset/service/service.py +++ b/runtime/datamate-python/app/module/dataset/service/service.py @@ -165,14 +165,20 @@ class Service: async def update_file_tags_partial( self, file_id: str, - new_tags: List[Dict[str, Any]] + new_tags: List[Dict[str, Any]], + template_id: Optional[str] = None ) -> tuple[bool, Optional[str], Optional[datetime]]: """ - 部分更新文件标签 + 部分更新文件标签,支持自动格式转换 + + 如果提供了 template_id,会自动将简化格式的标签转换为完整格式。 + 简化格式: {"from_name": "x", "to_name": "y", "values": [...]} + 完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}} Args: file_id: 文件ID - new_tags: 新的标签列表(部分更新) + new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式 + template_id: 可选的模板ID,用于格式转换 Returns: (成功标志, 错误信息, 更新时间) @@ -190,6 +196,38 @@ class Service: logger.error(f"File not found: {file_id}") return False, f"File not found: {file_id}", None + # 如果提供了 template_id,尝试进行格式转换 + processed_tags = new_tags + if template_id: + logger.debug(f"Converting tags using template: {template_id}") + + try: + # 获取模板配置 + from app.db.models import AnnotationTemplate + template_result = await self.db.execute( + select(AnnotationTemplate).where( + AnnotationTemplate.id == template_id, + AnnotationTemplate.deleted_at.is_(None) + ) + ) + template = template_result.scalar_one_or_none() + + if not template: + logger.warning(f"Template {template_id} not found, skipping conversion") + else: + # 使用 converter 转换标签格式 + from app.module.annotation.utils import create_converter_from_template_config + + converter = create_converter_from_template_config(template.configuration) # type: ignore + processed_tags = converter.convert_if_needed(new_tags) + + logger.info(f"Converted {len(new_tags)} tags to full format") + + except Exception as e: + logger.error(f"Failed to convert tags using template: {e}") + # 继续使用原始标签格式 + logger.warning("Continuing with original tag format") + # 获取现有标签 existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore @@ -197,7 +235,7 @@ class Service: tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')} # 更新或追加标签 - for new_tag in new_tags: + for new_tag in processed_tags: tag_id = new_tag.get('id') if tag_id and tag_id in tag_id_map: # 更新现有标签 diff --git a/runtime/datamate-python/examples/tag_format_conversion_examples.py b/runtime/datamate-python/examples/tag_format_conversion_examples.py new file mode 100644 index 0000000..207c647 --- /dev/null +++ b/runtime/datamate-python/examples/tag_format_conversion_examples.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +""" +Example: Tag Format Conversion Usage + +This script demonstrates how to use the tag format conversion feature +to update file tags using the simplified external format. + +Run this script after: +1. Creating a dataset +2. Creating an annotation template +3. Creating a labeling project that links the dataset and template +4. Uploading files to the dataset +""" + +import asyncio +import httpx +import json +from typing import List, Dict, Any + + +# Configuration +API_BASE_URL = "http://localhost:8000" +FILE_ID = "your-file-uuid-here" # Replace with actual file ID + + +async def update_file_tags_simplified( + file_id: str, + tags: List[Dict[str, Any]] +) -> Dict[str, Any]: + """ + Update file tags using simplified format + + Args: + file_id: UUID of the file to update + tags: List of tags in simplified format + + Returns: + API response + """ + url = f"{API_BASE_URL}/api/annotation/task/{file_id}" + + payload = { + "tags": tags + } + + async with httpx.AsyncClient() as client: + response = await client.put(url, json=payload, timeout=30.0) + response.raise_for_status() + return response.json() + + +async def example_1_basic_update(): + """Example 1: Basic tag update with simplified format""" + print("\n=== Example 1: Basic Tag Update ===\n") + + # Simplified format - no type or nested value required + tags = [ + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive", "negative"] + } + ] + + print("Submitting tags in simplified format:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + + print("\n✓ Success! Response:") + print(json.dumps(result, indent=2)) + + # The response will contain tags in full internal format + if result.get("data", {}).get("tags"): + stored_tag = result["data"]["tags"][0] + print("\n📝 Stored tag format:") + print(f" - ID: {stored_tag.get('id')}") + print(f" - Type: {stored_tag.get('type')}") + print(f" - Value: {stored_tag.get('value')}") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def example_2_multiple_controls(): + """Example 2: Update multiple different control types""" + print("\n=== Example 2: Multiple Control Types ===\n") + + tags = [ + # Text classification + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + }, + # Image bounding boxes + { + "from_name": "bbox", + "to_name": "image", + "values": ["cat", "dog"] + }, + # Text comment + { + "from_name": "comment", + "to_name": "text", + "values": ["This is a great example"] + } + ] + + print("Submitting multiple control types:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + print("\n✓ Success! All tags converted and stored.") + print(f" Total tags: {len(result['data']['tags'])}") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def example_3_update_existing(): + """Example 3: Update existing tag by preserving ID""" + print("\n=== Example 3: Update Existing Tag ===\n") + + # First, let's assume we know the ID of an existing tag + existing_tag_id = "some-existing-uuid" + + tags = [ + { + "id": existing_tag_id, # Preserve ID to update, not create new + "from_name": "sentiment", + "to_name": "text", + "values": ["neutral"] # Change value + } + ] + + print("Updating existing tag by ID:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + print("\n✓ Success! Existing tag updated.") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def example_4_camelcase(): + """Example 4: Using camelCase field names (frontend style)""" + print("\n=== Example 4: CamelCase Field Names ===\n") + + # Frontend typically sends camelCase + tags = [ + { + "fromName": "sentiment", # camelCase + "toName": "text", # camelCase + "values": ["positive"] + } + ] + + print("Submitting with camelCase:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + print("\n✓ Success! CamelCase automatically handled.") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def example_5_mixed_format(): + """Example 5: Mixed format - some simplified, some full""" + print("\n=== Example 5: Mixed Format Support ===\n") + + tags = [ + # Simplified format + { + "from_name": "new_label", + "to_name": "image", + "values": ["new_value"] + }, + # Full format (already has type and nested value) + { + "id": "existing-uuid-123", + "from_name": "old_label", + "to_name": "image", + "type": "choices", + "value": { + "choices": ["old_value"] + } + } + ] + + print("Submitting mixed format:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + print("\n✓ Success! Mixed format handled correctly.") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def example_6_error_handling(): + """Example 6: Error handling - unknown control""" + print("\n=== Example 6: Error Handling ===\n") + + tags = [ + # Valid control + { + "from_name": "sentiment", + "to_name": "text", + "values": ["positive"] + }, + # Invalid control - doesn't exist in template + { + "from_name": "unknown_control", + "to_name": "text", + "values": ["some_value"] + } + ] + + print("Submitting with unknown control:") + print(json.dumps(tags, indent=2)) + + try: + result = await update_file_tags_simplified(FILE_ID, tags) + print("\n⚠ Partial Success:") + print(f" Valid tags stored: {len(result['data']['tags'])}") + print(" (Unknown control was skipped)") + + except Exception as e: + print(f"\n✗ Error: {e}") + + +async def main(): + """Run all examples""" + print("\n" + "="*60) + print("Tag Format Conversion Examples") + print("="*60) + + print(f"\nAPI Base URL: {API_BASE_URL}") + print(f"File ID: {FILE_ID}") + print("\n⚠ Make sure to replace FILE_ID with an actual file UUID!") + + # Uncomment the examples you want to run: + + # await example_1_basic_update() + # await example_2_multiple_controls() + # await example_3_update_existing() + # await example_4_camelcase() + # await example_5_mixed_format() + # await example_6_error_handling() + + print("\n" + "="*60) + print("Tip: Edit this script to uncomment examples you want to run") + print("="*60 + "\n") + + +if __name__ == "__main__": + # Run the examples + asyncio.run(main())