You've already forked DataMate
feat: Enhance file tag update functionality with automatic format conversion (#84)
- Updated `update_file_tags` to support both simplified and full tag formats. - Introduced `TagFormatConverter` to handle conversion from simplified external tags to internal storage format. - Added logic to fetch and utilize the appropriate annotation template for conversion. - Improved error handling for missing templates and unknown controls during tag updates. - Created example script demonstrating the usage of the new tag format conversion feature. - Added unit tests for `TagFormatConverter` to ensure correct functionality and edge case handling.
This commit is contained in:
@@ -150,12 +150,19 @@ async def create_mapping(
|
||||
async def list_mappings(
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
|
||||
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
查询所有映射关系(分页)
|
||||
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
|
||||
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
|
||||
|
||||
参数:
|
||||
- page: 页码(从1开始)
|
||||
- pageSize: 每页记录数
|
||||
- includeTemplate: 是否包含模板详情(默认false)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
@@ -163,10 +170,14 @@ async def list_mappings(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_all_mappings_with_count(
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
skip=skip,
|
||||
limit=page_size,
|
||||
include_deleted=False,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -181,7 +192,7 @@ async def list_mappings(
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}")
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -199,14 +210,21 @@ async def get_mapping(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系
|
||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||
|
||||
返回数据集映射关系以及关联的完整标注模板信息,包括:
|
||||
- 映射基本信息
|
||||
- 数据集信息
|
||||
- Label Studio 项目信息
|
||||
- 完整的标注模板配置(如果存在)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Get mapping: {mapping_id}")
|
||||
logger.info(f"Get mapping with template details: {mapping_id}")
|
||||
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id)
|
||||
# 获取映射,并包含完整的模板信息
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
@@ -214,7 +232,7 @@ async def get_mapping(
|
||||
detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Found mapping: {mapping.id}")
|
||||
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -233,12 +251,20 @@ async def get_mappings_by_source(
|
||||
dataset_id: str,
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
|
||||
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系(分页)
|
||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
|
||||
默认包含关联的完整标注模板信息。
|
||||
|
||||
参数:
|
||||
- dataset_id: 数据集ID
|
||||
- page: 页码(从1开始)
|
||||
- pageSize: 每页记录数
|
||||
- includeTemplate: 是否包含模板详情(默认true)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
@@ -246,13 +272,14 @@ async def get_mappings_by_source(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}")
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
# 获取数据和总数
|
||||
# 获取数据和总数(包含模板信息)
|
||||
mappings, total = await service.get_mappings_by_source_with_count(
|
||||
dataset_id=dataset_id,
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
limit=page_size,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -267,7 +294,7 @@ async def get_mappings_by_source(
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
|
||||
@@ -244,16 +244,53 @@ async def update_file_tags(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update File Tags (Partial Update)
|
||||
Update File Tags (Partial Update with Auto Format Conversion)
|
||||
|
||||
接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。
|
||||
|
||||
支持两种标签格式:
|
||||
1. 简化格式(外部用户提交):
|
||||
[{"from_name": "label", "to_name": "image", "values": ["cat", "dog"]}]
|
||||
|
||||
2. 完整格式(内部存储):
|
||||
[{"id": "...", "from_name": "label", "to_name": "image", "type": "choices",
|
||||
"value": {"choices": ["cat", "dog"]}}]
|
||||
|
||||
系统会自动根据数据集关联的模板将简化格式转换为完整格式。
|
||||
请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。
|
||||
"""
|
||||
service = DatasetManagementService(db)
|
||||
|
||||
# 首先获取文件所属的数据集
|
||||
from sqlalchemy.future import select
|
||||
from app.db.models import DatasetFiles
|
||||
|
||||
result = await db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
file_record = result.scalar_one_or_none()
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
|
||||
|
||||
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
|
||||
|
||||
# 查找数据集关联的模板ID
|
||||
from ..service.mapping import DatasetMappingService
|
||||
|
||||
mapping_service = DatasetMappingService(db)
|
||||
template_id = await mapping_service.get_template_id_by_dataset_id(dataset_id)
|
||||
|
||||
if template_id:
|
||||
logger.info(f"Found template {template_id} for dataset {dataset_id}, will auto-convert tag format")
|
||||
else:
|
||||
logger.warning(f"No template found for dataset {dataset_id}, tags must be in full format")
|
||||
|
||||
# 更新标签(如果有模板ID则自动转换格式)
|
||||
success, error_msg, updated_at = await service.update_file_tags_partial(
|
||||
file_id=file_id,
|
||||
new_tags=request.tags
|
||||
new_tags=request.tags,
|
||||
template_id=template_id # 传递模板ID以启用自动转换
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -261,10 +298,7 @@ async def update_file_tags(
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg or "更新标签失败")
|
||||
|
||||
# 获取更新后的完整标签列表
|
||||
from sqlalchemy.future import select
|
||||
from app.db.models import DatasetFiles
|
||||
|
||||
# 重新获取更新后的文件记录(获取完整标签列表)
|
||||
result = await db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
|
||||
@@ -3,14 +3,6 @@ from .config import (
|
||||
TagConfigResponse
|
||||
)
|
||||
|
||||
from .mapping import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingCreateResponse,
|
||||
DatasetMappingUpdateRequest,
|
||||
DatasetMappingResponse,
|
||||
DeleteDatasetResponse,
|
||||
)
|
||||
|
||||
from .sync import (
|
||||
SyncDatasetRequest,
|
||||
SyncDatasetResponse,
|
||||
@@ -30,6 +22,17 @@ from .template import (
|
||||
AnnotationTemplateListResponse
|
||||
)
|
||||
|
||||
from .mapping import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingCreateResponse,
|
||||
DatasetMappingUpdateRequest,
|
||||
DatasetMappingResponse,
|
||||
DeleteDatasetResponse,
|
||||
)
|
||||
|
||||
# Rebuild model to resolve forward references
|
||||
DatasetMappingResponse.model_rebuild()
|
||||
|
||||
__all__ = [
|
||||
"ConfigResponse",
|
||||
"TagConfigResponse",
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from pydantic import Field, BaseModel
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
from app.module.shared.schema import BaseResponseModel
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .template import AnnotationTemplateResponse
|
||||
|
||||
|
||||
class DatasetMappingCreateRequest(BaseModel):
|
||||
"""数据集映射 创建 请求模型
|
||||
@@ -42,6 +45,8 @@ class DatasetMappingResponse(BaseModel):
|
||||
labeling_project_id: str = Field(..., alias="labelingProjectId", description="标注项目ID")
|
||||
name: Optional[str] = Field(None, description="标注项目名称")
|
||||
description: Optional[str] = Field(None, description="标注项目描述")
|
||||
template_id: Optional[str] = Field(None, alias="templateId", description="关联的模板ID")
|
||||
template: Optional['AnnotationTemplateResponse'] = Field(None, description="关联的标注模板详情")
|
||||
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
|
||||
deleted_at: Optional[datetime] = Field(None, alias="deletedAt", description="删除时间")
|
||||
|
||||
@@ -7,12 +7,13 @@ from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import LabelingProject
|
||||
from app.db.models import LabelingProject, AnnotationTemplate
|
||||
from app.db.models.dataset_management import Dataset
|
||||
from app.module.annotation.schema import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingUpdateRequest,
|
||||
DatasetMappingResponse
|
||||
DatasetMappingResponse,
|
||||
AnnotationTemplateResponse
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -33,11 +34,32 @@ class DatasetMappingService:
|
||||
LabelingProject.dataset_id == Dataset.id
|
||||
)
|
||||
|
||||
def _to_response_from_row(self, row) -> DatasetMappingResponse:
|
||||
"""Convert query row (mapping + dataset_name) to response"""
|
||||
async def _to_response_from_row(
|
||||
self,
|
||||
row,
|
||||
include_template: bool = False
|
||||
) -> DatasetMappingResponse:
|
||||
"""
|
||||
Convert query row (mapping + dataset_name) to response
|
||||
|
||||
Args:
|
||||
row: Query result row containing (LabelingProject, dataset_name)
|
||||
include_template: If True, fetch and include full template details
|
||||
"""
|
||||
mapping = row[0] # LabelingProject object
|
||||
dataset_name = row[1] # dataset_name from join
|
||||
|
||||
# Get template_id from mapping
|
||||
template_id = getattr(mapping, 'template_id', None)
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
if include_template and template_id:
|
||||
from ..service.template import AnnotationTemplateService
|
||||
template_service = AnnotationTemplateService()
|
||||
template_response = await template_service.get_template(self.db, template_id)
|
||||
logger.debug(f"Included template details for template_id: {template_id}")
|
||||
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
"dataset_id": mapping.dataset_id,
|
||||
@@ -45,6 +67,8 @@ class DatasetMappingService:
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
@@ -52,8 +76,18 @@ class DatasetMappingService:
|
||||
|
||||
return DatasetMappingResponse(**response_data)
|
||||
|
||||
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
|
||||
"""Convert ORM model to response with dataset name (for single entity operations)"""
|
||||
async def _to_response(
|
||||
self,
|
||||
mapping: LabelingProject,
|
||||
include_template: bool = False
|
||||
) -> DatasetMappingResponse:
|
||||
"""
|
||||
Convert ORM model to response with dataset name (for single entity operations)
|
||||
|
||||
Args:
|
||||
mapping: LabelingProject ORM object
|
||||
include_template: If True, fetch and include full template details
|
||||
"""
|
||||
# Fetch dataset name
|
||||
dataset_name = None
|
||||
dataset_id = getattr(mapping, 'dataset_id', None)
|
||||
@@ -63,6 +97,17 @@ class DatasetMappingService:
|
||||
)
|
||||
dataset_name = dataset_result.scalar_one_or_none()
|
||||
|
||||
# Get template_id from mapping
|
||||
template_id = getattr(mapping, 'template_id', None)
|
||||
|
||||
# Optionally fetch full template details
|
||||
template_response = None
|
||||
if include_template and template_id:
|
||||
from ..service.template import AnnotationTemplateService
|
||||
template_service = AnnotationTemplateService()
|
||||
template_response = await template_service.get_template(self.db, template_id)
|
||||
logger.debug(f"Included template details for template_id: {template_id}")
|
||||
|
||||
# Create response dict with all fields
|
||||
response_data = {
|
||||
"id": mapping.id,
|
||||
@@ -71,6 +116,8 @@ class DatasetMappingService:
|
||||
"labeling_project_id": mapping.labeling_project_id,
|
||||
"name": mapping.name,
|
||||
"description": getattr(mapping, 'description', None),
|
||||
"template_id": template_id,
|
||||
"template": template_response,
|
||||
"created_at": mapping.created_at,
|
||||
"updated_at": mapping.updated_at,
|
||||
"deleted_at": mapping.deleted_at,
|
||||
@@ -136,7 +183,12 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=False)
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def get_mapping_by_labeling_project_id(
|
||||
self,
|
||||
@@ -160,9 +212,19 @@ class DatasetMappingService:
|
||||
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
|
||||
return None
|
||||
|
||||
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
|
||||
"""根据映射UUID获取映射"""
|
||||
logger.debug(f"Get mapping: {mapping_id}")
|
||||
async def get_mapping_by_uuid(
|
||||
self,
|
||||
mapping_id: str,
|
||||
include_template: bool = False
|
||||
) -> Optional[DatasetMappingResponse]:
|
||||
"""
|
||||
根据映射UUID获取映射
|
||||
|
||||
Args:
|
||||
mapping_id: 映射UUID
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject).where(
|
||||
@@ -174,7 +236,7 @@ class DatasetMappingService:
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"Found mapping: {mapping.id}")
|
||||
return await self._to_response(mapping)
|
||||
return await self._to_response(mapping, include_template=include_template)
|
||||
|
||||
logger.debug(f"No mapping found for mapping id: {mapping_id}")
|
||||
return None
|
||||
@@ -248,7 +310,12 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings")
|
||||
return [self._to_response_from_row(row) for row in rows]
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=False)
|
||||
responses.append(response)
|
||||
return responses
|
||||
|
||||
async def count_mappings(self, include_deleted: bool = False) -> int:
|
||||
"""统计映射总数"""
|
||||
@@ -264,10 +331,19 @@ class DatasetMappingService:
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""获取所有映射及总数(用于分页)"""
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
|
||||
"""
|
||||
获取所有映射及总数(用于分页)
|
||||
|
||||
Args:
|
||||
skip: 跳过记录数
|
||||
limit: 返回记录数
|
||||
include_deleted: 是否包含已删除的记录
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}")
|
||||
|
||||
# 构建查询
|
||||
query = self._build_query_with_dataset_name()
|
||||
@@ -292,17 +368,62 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=include_template)
|
||||
responses.append(response)
|
||||
return responses, total
|
||||
|
||||
async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get template ID for a dataset by finding its labeling project
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset UUID
|
||||
|
||||
Returns:
|
||||
Template ID or None if no labeling project found or no template associated
|
||||
"""
|
||||
logger.debug(f"Looking up template for dataset: {dataset_id}")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(LabelingProject.template_id)
|
||||
.where(
|
||||
LabelingProject.dataset_id == dataset_id,
|
||||
LabelingProject.deleted_at.is_(None)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
template_id = result.scalar_one_or_none()
|
||||
|
||||
if template_id:
|
||||
logger.debug(f"Found template {template_id} for dataset {dataset_id}")
|
||||
else:
|
||||
logger.warning(f"No template found for dataset {dataset_id}")
|
||||
|
||||
return template_id
|
||||
|
||||
async def get_mappings_by_source_with_count(
|
||||
self,
|
||||
dataset_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
include_deleted: bool = False,
|
||||
include_template: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""根据源数据集ID获取映射关系及总数(用于分页)"""
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
|
||||
"""
|
||||
根据源数据集ID获取映射关系及总数(用于分页)
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
skip: 跳过记录数
|
||||
limit: 返回记录数
|
||||
include_deleted: 是否包含已删除的记录
|
||||
include_template: 是否包含完整的模板信息
|
||||
"""
|
||||
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}")
|
||||
|
||||
# 构建查询
|
||||
query = self._build_query_with_dataset_name().where(
|
||||
@@ -332,4 +453,9 @@ class DatasetMappingService:
|
||||
rows = result.all()
|
||||
|
||||
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
||||
return [self._to_response_from_row(row) for row in rows], total
|
||||
# Convert rows to responses (async comprehension)
|
||||
responses = []
|
||||
for row in rows:
|
||||
response = await self._to_response_from_row(row, include_template=include_template)
|
||||
responses.append(response)
|
||||
return responses, total
|
||||
@@ -2,5 +2,10 @@
|
||||
Annotation Module Utilities
|
||||
"""
|
||||
from .config_validator import LabelStudioConfigValidator
|
||||
from .tag_converter import TagFormatConverter, create_converter_from_template_config
|
||||
|
||||
__all__ = ['LabelStudioConfigValidator']
|
||||
__all__ = [
|
||||
'LabelStudioConfigValidator',
|
||||
'TagFormatConverter',
|
||||
'create_converter_from_template_config'
|
||||
]
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tag Format Converter
|
||||
|
||||
Converts simplified external tag format to internal storage format by looking up
|
||||
the type from the annotation template configuration.
|
||||
|
||||
External format (from users):
|
||||
[
|
||||
{
|
||||
"from_name": "label",
|
||||
"to_name": "image",
|
||||
"values": ["cat", "dog"]
|
||||
}
|
||||
]
|
||||
|
||||
Internal storage format:
|
||||
[
|
||||
{
|
||||
"id": "unique_id",
|
||||
"from_name": "label",
|
||||
"to_name": "image",
|
||||
"type": "choices",
|
||||
"value": {
|
||||
"choices": ["cat", "dog"]
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from ..schema.template import TemplateConfiguration
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TagFormatConverter:
|
||||
"""Convert between simplified external tag format and internal storage format"""
|
||||
|
||||
def __init__(self, template_config: TemplateConfiguration):
|
||||
"""
|
||||
Initialize converter with template configuration
|
||||
|
||||
Args:
|
||||
template_config: The template configuration containing label definitions
|
||||
"""
|
||||
self.template_config = template_config
|
||||
# Build a lookup map: from_name -> type
|
||||
self._type_map = self._build_type_map()
|
||||
|
||||
def _build_type_map(self) -> Dict[str, str]:
|
||||
"""
|
||||
Build a mapping from from_name to type from template labels
|
||||
|
||||
Returns:
|
||||
Dictionary mapping from_name to control type
|
||||
"""
|
||||
type_map = {}
|
||||
for label_def in self.template_config.labels:
|
||||
from_name = label_def.from_name
|
||||
control_type = label_def.type
|
||||
type_map[from_name] = control_type
|
||||
logger.debug(f"Registered control: {from_name} -> {control_type}")
|
||||
|
||||
return type_map
|
||||
|
||||
def get_type_for_from_name(self, from_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get the control type for a given from_name
|
||||
|
||||
Args:
|
||||
from_name: The control name
|
||||
|
||||
Returns:
|
||||
Control type or None if not found
|
||||
"""
|
||||
return self._type_map.get(from_name)
|
||||
|
||||
def convert_simplified_to_full(
|
||||
self,
|
||||
simplified_tags: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert simplified tag format to full internal storage format
|
||||
|
||||
Args:
|
||||
simplified_tags: List of tags in simplified format with structure:
|
||||
[
|
||||
{
|
||||
"from_name": "label",
|
||||
"to_name": "image",
|
||||
"values": ["cat", "dog"] # Can be list or single value
|
||||
}
|
||||
]
|
||||
|
||||
Returns:
|
||||
List of tags in full internal format:
|
||||
[
|
||||
{
|
||||
"id": "unique_id",
|
||||
"from_name": "label",
|
||||
"to_name": "image",
|
||||
"type": "choices",
|
||||
"value": {
|
||||
"choices": ["cat", "dog"]
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
full_tags = []
|
||||
|
||||
for simplified_tag in simplified_tags:
|
||||
# Support both camelCase and snake_case from external sources
|
||||
from_name = simplified_tag.get('from_name') or simplified_tag.get('fromName')
|
||||
to_name = simplified_tag.get('to_name') or simplified_tag.get('toName')
|
||||
values = simplified_tag.get('values')
|
||||
tag_id = simplified_tag.get('id') # Use existing ID if provided
|
||||
|
||||
if not from_name or not to_name:
|
||||
logger.warning(f"Skipping tag with missing from_name or to_name: {simplified_tag}")
|
||||
continue
|
||||
|
||||
# Look up the type from template configuration
|
||||
control_type = self.get_type_for_from_name(from_name)
|
||||
|
||||
if not control_type:
|
||||
logger.warning(
|
||||
f"Could not find type for from_name '{from_name}' in template. "
|
||||
f"Tag will be skipped. Available controls: {list(self._type_map.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Generate ID if not provided
|
||||
if not tag_id:
|
||||
tag_id = str(uuid.uuid4())
|
||||
|
||||
# Convert values to the proper nested structure
|
||||
# The key in the value dict should match the control type
|
||||
full_tag = {
|
||||
"id": tag_id,
|
||||
"from_name": from_name,
|
||||
"to_name": to_name,
|
||||
"type": control_type,
|
||||
"value": {
|
||||
control_type: values
|
||||
}
|
||||
}
|
||||
|
||||
full_tags.append(full_tag)
|
||||
logger.debug(f"Converted tag: {from_name} ({control_type}) with {len(values) if isinstance(values, list) else 1} values")
|
||||
|
||||
return full_tags
|
||||
|
||||
def is_simplified_format(self, tag: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a tag is in simplified format (missing type field)
|
||||
|
||||
Args:
|
||||
tag: Tag dictionary to check
|
||||
|
||||
Returns:
|
||||
True if tag appears to be in simplified format
|
||||
"""
|
||||
# Simplified format has 'values' at top level and no 'type' field
|
||||
has_values = 'values' in tag
|
||||
has_type = 'type' in tag
|
||||
has_value = 'value' in tag
|
||||
|
||||
# If it has 'values' but no 'type', it's simplified
|
||||
# If it has 'type' and nested 'value', it's already full format
|
||||
return has_values and not has_type and not has_value
|
||||
|
||||
def convert_if_needed(
|
||||
self,
|
||||
tags: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert tags to full format if they are in simplified format
|
||||
|
||||
This method can handle mixed formats - it will convert simplified tags
|
||||
and pass through tags that are already in full format.
|
||||
|
||||
Args:
|
||||
tags: List of tags in either format
|
||||
|
||||
Returns:
|
||||
List of tags in full internal format
|
||||
"""
|
||||
if not tags:
|
||||
return []
|
||||
|
||||
result = []
|
||||
|
||||
for tag in tags:
|
||||
if self.is_simplified_format(tag):
|
||||
# Convert simplified format
|
||||
converted = self.convert_simplified_to_full([tag])
|
||||
result.extend(converted)
|
||||
else:
|
||||
# Already in full format, pass through
|
||||
result.append(tag)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_converter_from_template_config(
|
||||
template_config_dict: Dict[str, Any]
|
||||
) -> TagFormatConverter:
|
||||
"""
|
||||
Create a TagFormatConverter from a template configuration dictionary
|
||||
|
||||
Args:
|
||||
template_config_dict: Template configuration as dict (from database JSON)
|
||||
|
||||
Returns:
|
||||
TagFormatConverter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If template configuration is invalid
|
||||
"""
|
||||
try:
|
||||
# Parse the configuration using Pydantic model
|
||||
from ..schema.template import TemplateConfiguration
|
||||
|
||||
template_config = TemplateConfiguration(**template_config_dict)
|
||||
return TagFormatConverter(template_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tag converter from template config: {e}")
|
||||
raise ValueError(f"Invalid template configuration: {e}")
|
||||
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Unit tests for TagFormatConverter
|
||||
|
||||
Run with: pytest app/module/annotation/utils/test_tag_converter.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from .tag_converter import TagFormatConverter, create_converter_from_template_config
|
||||
from ..schema.template import TemplateConfiguration, LabelDefinition, ObjectDefinition
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template_config():
|
||||
"""Create a sample template configuration for testing"""
|
||||
return TemplateConfiguration(
|
||||
labels=[
|
||||
LabelDefinition(
|
||||
fromName="sentiment",
|
||||
toName="text",
|
||||
type="choices",
|
||||
options=["positive", "negative", "neutral"],
|
||||
required=True,
|
||||
labels=None,
|
||||
description=None
|
||||
),
|
||||
LabelDefinition(
|
||||
fromName="bbox",
|
||||
toName="image",
|
||||
type="rectanglelabels",
|
||||
labels=["cat", "dog", "bird"],
|
||||
required=False,
|
||||
options=None,
|
||||
description=None
|
||||
),
|
||||
LabelDefinition(
|
||||
fromName="comment",
|
||||
toName="text",
|
||||
type="textarea",
|
||||
required=False,
|
||||
options=None,
|
||||
labels=None,
|
||||
description=None
|
||||
)
|
||||
],
|
||||
objects=[
|
||||
ObjectDefinition(name="text", type="Text", value="$text"),
|
||||
ObjectDefinition(name="image", type="Image", value="$image")
|
||||
],
|
||||
metadata=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def converter(sample_template_config):
|
||||
"""Create a converter instance"""
|
||||
return TagFormatConverter(sample_template_config)
|
||||
|
||||
|
||||
class TestTagFormatConverter:
|
||||
"""Test TagFormatConverter functionality"""
|
||||
|
||||
def test_type_map_building(self, converter):
|
||||
"""Test that type map is built correctly from template"""
|
||||
assert converter.get_type_for_from_name("sentiment") == "choices"
|
||||
assert converter.get_type_for_from_name("bbox") == "rectanglelabels"
|
||||
assert converter.get_type_for_from_name("comment") == "textarea"
|
||||
assert converter.get_type_for_from_name("nonexistent") is None
|
||||
|
||||
def test_convert_simplified_to_full_single_value(self, converter):
|
||||
"""Test conversion of simplified format with single value"""
|
||||
simplified = [
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 1
|
||||
tag = result[0]
|
||||
assert tag["from_name"] == "sentiment"
|
||||
assert tag["to_name"] == "text"
|
||||
assert tag["type"] == "choices"
|
||||
assert tag["value"] == {"choices": ["positive"]}
|
||||
assert "id" in tag
|
||||
|
||||
def test_convert_simplified_to_full_multiple_values(self, converter):
|
||||
"""Test conversion of simplified format with multiple values"""
|
||||
simplified = [
|
||||
{
|
||||
"from_name": "bbox",
|
||||
"to_name": "image",
|
||||
"values": ["cat", "dog"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 1
|
||||
tag = result[0]
|
||||
assert tag["type"] == "rectanglelabels"
|
||||
assert tag["value"] == {"rectanglelabels": ["cat", "dog"]}
|
||||
|
||||
def test_convert_simplified_camelcase(self, converter):
|
||||
"""Test that camelCase field names are supported"""
|
||||
simplified = [
|
||||
{
|
||||
"fromName": "sentiment", # camelCase
|
||||
"toName": "text", # camelCase
|
||||
"values": ["neutral"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["from_name"] == "sentiment"
|
||||
assert result[0]["to_name"] == "text"
|
||||
|
||||
def test_convert_multiple_tags(self, converter):
|
||||
"""Test conversion of multiple tags at once"""
|
||||
simplified = [
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
},
|
||||
{
|
||||
"from_name": "bbox",
|
||||
"to_name": "image",
|
||||
"values": ["cat"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["type"] == "choices"
|
||||
assert result[1]["type"] == "rectanglelabels"
|
||||
|
||||
def test_convert_with_existing_id(self, converter):
|
||||
"""Test that existing IDs are preserved"""
|
||||
existing_id = "my-custom-id-123"
|
||||
simplified = [
|
||||
{
|
||||
"id": existing_id,
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert result[0]["id"] == existing_id
|
||||
|
||||
def test_skip_unknown_from_name(self, converter):
|
||||
"""Test that tags with unknown from_name are skipped"""
|
||||
simplified = [
|
||||
{
|
||||
"from_name": "unknown_control",
|
||||
"to_name": "text",
|
||||
"values": ["value"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 0 # Should be skipped
|
||||
|
||||
def test_skip_missing_fields(self, converter):
|
||||
"""Test that tags with missing required fields are skipped"""
|
||||
simplified = [
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
# Missing to_name
|
||||
"values": ["positive"]
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_simplified_to_full(simplified)
|
||||
|
||||
assert len(result) == 0 # Should be skipped
|
||||
|
||||
def test_is_simplified_format(self, converter):
|
||||
"""Test detection of simplified format"""
|
||||
# Simplified format
|
||||
assert converter.is_simplified_format({
|
||||
"from_name": "x",
|
||||
"to_name": "y",
|
||||
"values": ["a"]
|
||||
}) is True
|
||||
|
||||
# Full format
|
||||
assert converter.is_simplified_format({
|
||||
"id": "123",
|
||||
"from_name": "x",
|
||||
"to_name": "y",
|
||||
"type": "choices",
|
||||
"value": {"choices": ["a"]}
|
||||
}) is False
|
||||
|
||||
# Edge case: has both (should not be considered simplified)
|
||||
assert converter.is_simplified_format({
|
||||
"from_name": "x",
|
||||
"to_name": "y",
|
||||
"type": "choices",
|
||||
"values": ["a"]
|
||||
}) is False
|
||||
|
||||
def test_convert_if_needed_mixed_formats(self, converter):
|
||||
"""Test conversion of mixed format tags"""
|
||||
mixed = [
|
||||
# Simplified format
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
},
|
||||
# Full format
|
||||
{
|
||||
"id": "existing-123",
|
||||
"from_name": "bbox",
|
||||
"to_name": "image",
|
||||
"type": "rectanglelabels",
|
||||
"value": {"rectanglelabels": ["cat"]}
|
||||
}
|
||||
]
|
||||
|
||||
result = converter.convert_if_needed(mixed)
|
||||
|
||||
assert len(result) == 2
|
||||
# First should be converted
|
||||
assert result[0]["type"] == "choices"
|
||||
assert result[0]["value"] == {"choices": ["positive"]}
|
||||
# Second should pass through unchanged
|
||||
assert result[1]["id"] == "existing-123"
|
||||
assert result[1]["type"] == "rectanglelabels"
|
||||
|
||||
|
||||
class TestCreateConverterFromDict:
|
||||
"""Test the factory function for creating converter from dict"""
|
||||
|
||||
def test_create_from_valid_dict(self):
|
||||
"""Test creating converter from valid configuration dict"""
|
||||
config_dict = {
|
||||
"labels": [
|
||||
{
|
||||
"fromName": "label",
|
||||
"toName": "image",
|
||||
"type": "choices",
|
||||
"options": ["a", "b"]
|
||||
}
|
||||
],
|
||||
"objects": [
|
||||
{
|
||||
"name": "image",
|
||||
"type": "Image",
|
||||
"value": "$image"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
converter = create_converter_from_template_config(config_dict)
|
||||
|
||||
assert isinstance(converter, TagFormatConverter)
|
||||
assert converter.get_type_for_from_name("label") == "choices"
|
||||
|
||||
def test_create_from_invalid_dict(self):
|
||||
"""Test that invalid config raises ValueError"""
|
||||
invalid_config = {
|
||||
"labels": "not-a-list", # Should be a list
|
||||
"objects": []
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid template configuration"):
|
||||
create_converter_from_template_config(invalid_config)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test real-world usage scenarios"""
|
||||
|
||||
def test_external_api_submission(self, converter):
|
||||
"""Simulate external user submitting tags via API"""
|
||||
# User submits simplified format
|
||||
user_submission = [
|
||||
{
|
||||
"fromName": "sentiment", # User uses camelCase
|
||||
"toName": "text",
|
||||
"values": ["positive", "negative"]
|
||||
}
|
||||
]
|
||||
|
||||
# System converts to internal format
|
||||
internal_tags = converter.convert_if_needed(user_submission)
|
||||
|
||||
# Verify correct storage format
|
||||
assert len(internal_tags) == 1
|
||||
assert internal_tags[0]["type"] == "choices"
|
||||
assert internal_tags[0]["value"] == {"choices": ["positive", "negative"]}
|
||||
assert "id" in internal_tags[0]
|
||||
|
||||
def test_update_existing_tags(self, converter):
|
||||
"""Simulate updating existing tags with new values"""
|
||||
# Existing tags in database (full format)
|
||||
existing_tags = [
|
||||
{
|
||||
"id": "tag-001",
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"type": "choices",
|
||||
"value": {"choices": ["positive"]}
|
||||
}
|
||||
]
|
||||
|
||||
# User updates with simplified format
|
||||
update_request = [
|
||||
{
|
||||
"id": "tag-001", # Same ID to update
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["negative"] # New value
|
||||
}
|
||||
]
|
||||
|
||||
# Convert update request
|
||||
converted_update = converter.convert_if_needed(update_request)
|
||||
|
||||
# Merge logic would replace tag-001
|
||||
assert converted_update[0]["id"] == "tag-001"
|
||||
assert converted_update[0]["value"] == {"choices": ["negative"]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -165,14 +165,20 @@ class Service:
|
||||
async def update_file_tags_partial(
|
||||
self,
|
||||
file_id: str,
|
||||
new_tags: List[Dict[str, Any]]
|
||||
new_tags: List[Dict[str, Any]],
|
||||
template_id: Optional[str] = None
|
||||
) -> tuple[bool, Optional[str], Optional[datetime]]:
|
||||
"""
|
||||
部分更新文件标签
|
||||
部分更新文件标签,支持自动格式转换
|
||||
|
||||
如果提供了 template_id,会自动将简化格式的标签转换为完整格式。
|
||||
简化格式: {"from_name": "x", "to_name": "y", "values": [...]}
|
||||
完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}}
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
new_tags: 新的标签列表(部分更新)
|
||||
new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式
|
||||
template_id: 可选的模板ID,用于格式转换
|
||||
|
||||
Returns:
|
||||
(成功标志, 错误信息, 更新时间)
|
||||
@@ -190,6 +196,38 @@ class Service:
|
||||
logger.error(f"File not found: {file_id}")
|
||||
return False, f"File not found: {file_id}", None
|
||||
|
||||
# 如果提供了 template_id,尝试进行格式转换
|
||||
processed_tags = new_tags
|
||||
if template_id:
|
||||
logger.debug(f"Converting tags using template: {template_id}")
|
||||
|
||||
try:
|
||||
# 获取模板配置
|
||||
from app.db.models import AnnotationTemplate
|
||||
template_result = await self.db.execute(
|
||||
select(AnnotationTemplate).where(
|
||||
AnnotationTemplate.id == template_id,
|
||||
AnnotationTemplate.deleted_at.is_(None)
|
||||
)
|
||||
)
|
||||
template = template_result.scalar_one_or_none()
|
||||
|
||||
if not template:
|
||||
logger.warning(f"Template {template_id} not found, skipping conversion")
|
||||
else:
|
||||
# 使用 converter 转换标签格式
|
||||
from app.module.annotation.utils import create_converter_from_template_config
|
||||
|
||||
converter = create_converter_from_template_config(template.configuration) # type: ignore
|
||||
processed_tags = converter.convert_if_needed(new_tags)
|
||||
|
||||
logger.info(f"Converted {len(new_tags)} tags to full format")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert tags using template: {e}")
|
||||
# 继续使用原始标签格式
|
||||
logger.warning("Continuing with original tag format")
|
||||
|
||||
# 获取现有标签
|
||||
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
|
||||
|
||||
@@ -197,7 +235,7 @@ class Service:
|
||||
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
|
||||
|
||||
# 更新或追加标签
|
||||
for new_tag in new_tags:
|
||||
for new_tag in processed_tags:
|
||||
tag_id = new_tag.get('id')
|
||||
if tag_id and tag_id in tag_id_map:
|
||||
# 更新现有标签
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example: Tag Format Conversion Usage
|
||||
|
||||
This script demonstrates how to use the tag format conversion feature
|
||||
to update file tags using the simplified external format.
|
||||
|
||||
Run this script after:
|
||||
1. Creating a dataset
|
||||
2. Creating an annotation template
|
||||
3. Creating a labeling project that links the dataset and template
|
||||
4. Uploading files to the dataset
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Configuration
|
||||
API_BASE_URL = "http://localhost:8000"
|
||||
FILE_ID = "your-file-uuid-here" # Replace with actual file ID
|
||||
|
||||
|
||||
async def update_file_tags_simplified(
|
||||
file_id: str,
|
||||
tags: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update file tags using simplified format
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to update
|
||||
tags: List of tags in simplified format
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
url = f"{API_BASE_URL}/api/annotation/task/{file_id}"
|
||||
|
||||
payload = {
|
||||
"tags": tags
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.put(url, json=payload, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def example_1_basic_update():
|
||||
"""Example 1: Basic tag update with simplified format"""
|
||||
print("\n=== Example 1: Basic Tag Update ===\n")
|
||||
|
||||
# Simplified format - no type or nested value required
|
||||
tags = [
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive", "negative"]
|
||||
}
|
||||
]
|
||||
|
||||
print("Submitting tags in simplified format:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
|
||||
print("\n✓ Success! Response:")
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
# The response will contain tags in full internal format
|
||||
if result.get("data", {}).get("tags"):
|
||||
stored_tag = result["data"]["tags"][0]
|
||||
print("\n📝 Stored tag format:")
|
||||
print(f" - ID: {stored_tag.get('id')}")
|
||||
print(f" - Type: {stored_tag.get('type')}")
|
||||
print(f" - Value: {stored_tag.get('value')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def example_2_multiple_controls():
|
||||
"""Example 2: Update multiple different control types"""
|
||||
print("\n=== Example 2: Multiple Control Types ===\n")
|
||||
|
||||
tags = [
|
||||
# Text classification
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
},
|
||||
# Image bounding boxes
|
||||
{
|
||||
"from_name": "bbox",
|
||||
"to_name": "image",
|
||||
"values": ["cat", "dog"]
|
||||
},
|
||||
# Text comment
|
||||
{
|
||||
"from_name": "comment",
|
||||
"to_name": "text",
|
||||
"values": ["This is a great example"]
|
||||
}
|
||||
]
|
||||
|
||||
print("Submitting multiple control types:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
print("\n✓ Success! All tags converted and stored.")
|
||||
print(f" Total tags: {len(result['data']['tags'])}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def example_3_update_existing():
|
||||
"""Example 3: Update existing tag by preserving ID"""
|
||||
print("\n=== Example 3: Update Existing Tag ===\n")
|
||||
|
||||
# First, let's assume we know the ID of an existing tag
|
||||
existing_tag_id = "some-existing-uuid"
|
||||
|
||||
tags = [
|
||||
{
|
||||
"id": existing_tag_id, # Preserve ID to update, not create new
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["neutral"] # Change value
|
||||
}
|
||||
]
|
||||
|
||||
print("Updating existing tag by ID:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
print("\n✓ Success! Existing tag updated.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def example_4_camelcase():
|
||||
"""Example 4: Using camelCase field names (frontend style)"""
|
||||
print("\n=== Example 4: CamelCase Field Names ===\n")
|
||||
|
||||
# Frontend typically sends camelCase
|
||||
tags = [
|
||||
{
|
||||
"fromName": "sentiment", # camelCase
|
||||
"toName": "text", # camelCase
|
||||
"values": ["positive"]
|
||||
}
|
||||
]
|
||||
|
||||
print("Submitting with camelCase:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
print("\n✓ Success! CamelCase automatically handled.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def example_5_mixed_format():
|
||||
"""Example 5: Mixed format - some simplified, some full"""
|
||||
print("\n=== Example 5: Mixed Format Support ===\n")
|
||||
|
||||
tags = [
|
||||
# Simplified format
|
||||
{
|
||||
"from_name": "new_label",
|
||||
"to_name": "image",
|
||||
"values": ["new_value"]
|
||||
},
|
||||
# Full format (already has type and nested value)
|
||||
{
|
||||
"id": "existing-uuid-123",
|
||||
"from_name": "old_label",
|
||||
"to_name": "image",
|
||||
"type": "choices",
|
||||
"value": {
|
||||
"choices": ["old_value"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
print("Submitting mixed format:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
print("\n✓ Success! Mixed format handled correctly.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def example_6_error_handling():
|
||||
"""Example 6: Error handling - unknown control"""
|
||||
print("\n=== Example 6: Error Handling ===\n")
|
||||
|
||||
tags = [
|
||||
# Valid control
|
||||
{
|
||||
"from_name": "sentiment",
|
||||
"to_name": "text",
|
||||
"values": ["positive"]
|
||||
},
|
||||
# Invalid control - doesn't exist in template
|
||||
{
|
||||
"from_name": "unknown_control",
|
||||
"to_name": "text",
|
||||
"values": ["some_value"]
|
||||
}
|
||||
]
|
||||
|
||||
print("Submitting with unknown control:")
|
||||
print(json.dumps(tags, indent=2))
|
||||
|
||||
try:
|
||||
result = await update_file_tags_simplified(FILE_ID, tags)
|
||||
print("\n⚠ Partial Success:")
|
||||
print(f" Valid tags stored: {len(result['data']['tags'])}")
|
||||
print(" (Unknown control was skipped)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples"""
|
||||
print("\n" + "="*60)
|
||||
print("Tag Format Conversion Examples")
|
||||
print("="*60)
|
||||
|
||||
print(f"\nAPI Base URL: {API_BASE_URL}")
|
||||
print(f"File ID: {FILE_ID}")
|
||||
print("\n⚠ Make sure to replace FILE_ID with an actual file UUID!")
|
||||
|
||||
# Uncomment the examples you want to run:
|
||||
|
||||
# await example_1_basic_update()
|
||||
# await example_2_multiple_controls()
|
||||
# await example_3_update_existing()
|
||||
# await example_4_camelcase()
|
||||
# await example_5_mixed_format()
|
||||
# await example_6_error_handling()
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Tip: Edit this script to uncomment examples you want to run")
|
||||
print("="*60 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the examples
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user