feat: Enhance file tag update functionality with automatic format conversion (#84)

- Updated `update_file_tags` to support both simplified and full tag formats.
- Introduced `TagFormatConverter` to handle conversion from simplified external tags to internal storage format.
- Added logic to fetch and utilize the appropriate annotation template for conversion.
- Improved error handling for missing templates and unknown controls during tag updates.
- Created example script demonstrating the usage of the new tag format conversion feature.
- Added unit tests for `TagFormatConverter` to ensure correct functionality and edge case handling.
This commit is contained in:
Jason Wang
2025-11-14 12:42:39 +08:00
committed by GitHub
parent 5cef9cb273
commit df853a5177
10 changed files with 1127 additions and 54 deletions

View File

@@ -150,12 +150,19 @@ async def create_mapping(
async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
db: AsyncSession = Depends(get_db)
):
"""
查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
参数:
- page: 页码(从1开始)
- pageSize: 每页记录数
- includeTemplate: 是否包含模板详情(默认false)
"""
try:
service = DatasetMappingService(db)
@@ -163,10 +170,14 @@ async def list_mappings(
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
limit=page_size
skip=skip,
limit=page_size,
include_deleted=False,
include_template=include_template
)
# 计算总页数
@@ -181,7 +192,7 @@ async def list_mappings(
content=mappings
)
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}")
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
return StandardResponse(
code=200,
@@ -199,14 +210,21 @@ async def get_mapping(
db: AsyncSession = Depends(get_db)
):
"""
根据 UUID 查询单个映射关系
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
返回数据集映射关系以及关联的完整标注模板信息,包括:
- 映射基本信息
- 数据集信息
- Label Studio 项目信息
- 完整的标注模板配置(如果存在)
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mapping: {mapping_id}")
logger.info(f"Get mapping with template details: {mapping_id}")
mapping = await service.get_mapping_by_uuid(mapping_id)
# 获取映射,并包含完整的模板信息
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
if not mapping:
raise HTTPException(
@@ -214,7 +232,7 @@ async def get_mapping(
detail=f"Mapping not found: {mapping_id}"
)
logger.info(f"Found mapping: {mapping.id}")
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
return StandardResponse(
code=200,
@@ -233,12 +251,20 @@ async def get_mappings_by_source(
dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
db: AsyncSession = Depends(get_db)
):
"""
根据源数据集 ID 查询所有映射关系(分页)
根据源数据集 ID 查询所有映射关系(分页,包含模板详情
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
默认包含关联的完整标注模板信息。
参数:
- dataset_id: 数据集ID
- page: 页码(从1开始)
- pageSize: 每页记录数
- includeTemplate: 是否包含模板详情(默认true)
"""
try:
service = DatasetMappingService(db)
@@ -246,13 +272,14 @@ async def get_mappings_by_source(
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}")
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数
# 获取数据和总数(包含模板信息)
mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id,
skip=skip,
limit=page_size
limit=page_size,
include_template=include_template
)
# 计算总页数
@@ -267,7 +294,7 @@ async def get_mappings_by_source(
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
return StandardResponse(
code=200,

View File

@@ -244,16 +244,53 @@ async def update_file_tags(
db: AsyncSession = Depends(get_db)
):
"""
Update File Tags (Partial Update)
Update File Tags (Partial Update with Auto Format Conversion)
接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。
支持两种标签格式:
1. 简化格式(外部用户提交):
[{"from_name": "label", "to_name": "image", "values": ["cat", "dog"]}]
2. 完整格式(内部存储):
[{"id": "...", "from_name": "label", "to_name": "image", "type": "choices",
"value": {"choices": ["cat", "dog"]}}]
系统会自动根据数据集关联的模板将简化格式转换为完整格式。
请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。
"""
service = DatasetManagementService(db)
# 首先获取文件所属的数据集
from sqlalchemy.future import select
from app.db.models import DatasetFiles
result = await db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)
file_record = result.scalar_one_or_none()
if not file_record:
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
# 查找数据集关联的模板ID
from ..service.mapping import DatasetMappingService
mapping_service = DatasetMappingService(db)
template_id = await mapping_service.get_template_id_by_dataset_id(dataset_id)
if template_id:
logger.info(f"Found template {template_id} for dataset {dataset_id}, will auto-convert tag format")
else:
logger.warning(f"No template found for dataset {dataset_id}, tags must be in full format")
# 更新标签(如果有模板ID则自动转换格式)
success, error_msg, updated_at = await service.update_file_tags_partial(
file_id=file_id,
new_tags=request.tags
new_tags=request.tags,
template_id=template_id # 传递模板ID以启用自动转换
)
if not success:
@@ -261,10 +298,7 @@ async def update_file_tags(
raise HTTPException(status_code=404, detail=error_msg)
raise HTTPException(status_code=500, detail=error_msg or "更新标签失败")
# 获取更新后的完整标签列表
from sqlalchemy.future import select
from app.db.models import DatasetFiles
# 重新获取更新后的文件记录(获取完整标签列表
result = await db.execute(
select(DatasetFiles).where(DatasetFiles.id == file_id)
)

View File

@@ -3,14 +3,6 @@ from .config import (
TagConfigResponse
)
from .mapping import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
DatasetMappingUpdateRequest,
DatasetMappingResponse,
DeleteDatasetResponse,
)
from .sync import (
SyncDatasetRequest,
SyncDatasetResponse,
@@ -30,6 +22,17 @@ from .template import (
AnnotationTemplateListResponse
)
from .mapping import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
DatasetMappingUpdateRequest,
DatasetMappingResponse,
DeleteDatasetResponse,
)
# Rebuild model to resolve forward references
DatasetMappingResponse.model_rebuild()
__all__ = [
"ConfigResponse",
"TagConfigResponse",

View File

@@ -1,10 +1,13 @@
from pydantic import Field, BaseModel
from typing import Optional
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from app.module.shared.schema import BaseResponseModel
from app.module.shared.schema import StandardResponse
if TYPE_CHECKING:
from .template import AnnotationTemplateResponse
class DatasetMappingCreateRequest(BaseModel):
"""数据集映射 创建 请求模型
@@ -42,6 +45,8 @@ class DatasetMappingResponse(BaseModel):
labeling_project_id: str = Field(..., alias="labelingProjectId", description="标注项目ID")
name: Optional[str] = Field(None, description="标注项目名称")
description: Optional[str] = Field(None, description="标注项目描述")
template_id: Optional[str] = Field(None, alias="templateId", description="关联的模板ID")
template: Optional['AnnotationTemplateResponse'] = Field(None, description="关联的标注模板详情")
created_at: datetime = Field(..., alias="createdAt", description="创建时间")
updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间")
deleted_at: Optional[datetime] = Field(None, alias="deletedAt", description="删除时间")

View File

@@ -7,12 +7,13 @@ from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.db.models import LabelingProject, AnnotationTemplate
from app.db.models.dataset_management import Dataset
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
DatasetMappingResponse,
AnnotationTemplateResponse
)
logger = get_logger(__name__)
@@ -33,11 +34,32 @@ class DatasetMappingService:
LabelingProject.dataset_id == Dataset.id
)
def _to_response_from_row(self, row) -> DatasetMappingResponse:
"""Convert query row (mapping + dataset_name) to response"""
async def _to_response_from_row(
self,
row,
include_template: bool = False
) -> DatasetMappingResponse:
"""
Convert query row (mapping + dataset_name) to response
Args:
row: Query result row containing (LabelingProject, dataset_name)
include_template: If True, fetch and include full template details
"""
mapping = row[0] # LabelingProject object
dataset_name = row[1] # dataset_name from join
# Get template_id from mapping
template_id = getattr(mapping, 'template_id', None)
# Optionally fetch full template details
template_response = None
if include_template and template_id:
from ..service.template import AnnotationTemplateService
template_service = AnnotationTemplateService()
template_response = await template_service.get_template(self.db, template_id)
logger.debug(f"Included template details for template_id: {template_id}")
response_data = {
"id": mapping.id,
"dataset_id": mapping.dataset_id,
@@ -45,6 +67,8 @@ class DatasetMappingService:
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"template_id": template_id,
"template": template_response,
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
@@ -52,8 +76,18 @@ class DatasetMappingService:
return DatasetMappingResponse(**response_data)
async def _to_response(self, mapping: LabelingProject) -> DatasetMappingResponse:
"""Convert ORM model to response with dataset name (for single entity operations)"""
async def _to_response(
self,
mapping: LabelingProject,
include_template: bool = False
) -> DatasetMappingResponse:
"""
Convert ORM model to response with dataset name (for single entity operations)
Args:
mapping: LabelingProject ORM object
include_template: If True, fetch and include full template details
"""
# Fetch dataset name
dataset_name = None
dataset_id = getattr(mapping, 'dataset_id', None)
@@ -63,6 +97,17 @@ class DatasetMappingService:
)
dataset_name = dataset_result.scalar_one_or_none()
# Get template_id from mapping
template_id = getattr(mapping, 'template_id', None)
# Optionally fetch full template details
template_response = None
if include_template and template_id:
from ..service.template import AnnotationTemplateService
template_service = AnnotationTemplateService()
template_response = await template_service.get_template(self.db, template_id)
logger.debug(f"Included template details for template_id: {template_id}")
# Create response dict with all fields
response_data = {
"id": mapping.id,
@@ -71,6 +116,8 @@ class DatasetMappingService:
"labeling_project_id": mapping.labeling_project_id,
"name": mapping.name,
"description": getattr(mapping, 'description', None),
"template_id": template_id,
"template": template_response,
"created_at": mapping.created_at,
"updated_at": mapping.updated_at,
"deleted_at": mapping.deleted_at,
@@ -136,7 +183,12 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=False)
responses.append(response)
return responses
async def get_mapping_by_labeling_project_id(
self,
@@ -160,9 +212,19 @@ class DatasetMappingService:
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
async def get_mapping_by_uuid(
self,
mapping_id: str,
include_template: bool = False
) -> Optional[DatasetMappingResponse]:
"""
根据映射UUID获取映射
Args:
mapping_id: 映射UUID
include_template: 是否包含完整的模板信息
"""
logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}")
result = await self.db.execute(
select(LabelingProject).where(
@@ -174,7 +236,7 @@ class DatasetMappingService:
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return await self._to_response(mapping)
return await self._to_response(mapping, include_template=include_template)
logger.debug(f"No mapping found for mapping id: {mapping_id}")
return None
@@ -248,7 +310,12 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings")
return [self._to_response_from_row(row) for row in rows]
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=False)
responses.append(response)
return responses
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
@@ -264,10 +331,19 @@ class DatasetMappingService:
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
include_deleted: bool = False,
include_template: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
"""
获取所有映射及总数(用于分页)
Args:
skip: 跳过记录数
limit: 返回记录数
include_deleted: 是否包含已删除的记录
include_template: 是否包含完整的模板信息
"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}")
# 构建查询
query = self._build_query_with_dataset_name()
@@ -292,17 +368,62 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=include_template)
responses.append(response)
return responses, total
async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]:
"""
Get template ID for a dataset by finding its labeling project
Args:
dataset_id: Dataset UUID
Returns:
Template ID or None if no labeling project found or no template associated
"""
logger.debug(f"Looking up template for dataset: {dataset_id}")
result = await self.db.execute(
select(LabelingProject.template_id)
.where(
LabelingProject.dataset_id == dataset_id,
LabelingProject.deleted_at.is_(None)
)
.limit(1)
)
template_id = result.scalar_one_or_none()
if template_id:
logger.debug(f"Found template {template_id} for dataset {dataset_id}")
else:
logger.warning(f"No template found for dataset {dataset_id}")
return template_id
async def get_mappings_by_source_with_count(
self,
dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
include_deleted: bool = False,
include_template: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
"""
根据源数据集ID获取映射关系及总数(用于分页)
Args:
dataset_id: 数据集ID
skip: 跳过记录数
limit: 返回记录数
include_deleted: 是否包含已删除的记录
include_template: 是否包含完整的模板信息
"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}")
# 构建查询
query = self._build_query_with_dataset_name().where(
@@ -332,4 +453,9 @@ class DatasetMappingService:
rows = result.all()
logger.debug(f"Found {len(rows)} mappings, total: {total}")
return [self._to_response_from_row(row) for row in rows], total
# Convert rows to responses (async comprehension)
responses = []
for row in rows:
response = await self._to_response_from_row(row, include_template=include_template)
responses.append(response)
return responses, total

View File

@@ -2,5 +2,10 @@
Annotation Module Utilities
"""
from .config_validator import LabelStudioConfigValidator
from .tag_converter import TagFormatConverter, create_converter_from_template_config
__all__ = ['LabelStudioConfigValidator']
__all__ = [
'LabelStudioConfigValidator',
'TagFormatConverter',
'create_converter_from_template_config'
]

View File

@@ -0,0 +1,232 @@
"""
Tag Format Converter
Converts simplified external tag format to internal storage format by looking up
the type from the annotation template configuration.
External format (from users):
[
{
"from_name": "label",
"to_name": "image",
"values": ["cat", "dog"]
}
]
Internal storage format:
[
{
"id": "unique_id",
"from_name": "label",
"to_name": "image",
"type": "choices",
"value": {
"choices": ["cat", "dog"]
}
}
]
"""
import uuid
from typing import List, Dict, Any, Optional
from datetime import datetime
from app.core.logging import get_logger
from ..schema.template import TemplateConfiguration
logger = get_logger(__name__)
class TagFormatConverter:
"""Convert between simplified external tag format and internal storage format"""
def __init__(self, template_config: TemplateConfiguration):
"""
Initialize converter with template configuration
Args:
template_config: The template configuration containing label definitions
"""
self.template_config = template_config
# Build a lookup map: from_name -> type
self._type_map = self._build_type_map()
def _build_type_map(self) -> Dict[str, str]:
"""
Build a mapping from from_name to type from template labels
Returns:
Dictionary mapping from_name to control type
"""
type_map = {}
for label_def in self.template_config.labels:
from_name = label_def.from_name
control_type = label_def.type
type_map[from_name] = control_type
logger.debug(f"Registered control: {from_name} -> {control_type}")
return type_map
def get_type_for_from_name(self, from_name: str) -> Optional[str]:
"""
Get the control type for a given from_name
Args:
from_name: The control name
Returns:
Control type or None if not found
"""
return self._type_map.get(from_name)
def convert_simplified_to_full(
self,
simplified_tags: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Convert simplified tag format to full internal storage format
Args:
simplified_tags: List of tags in simplified format with structure:
[
{
"from_name": "label",
"to_name": "image",
"values": ["cat", "dog"] # Can be list or single value
}
]
Returns:
List of tags in full internal format:
[
{
"id": "unique_id",
"from_name": "label",
"to_name": "image",
"type": "choices",
"value": {
"choices": ["cat", "dog"]
}
}
]
"""
full_tags = []
for simplified_tag in simplified_tags:
# Support both camelCase and snake_case from external sources
from_name = simplified_tag.get('from_name') or simplified_tag.get('fromName')
to_name = simplified_tag.get('to_name') or simplified_tag.get('toName')
values = simplified_tag.get('values')
tag_id = simplified_tag.get('id') # Use existing ID if provided
if not from_name or not to_name:
logger.warning(f"Skipping tag with missing from_name or to_name: {simplified_tag}")
continue
# Look up the type from template configuration
control_type = self.get_type_for_from_name(from_name)
if not control_type:
logger.warning(
f"Could not find type for from_name '{from_name}' in template. "
f"Tag will be skipped. Available controls: {list(self._type_map.keys())}"
)
continue
# Generate ID if not provided
if not tag_id:
tag_id = str(uuid.uuid4())
# Convert values to the proper nested structure
# The key in the value dict should match the control type
full_tag = {
"id": tag_id,
"from_name": from_name,
"to_name": to_name,
"type": control_type,
"value": {
control_type: values
}
}
full_tags.append(full_tag)
logger.debug(f"Converted tag: {from_name} ({control_type}) with {len(values) if isinstance(values, list) else 1} values")
return full_tags
def is_simplified_format(self, tag: Dict[str, Any]) -> bool:
"""
Check if a tag is in simplified format (missing type field)
Args:
tag: Tag dictionary to check
Returns:
True if tag appears to be in simplified format
"""
# Simplified format has 'values' at top level and no 'type' field
has_values = 'values' in tag
has_type = 'type' in tag
has_value = 'value' in tag
# If it has 'values' but no 'type', it's simplified
# If it has 'type' and nested 'value', it's already full format
return has_values and not has_type and not has_value
def convert_if_needed(
self,
tags: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Convert tags to full format if they are in simplified format
This method can handle mixed formats - it will convert simplified tags
and pass through tags that are already in full format.
Args:
tags: List of tags in either format
Returns:
List of tags in full internal format
"""
if not tags:
return []
result = []
for tag in tags:
if self.is_simplified_format(tag):
# Convert simplified format
converted = self.convert_simplified_to_full([tag])
result.extend(converted)
else:
# Already in full format, pass through
result.append(tag)
return result
def create_converter_from_template_config(
template_config_dict: Dict[str, Any]
) -> TagFormatConverter:
"""
Create a TagFormatConverter from a template configuration dictionary
Args:
template_config_dict: Template configuration as dict (from database JSON)
Returns:
TagFormatConverter instance
Raises:
ValueError: If template configuration is invalid
"""
try:
# Parse the configuration using Pydantic model
from ..schema.template import TemplateConfiguration
template_config = TemplateConfiguration(**template_config_dict)
return TagFormatConverter(template_config)
except Exception as e:
logger.error(f"Failed to create tag converter from template config: {e}")
raise ValueError(f"Invalid template configuration: {e}")

View File

@@ -0,0 +1,337 @@
"""
Unit tests for TagFormatConverter
Run with: pytest app/module/annotation/utils/test_tag_converter.py -v
"""
import pytest
from .tag_converter import TagFormatConverter, create_converter_from_template_config
from ..schema.template import TemplateConfiguration, LabelDefinition, ObjectDefinition
@pytest.fixture
def sample_template_config():
"""Create a sample template configuration for testing"""
return TemplateConfiguration(
labels=[
LabelDefinition(
fromName="sentiment",
toName="text",
type="choices",
options=["positive", "negative", "neutral"],
required=True,
labels=None,
description=None
),
LabelDefinition(
fromName="bbox",
toName="image",
type="rectanglelabels",
labels=["cat", "dog", "bird"],
required=False,
options=None,
description=None
),
LabelDefinition(
fromName="comment",
toName="text",
type="textarea",
required=False,
options=None,
labels=None,
description=None
)
],
objects=[
ObjectDefinition(name="text", type="Text", value="$text"),
ObjectDefinition(name="image", type="Image", value="$image")
],
metadata=None
)
@pytest.fixture
def converter(sample_template_config):
"""Create a converter instance"""
return TagFormatConverter(sample_template_config)
class TestTagFormatConverter:
"""Test TagFormatConverter functionality"""
def test_type_map_building(self, converter):
"""Test that type map is built correctly from template"""
assert converter.get_type_for_from_name("sentiment") == "choices"
assert converter.get_type_for_from_name("bbox") == "rectanglelabels"
assert converter.get_type_for_from_name("comment") == "textarea"
assert converter.get_type_for_from_name("nonexistent") is None
def test_convert_simplified_to_full_single_value(self, converter):
"""Test conversion of simplified format with single value"""
simplified = [
{
"from_name": "sentiment",
"to_name": "text",
"values": ["positive"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 1
tag = result[0]
assert tag["from_name"] == "sentiment"
assert tag["to_name"] == "text"
assert tag["type"] == "choices"
assert tag["value"] == {"choices": ["positive"]}
assert "id" in tag
def test_convert_simplified_to_full_multiple_values(self, converter):
"""Test conversion of simplified format with multiple values"""
simplified = [
{
"from_name": "bbox",
"to_name": "image",
"values": ["cat", "dog"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 1
tag = result[0]
assert tag["type"] == "rectanglelabels"
assert tag["value"] == {"rectanglelabels": ["cat", "dog"]}
def test_convert_simplified_camelcase(self, converter):
"""Test that camelCase field names are supported"""
simplified = [
{
"fromName": "sentiment", # camelCase
"toName": "text", # camelCase
"values": ["neutral"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 1
assert result[0]["from_name"] == "sentiment"
assert result[0]["to_name"] == "text"
def test_convert_multiple_tags(self, converter):
"""Test conversion of multiple tags at once"""
simplified = [
{
"from_name": "sentiment",
"to_name": "text",
"values": ["positive"]
},
{
"from_name": "bbox",
"to_name": "image",
"values": ["cat"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 2
assert result[0]["type"] == "choices"
assert result[1]["type"] == "rectanglelabels"
def test_convert_with_existing_id(self, converter):
"""Test that existing IDs are preserved"""
existing_id = "my-custom-id-123"
simplified = [
{
"id": existing_id,
"from_name": "sentiment",
"to_name": "text",
"values": ["positive"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert result[0]["id"] == existing_id
def test_skip_unknown_from_name(self, converter):
"""Test that tags with unknown from_name are skipped"""
simplified = [
{
"from_name": "unknown_control",
"to_name": "text",
"values": ["value"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 0 # Should be skipped
def test_skip_missing_fields(self, converter):
"""Test that tags with missing required fields are skipped"""
simplified = [
{
"from_name": "sentiment",
# Missing to_name
"values": ["positive"]
}
]
result = converter.convert_simplified_to_full(simplified)
assert len(result) == 0 # Should be skipped
def test_is_simplified_format(self, converter):
"""Test detection of simplified format"""
# Simplified format
assert converter.is_simplified_format({
"from_name": "x",
"to_name": "y",
"values": ["a"]
}) is True
# Full format
assert converter.is_simplified_format({
"id": "123",
"from_name": "x",
"to_name": "y",
"type": "choices",
"value": {"choices": ["a"]}
}) is False
# Edge case: has both (should not be considered simplified)
assert converter.is_simplified_format({
"from_name": "x",
"to_name": "y",
"type": "choices",
"values": ["a"]
}) is False
def test_convert_if_needed_mixed_formats(self, converter):
"""Test conversion of mixed format tags"""
mixed = [
# Simplified format
{
"from_name": "sentiment",
"to_name": "text",
"values": ["positive"]
},
# Full format
{
"id": "existing-123",
"from_name": "bbox",
"to_name": "image",
"type": "rectanglelabels",
"value": {"rectanglelabels": ["cat"]}
}
]
result = converter.convert_if_needed(mixed)
assert len(result) == 2
# First should be converted
assert result[0]["type"] == "choices"
assert result[0]["value"] == {"choices": ["positive"]}
# Second should pass through unchanged
assert result[1]["id"] == "existing-123"
assert result[1]["type"] == "rectanglelabels"
class TestCreateConverterFromDict:
"""Test the factory function for creating converter from dict"""
def test_create_from_valid_dict(self):
"""Test creating converter from valid configuration dict"""
config_dict = {
"labels": [
{
"fromName": "label",
"toName": "image",
"type": "choices",
"options": ["a", "b"]
}
],
"objects": [
{
"name": "image",
"type": "Image",
"value": "$image"
}
]
}
converter = create_converter_from_template_config(config_dict)
assert isinstance(converter, TagFormatConverter)
assert converter.get_type_for_from_name("label") == "choices"
def test_create_from_invalid_dict(self):
"""Test that invalid config raises ValueError"""
invalid_config = {
"labels": "not-a-list", # Should be a list
"objects": []
}
with pytest.raises(ValueError, match="Invalid template configuration"):
create_converter_from_template_config(invalid_config)
class TestIntegrationScenarios:
"""Test real-world usage scenarios"""
def test_external_api_submission(self, converter):
"""Simulate external user submitting tags via API"""
# User submits simplified format
user_submission = [
{
"fromName": "sentiment", # User uses camelCase
"toName": "text",
"values": ["positive", "negative"]
}
]
# System converts to internal format
internal_tags = converter.convert_if_needed(user_submission)
# Verify correct storage format
assert len(internal_tags) == 1
assert internal_tags[0]["type"] == "choices"
assert internal_tags[0]["value"] == {"choices": ["positive", "negative"]}
assert "id" in internal_tags[0]
def test_update_existing_tags(self, converter):
"""Simulate updating existing tags with new values"""
# Existing tags in database (full format)
existing_tags = [
{
"id": "tag-001",
"from_name": "sentiment",
"to_name": "text",
"type": "choices",
"value": {"choices": ["positive"]}
}
]
# User updates with simplified format
update_request = [
{
"id": "tag-001", # Same ID to update
"from_name": "sentiment",
"to_name": "text",
"values": ["negative"] # New value
}
]
# Convert update request
converted_update = converter.convert_if_needed(update_request)
# Merge logic would replace tag-001
assert converted_update[0]["id"] == "tag-001"
assert converted_update[0]["value"] == {"choices": ["negative"]}
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -165,14 +165,20 @@ class Service:
async def update_file_tags_partial(
self,
file_id: str,
new_tags: List[Dict[str, Any]]
new_tags: List[Dict[str, Any]],
template_id: Optional[str] = None
) -> tuple[bool, Optional[str], Optional[datetime]]:
"""
部分更新文件标签
部分更新文件标签,支持自动格式转换
如果提供了 template_id,会自动将简化格式的标签转换为完整格式。
简化格式: {"from_name": "x", "to_name": "y", "values": [...]}
完整格式: {"id": "...", "from_name": "x", "to_name": "y", "type": "...", "value": {"type": [...]}}
Args:
file_id: 文件ID
new_tags: 新的标签列表(部分更新)
new_tags: 新的标签列表(部分更新),可以是简化格式或完整格式
template_id: 可选的模板ID,用于格式转换
Returns:
(成功标志, 错误信息, 更新时间)
@@ -190,6 +196,38 @@ class Service:
logger.error(f"File not found: {file_id}")
return False, f"File not found: {file_id}", None
# 如果提供了 template_id,尝试进行格式转换
processed_tags = new_tags
if template_id:
logger.debug(f"Converting tags using template: {template_id}")
try:
# 获取模板配置
from app.db.models import AnnotationTemplate
template_result = await self.db.execute(
select(AnnotationTemplate).where(
AnnotationTemplate.id == template_id,
AnnotationTemplate.deleted_at.is_(None)
)
)
template = template_result.scalar_one_or_none()
if not template:
logger.warning(f"Template {template_id} not found, skipping conversion")
else:
# 使用 converter 转换标签格式
from app.module.annotation.utils import create_converter_from_template_config
converter = create_converter_from_template_config(template.configuration) # type: ignore
processed_tags = converter.convert_if_needed(new_tags)
logger.info(f"Converted {len(new_tags)} tags to full format")
except Exception as e:
logger.error(f"Failed to convert tags using template: {e}")
# 继续使用原始标签格式
logger.warning("Continuing with original tag format")
# 获取现有标签
existing_tags: List[Dict[str, Any]] = file_record.tags or [] # type: ignore
@@ -197,7 +235,7 @@ class Service:
tag_id_map = {tag.get('id'): idx for idx, tag in enumerate(existing_tags) if tag.get('id')}
# 更新或追加标签
for new_tag in new_tags:
for new_tag in processed_tags:
tag_id = new_tag.get('id')
if tag_id and tag_id in tag_id_map:
# 更新现有标签