Files
DataMate/runtime/datamate-python/app/module/annotation/interface/project.py
Jerry Yan 0bb9abb200 feat(annotation): 添加标注类型显示功能
- 在前端页面中新增标注类型列并使用Tag组件展示
- 添加AnnotationTypeMap常量用于标注类型的映射
- 修改接口定义支持labelingType字段的传递
- 更新后端项目创建和更新逻辑以存储标注类型
- 添加标注类型配置键常量统一管理
- 扩展数据传输对象支持标注类型属性
- 实现模板标注类型的继承逻辑
2026-02-01 19:08:11 +08:00

504 lines
18 KiB
Python

from typing import Optional
import math
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
from app.db.models import LabelingProject, DatasetFiles
from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService
from ..schema import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
DatasetMappingUpdateRequest,
DeleteDatasetResponse,
DatasetMappingResponse,
)
router = APIRouter(
prefix="/project",
tags=["annotation/project"]
)
logger = get_logger(__name__)
TEXT_DATASET_TYPE = "TEXT"
SOURCE_DOCUMENT_FILE_TYPES = {"pdf", "doc", "docx", "xls", "xlsx"}
LABELING_TYPE_CONFIG_KEY = "labeling_type"
@router.get("/{mapping_id}/login")
async def login_label_studio(
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
raise HTTPException(status_code=410, detail="当前为内嵌编辑器模式,不再支持 Label Studio 登录代理接口")
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
async def create_mapping(
request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db)
):
"""
创建数据集映射
在 DataMate 中创建标注项目(t_dm_labeling_projects),用于内嵌 Label Studio 编辑器。
注意:一个数据集可以创建多个标注项目
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
"""
try:
dm_client = DatasetManagementService(db)
mapping_service = DatasetMappingService(db)
template_service = AnnotationTemplateService()
logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息
dataset_info = await dm_client.get_dataset(request.dataset_id)
if not dataset_info:
raise HTTPException(
status_code=404,
detail=f"Dataset not found in DM service: {request.dataset_id}"
)
dataset_type = (
getattr(dataset_info, "datasetType", None)
or getattr(dataset_info, "dataset_type", None)
or ""
).upper()
project_name = request.name or \
dataset_info.name or \
"A new project from DataMate"
project_description = request.description or \
dataset_info.description or \
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
# 如果提供了模板ID,获取模板配置
label_config = None
template_labeling_type = None
if request.template_id:
logger.info(f"Using template: {request.template_id}")
template = await template_service.get_template(db, request.template_id)
if not template:
raise HTTPException(
status_code=404,
detail=f"Template not found: {request.template_id}"
)
label_config = template.label_config
template_labeling_type = getattr(template, "labeling_type", None)
logger.debug(f"Template label config loaded for template: {template.name}")
# 如果直接提供了 label_config (自定义或修改后的),则覆盖模板配置
if request.label_config:
label_config = request.label_config
logger.debug("Using custom label config from request")
# DataMate-only:不再创建/依赖 Label Studio Server 项目。
# 为兼容既有 schema 字段(labeling_project_id 长度 8),生成一个 8 位数字 ID。
labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8)
project_configuration = {}
if label_config:
project_configuration["label_config"] = label_config
if project_description:
project_configuration["description"] = project_description
if dataset_type == TEXT_DATASET_TYPE and request.segmentation_enabled is not None:
project_configuration["segmentation_enabled"] = bool(request.segmentation_enabled)
if template_labeling_type:
project_configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
labeling_project = LabelingProject(
id=str(uuid.uuid4()), # Generate UUID here
dataset_id=request.dataset_id,
labeling_project_id=labeling_project_id,
name=project_name,
template_id=request.template_id, # Save template_id to database
configuration=project_configuration or None,
)
file_result = await db.execute(
select(DatasetFiles).where(DatasetFiles.dataset_id == request.dataset_id)
)
file_records = file_result.scalars().all()
snapshot_file_ids: list[str] = []
if dataset_type == TEXT_DATASET_TYPE:
snapshot_file_ids = []
for file_record in file_records:
if not file_record.id:
continue
file_type = str(getattr(file_record, "file_type", "") or "").lower()
if file_type in SOURCE_DOCUMENT_FILE_TYPES:
continue
snapshot_file_ids.append(str(file_record.id))
else:
snapshot_file_ids = [
str(file_record.id)
for file_record in file_records
if file_record.id
]
# 创建映射关系并写入快照
mapping = await mapping_service.create_mapping_with_snapshot(
labeling_project, snapshot_file_ids
)
response_data = DatasetMappingCreateResponse(
id=mapping.id,
labeling_project_id=str(mapping.labeling_project_id),
labeling_project_name=mapping.name or project_name
)
return StandardResponse(
code=201,
message="success",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"),
size: int = Query(20, ge=1, le=100, description="每页记录数"),
include_template: bool = Query(False, description="是否包含模板详情", alias="includeTemplate"),
db: AsyncSession = Depends(get_db)
):
"""
查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
参数:
- page: 页码(从1开始)
- size: 每页记录数
- includeTemplate: 是否包含模板详情(默认false)
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * size
logger.info(f"List mappings: page={page}, size={size}, include_template={include_template}")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
limit=size,
include_deleted=False,
include_template=include_template
)
# 计算总页数
total_pages = math.ceil(total / size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except Exception as e:
logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
返回数据集映射关系以及关联的完整标注模板信息,包括:
- 映射基本信息
- 数据集信息
- Label Studio 项目信息
- 完整的标注模板配置(如果存在)
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mapping with template details: {mapping_id}")
# 获取映射,并包含完整的模板信息
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {mapping_id}"
)
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
return StandardResponse(
code=200,
message="success",
data=mapping
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def get_mappings_by_source(
dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"),
size: int = Query(20, ge=1, le=100, description="每页记录数"),
include_template: bool = Query(True, description="是否包含模板详情", alias="includeTemplate"),
db: AsyncSession = Depends(get_db)
):
"""
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
默认包含关联的完整标注模板信息。
参数:
- dataset_id: 数据集ID
- page: 页码(从1开始)
- size: 每页记录数
- includeTemplate: 是否包含模板详情(默认true)
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, size={size}, include_template={include_template}")
# 获取数据和总数(包含模板信息)
mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id,
skip=skip,
limit=size,
include_template=include_template
)
# 计算总页数
total_pages = math.ceil(total / size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/{project_id}", response_model=StandardResponse[DeleteDatasetResponse])
async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db)
):
"""
删除映射关系(软删除)
通过 path 参数 `project_id` 指定要删除的映射(映射的 UUID)。
内嵌编辑器模式下仅软删除 DataMate 标注项目记录,不再删除/依赖 Label Studio Server 项目。
"""
try:
logger.debug(f"Delete mapping request received: project_id={project_id!r}")
service = DatasetMappingService(db)
# 使用 mapping UUID 查询映射记录
logger.debug(f"Deleting by mapping UUID: {project_id}")
mapping = await service.get_mapping_by_uuid(project_id)
logger.debug(f"Mapping lookup result: {mapping}")
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping either not found or not specified."
)
id = mapping.id
logger.debug(f"Found mapping: {id}")
# 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(id)
logger.debug(f"Soft delete result for mapping {id}: {soft_delete_success}")
if not soft_delete_success:
raise HTTPException(
status_code=500,
detail="Failed to delete mapping record"
)
logger.info(f"Successfully deleted mapping: {id}")
return StandardResponse(
code=200,
message="success",
data=DeleteDatasetResponse(
id=id,
status="success"
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.put("/{project_id}", response_model=StandardResponse[DatasetMappingResponse])
async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db)
):
"""
更新标注项目信息
通过 path 参数 `project_id` 指定要更新的映射(映射的 UUID)。
支持更新的字段:
- name: 标注项目名称
- description: 标注项目描述
- template_id: 标注模板ID
- label_config: Label Studio XML配置
"""
try:
logger.info(f"Update mapping request received: project_id={project_id!r}")
service = DatasetMappingService(db)
# 直接查询 ORM 模型获取原始数据
result = await db.execute(
select(LabelingProject).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping_orm = result.scalar_one_or_none()
if not mapping_orm:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {project_id}"
)
# 构建更新数据
update_values = {}
if request.name is not None:
update_values["name"] = request.name
# 从 configuration 字段中读取和更新 description 和 label_config
configuration = {}
if mapping_orm.configuration:
configuration = mapping_orm.configuration.copy() if isinstance(mapping_orm.configuration, dict) else {}
if request.description is not None:
configuration["description"] = request.description
if request.label_config is not None:
configuration["label_config"] = request.label_config
if configuration:
update_values["configuration"] = configuration
if request.template_id is not None:
update_values["template_id"] = request.template_id
template_labeling_type = None
if request.template_id:
template_service = AnnotationTemplateService()
template = await template_service.get_template(db, request.template_id)
if not template:
raise HTTPException(
status_code=404,
detail=f"Template not found: {request.template_id}"
)
template_labeling_type = getattr(template, "labeling_type", None)
if template_labeling_type:
configuration[LABELING_TYPE_CONFIG_KEY] = template_labeling_type
if not update_values:
# 没有要更新的字段,直接返回当前数据
response_data = await service.get_mapping_by_uuid(project_id)
return StandardResponse(
code=200,
message="success",
data=response_data
)
# 执行更新
from datetime import datetime
update_values["updated_at"] = datetime.now()
result = await db.execute(
update(LabelingProject)
.where(LabelingProject.id == project_id)
.values(**update_values)
)
await db.commit()
if result.rowcount == 0:
raise HTTPException(
status_code=500,
detail="Failed to update mapping"
)
# 重新获取更新后的数据
updated_mapping = await service.get_mapping_by_uuid(project_id)
logger.info(f"Successfully updated mapping: {project_id}")
return StandardResponse(
code=200,
message="success",
data=updated_mapping
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")