You've already forked DataMate
feat: Enhance file tag update functionality with automatic format conversion (#84)
- Updated `update_file_tags` to support both simplified and full tag formats. - Introduced `TagFormatConverter` to handle conversion from simplified external tags to internal storage format. - Added logic to fetch and utilize the appropriate annotation template for conversion. - Improved error handling for missing templates and unknown controls during tag updates. - Created example script demonstrating the usage of the new tag format conversion feature. - Added unit tests for `TagFormatConverter` to ensure correct functionality and edge case handling.
This commit is contained in:
@@ -150,12 +150,19 @@ async def create_mapping(
|
||||
async def list_mappings(
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
|
||||
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
查询所有映射关系(分页)
|
||||
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
|
||||
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
|
||||
|
||||
参数:
|
||||
- page: 页码(从1开始)
|
||||
- pageSize: 每页记录数
|
||||
- includeTemplate: 是否包含模板详情(默认false)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
@@ -163,10 +170,14 @@ async def list_mappings(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_all_mappings_with_count(
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
skip=skip,
|
||||
limit=page_size,
|
||||
include_deleted=False,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -181,7 +192,7 @@ async def list_mappings(
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}")
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -199,14 +210,21 @@ async def get_mapping(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系
|
||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||
|
||||
返回数据集映射关系以及关联的完整标注模板信息,包括:
|
||||
- 映射基本信息
|
||||
- 数据集信息
|
||||
- Label Studio 项目信息
|
||||
- 完整的标注模板配置(如果存在)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Get mapping: {mapping_id}")
|
||||
logger.info(f"Get mapping with template details: {mapping_id}")
|
||||
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id)
|
||||
# 获取映射,并包含完整的模板信息
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
@@ -214,7 +232,7 @@ async def get_mapping(
|
||||
detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Found mapping: {mapping.id}")
|
||||
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
@@ -233,12 +251,20 @@ async def get_mappings_by_source(
|
||||
dataset_id: str,
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数", alias="pageSize"),
|
||||
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系(分页)
|
||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
|
||||
默认包含关联的完整标注模板信息。
|
||||
|
||||
参数:
|
||||
- dataset_id: 数据集ID
|
||||
- page: 页码(从1开始)
|
||||
- pageSize: 每页记录数
|
||||
- includeTemplate: 是否包含模板详情(默认true)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
@@ -246,13 +272,14 @@ async def get_mappings_by_source(
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}")
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
# 获取数据和总数
|
||||
# 获取数据和总数(包含模板信息)
|
||||
mappings, total = await service.get_mappings_by_source_with_count(
|
||||
dataset_id=dataset_id,
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
limit=page_size,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
@@ -267,7 +294,7 @@ async def get_mappings_by_source(
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
|
||||
@@ -244,16 +244,53 @@ async def update_file_tags(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update File Tags (Partial Update)
|
||||
Update File Tags (Partial Update with Auto Format Conversion)
|
||||
|
||||
接收部分标签更新并合并到指定文件(只修改提交的标签,其余保持不变),并更新 `tags_updated_at`。
|
||||
|
||||
支持两种标签格式:
|
||||
1. 简化格式(外部用户提交):
|
||||
[{"from_name": "label", "to_name": "image", "values": ["cat", "dog"]}]
|
||||
|
||||
2. 完整格式(内部存储):
|
||||
[{"id": "...", "from_name": "label", "to_name": "image", "type": "choices",
|
||||
"value": {"choices": ["cat", "dog"]}}]
|
||||
|
||||
系统会自动根据数据集关联的模板将简化格式转换为完整格式。
|
||||
请求与响应使用 Pydantic 模型 `UpdateFileTagsRequest` / `UpdateFileTagsResponse`。
|
||||
"""
|
||||
service = DatasetManagementService(db)
|
||||
|
||||
# 首先获取文件所属的数据集
|
||||
from sqlalchemy.future import select
|
||||
from app.db.models import DatasetFiles
|
||||
|
||||
result = await db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
file_record = result.scalar_one_or_none()
|
||||
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
|
||||
|
||||
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
|
||||
|
||||
# 查找数据集关联的模板ID
|
||||
from ..service.mapping import DatasetMappingService
|
||||
|
||||
mapping_service = DatasetMappingService(db)
|
||||
template_id = await mapping_service.get_template_id_by_dataset_id(dataset_id)
|
||||
|
||||
if template_id:
|
||||
logger.info(f"Found template {template_id} for dataset {dataset_id}, will auto-convert tag format")
|
||||
else:
|
||||
logger.warning(f"No template found for dataset {dataset_id}, tags must be in full format")
|
||||
|
||||
# 更新标签(如果有模板ID则自动转换格式)
|
||||
success, error_msg, updated_at = await service.update_file_tags_partial(
|
||||
file_id=file_id,
|
||||
new_tags=request.tags
|
||||
new_tags=request.tags,
|
||||
template_id=template_id # 传递模板ID以启用自动转换
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -261,10 +298,7 @@ async def update_file_tags(
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
raise HTTPException(status_code=500, detail=error_msg or "更新标签失败")
|
||||
|
||||
# 获取更新后的完整标签列表
|
||||
from sqlalchemy.future import select
|
||||
from app.db.models import DatasetFiles
|
||||
|
||||
# 重新获取更新后的文件记录(获取完整标签列表)
|
||||
result = await db.execute(
|
||||
select(DatasetFiles).where(DatasetFiles.id == file_id)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user