You've already forked DataMate
- 在前端页面中新增标注类型列并使用Tag组件展示 - 添加AnnotationTypeMap常量用于标注类型的映射 - 修改接口定义支持labelingType字段的传递 - 更新后端项目创建和更新逻辑以存储标注类型 - 添加标注类型配置键常量统一管理 - 扩展数据传输对象支持标注类型属性 - 实现模板标注类型的继承逻辑
608 lines
22 KiB
Python
608 lines
22 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy import update, func, insert
|
|
from sqlalchemy.orm import aliased
|
|
from typing import Optional, List, Tuple
|
|
from datetime import datetime
|
|
import uuid
|
|
|
|
from app.core.logging import get_logger
|
|
from app.db.models import LabelingProject, AnnotationResult, LabelingProjectFile
|
|
from app.db.models.annotation_management import ANNOTATION_STATUS_IN_PROGRESS
|
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
|
from app.module.annotation.schema import (
|
|
DatasetMappingCreateRequest,
|
|
DatasetMappingUpdateRequest,
|
|
DatasetMappingResponse,
|
|
AnnotationTemplateResponse
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
LABELING_TYPE_CONFIG_KEY = "labeling_type"
|
|
|
|
class DatasetMappingService:
|
|
"""数据集映射服务"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
SNAPSHOT_INSERT_BATCH_SIZE = 500
|
|
|
|
def _build_query_with_dataset_name(self):
|
|
"""Build base query with dataset name joined"""
|
|
return select(
|
|
LabelingProject,
|
|
Dataset.name.label('dataset_name')
|
|
).outerjoin(
|
|
Dataset,
|
|
LabelingProject.dataset_id == Dataset.id
|
|
)
|
|
|
|
async def _get_project_stats(
|
|
self,
|
|
project_id: str,
|
|
dataset_id: str
|
|
) -> Tuple[int, int, int]:
|
|
"""
|
|
获取项目的统计数据
|
|
|
|
Args:
|
|
project_id: 标注项目ID
|
|
dataset_id: 数据集ID
|
|
|
|
Returns:
|
|
(total_count, annotated_count, in_progress_count) 元组
|
|
"""
|
|
# 获取标注项目快照数据量(只统计快照内的文件)
|
|
total_result = await self.db.execute(
|
|
select(func.count())
|
|
.select_from(LabelingProjectFile)
|
|
.join(DatasetFiles, LabelingProjectFile.file_id == DatasetFiles.id)
|
|
.where(
|
|
LabelingProjectFile.project_id == project_id,
|
|
DatasetFiles.dataset_id == dataset_id,
|
|
)
|
|
)
|
|
total_count = int(total_result.scalar() or 0)
|
|
|
|
# 获取已标注数据量(统计不同的 file_id 数量)
|
|
annotated_result = await self.db.execute(
|
|
select(func.count(func.distinct(AnnotationResult.file_id))).where(
|
|
AnnotationResult.project_id == project_id
|
|
)
|
|
)
|
|
annotated_count = int(annotated_result.scalar() or 0)
|
|
|
|
# 获取分段标注中数据量(标注状态为 IN_PROGRESS)
|
|
in_progress_result = await self.db.execute(
|
|
select(func.count(func.distinct(AnnotationResult.file_id))).where(
|
|
AnnotationResult.project_id == project_id,
|
|
AnnotationResult.annotation_status == ANNOTATION_STATUS_IN_PROGRESS,
|
|
)
|
|
)
|
|
in_progress_count = int(in_progress_result.scalar() or 0)
|
|
|
|
return total_count, annotated_count, in_progress_count
|
|
|
|
async def _to_response_from_row(
|
|
self,
|
|
row,
|
|
include_template: bool = False
|
|
) -> DatasetMappingResponse:
|
|
"""
|
|
Convert query row (mapping + dataset_name) to response
|
|
|
|
Args:
|
|
row: Query result row containing (LabelingProject, dataset_name)
|
|
include_template: If True, fetch and include full template details
|
|
"""
|
|
mapping = row[0] # LabelingProject object
|
|
dataset_name = row[1] # dataset_name from join
|
|
|
|
# Get template_id from mapping
|
|
template_id = getattr(mapping, 'template_id', None)
|
|
|
|
# 从 configuration JSON 字段中提取 label_config 和 description
|
|
configuration = getattr(mapping, 'configuration', None) or {}
|
|
label_config = None
|
|
description = None
|
|
segmentation_enabled = None
|
|
labeling_type = None
|
|
if isinstance(configuration, dict):
|
|
label_config = configuration.get('label_config')
|
|
description = configuration.get('description')
|
|
segmentation_enabled = configuration.get('segmentation_enabled')
|
|
labeling_type = configuration.get(LABELING_TYPE_CONFIG_KEY)
|
|
|
|
# Optionally fetch full template details
|
|
template_response = None
|
|
if include_template and template_id:
|
|
from ..service.template import AnnotationTemplateService
|
|
template_service = AnnotationTemplateService()
|
|
template_response = await template_service.get_template(self.db, template_id)
|
|
logger.debug(f"Included template details for template_id: {template_id}")
|
|
|
|
if not labeling_type and template_response:
|
|
labeling_type = getattr(template_response, "labeling_type", None)
|
|
|
|
# 获取统计数据
|
|
total_count, annotated_count, in_progress_count = await self._get_project_stats(
|
|
mapping.id, mapping.dataset_id
|
|
)
|
|
|
|
response_data = {
|
|
"id": mapping.id,
|
|
"dataset_id": mapping.dataset_id,
|
|
"dataset_name": dataset_name,
|
|
"labeling_project_id": mapping.labeling_project_id,
|
|
"name": mapping.name,
|
|
"description": description,
|
|
"template_id": template_id,
|
|
"labeling_type": labeling_type,
|
|
"template": template_response,
|
|
"label_config": label_config,
|
|
"segmentation_enabled": segmentation_enabled,
|
|
"total_count": total_count,
|
|
"annotated_count": annotated_count,
|
|
"in_progress_count": in_progress_count,
|
|
"created_at": mapping.created_at,
|
|
"updated_at": mapping.updated_at,
|
|
"deleted_at": mapping.deleted_at,
|
|
}
|
|
|
|
return DatasetMappingResponse(**response_data)
|
|
|
|
async def _to_response(
|
|
self,
|
|
mapping: LabelingProject,
|
|
include_template: bool = False
|
|
) -> DatasetMappingResponse:
|
|
"""
|
|
Convert ORM model to response with dataset name (for single entity operations)
|
|
|
|
Args:
|
|
mapping: LabelingProject ORM object
|
|
include_template: If True, fetch and include full template details
|
|
"""
|
|
# Fetch dataset name
|
|
dataset_name = None
|
|
dataset_id = getattr(mapping, 'dataset_id', None)
|
|
if dataset_id:
|
|
dataset_result = await self.db.execute(
|
|
select(Dataset.name).where(Dataset.id == dataset_id)
|
|
)
|
|
dataset_name = dataset_result.scalar_one_or_none()
|
|
|
|
# Get template_id from mapping
|
|
template_id = getattr(mapping, 'template_id', None)
|
|
|
|
# 从 configuration JSON 字段中提取 label_config 和 description
|
|
configuration = getattr(mapping, 'configuration', None) or {}
|
|
label_config = None
|
|
description = None
|
|
segmentation_enabled = None
|
|
labeling_type = None
|
|
if isinstance(configuration, dict):
|
|
label_config = configuration.get('label_config')
|
|
description = configuration.get('description')
|
|
segmentation_enabled = configuration.get('segmentation_enabled')
|
|
labeling_type = configuration.get(LABELING_TYPE_CONFIG_KEY)
|
|
|
|
# Optionally fetch full template details
|
|
template_response = None
|
|
if include_template and template_id:
|
|
from ..service.template import AnnotationTemplateService
|
|
template_service = AnnotationTemplateService()
|
|
template_response = await template_service.get_template(self.db, template_id)
|
|
logger.debug(f"Included template details for template_id: {template_id}")
|
|
|
|
if not labeling_type and template_response:
|
|
labeling_type = getattr(template_response, "labeling_type", None)
|
|
|
|
# 获取统计数据
|
|
total_count, annotated_count, in_progress_count = 0, 0, 0
|
|
if dataset_id:
|
|
total_count, annotated_count, in_progress_count = await self._get_project_stats(
|
|
mapping.id, dataset_id
|
|
)
|
|
|
|
# Create response dict with all fields
|
|
response_data = {
|
|
"id": mapping.id,
|
|
"dataset_id": dataset_id,
|
|
"dataset_name": dataset_name,
|
|
"labeling_project_id": mapping.labeling_project_id,
|
|
"name": mapping.name,
|
|
"description": description,
|
|
"template_id": template_id,
|
|
"labeling_type": labeling_type,
|
|
"template": template_response,
|
|
"label_config": label_config,
|
|
"segmentation_enabled": segmentation_enabled,
|
|
"total_count": total_count,
|
|
"annotated_count": annotated_count,
|
|
"in_progress_count": in_progress_count,
|
|
"created_at": mapping.created_at,
|
|
"updated_at": mapping.updated_at,
|
|
"deleted_at": mapping.deleted_at,
|
|
}
|
|
|
|
return DatasetMappingResponse(**response_data)
|
|
|
|
async def create_mapping(
|
|
self,
|
|
labeling_project: LabelingProject
|
|
) -> DatasetMappingResponse:
|
|
"""创建数据集映射"""
|
|
logger.debug(f"Create dataset mapping: {labeling_project.dataset_id} -> {labeling_project.labeling_project_id}")
|
|
|
|
# Use the passed object directly
|
|
self.db.add(labeling_project)
|
|
await self.db.commit()
|
|
await self.db.refresh(labeling_project)
|
|
|
|
logger.debug(f"Mapping created: {labeling_project.id}")
|
|
return await self._to_response(labeling_project)
|
|
|
|
async def create_mapping_with_snapshot(
|
|
self,
|
|
labeling_project: LabelingProject,
|
|
file_ids: List[str],
|
|
) -> DatasetMappingResponse:
|
|
"""创建数据集映射并写入快照文件"""
|
|
logger.debug(
|
|
"Create dataset mapping with snapshot: %s -> %s, files=%d",
|
|
labeling_project.dataset_id,
|
|
labeling_project.labeling_project_id,
|
|
len(file_ids),
|
|
)
|
|
|
|
self.db.add(labeling_project)
|
|
await self.db.flush()
|
|
assert labeling_project.id, "labeling_project.id must be set before snapshot insert"
|
|
|
|
if file_ids:
|
|
await self._insert_snapshot_records(labeling_project.id, file_ids)
|
|
|
|
await self.db.commit()
|
|
await self.db.refresh(labeling_project)
|
|
|
|
logger.debug("Mapping created with snapshot: %s", labeling_project.id)
|
|
return await self._to_response(labeling_project)
|
|
|
|
async def _insert_snapshot_records(self, project_id: str, file_ids: List[str]) -> None:
|
|
batch: List[dict] = []
|
|
for file_id in file_ids:
|
|
batch.append(
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"project_id": project_id,
|
|
"file_id": file_id,
|
|
}
|
|
)
|
|
if len(batch) >= self.SNAPSHOT_INSERT_BATCH_SIZE:
|
|
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
|
batch.clear()
|
|
if batch:
|
|
await self.db.execute(insert(LabelingProjectFile).values(batch))
|
|
|
|
async def get_mapping_by_source_uuid(
|
|
self,
|
|
dataset_id: str
|
|
) -> Optional[DatasetMappingResponse]:
|
|
"""根据源数据集ID获取映射(返回第一个未删除的)"""
|
|
logger.debug(f"Get mapping by source dataset id: {dataset_id}")
|
|
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.dataset_id == dataset_id,
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
)
|
|
mapping = result.scalar_one_or_none()
|
|
|
|
if mapping:
|
|
logger.debug(f"Found mapping: {mapping.id}")
|
|
return await self._to_response(mapping)
|
|
|
|
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
|
|
return None
|
|
|
|
async def get_mappings_by_dataset_id(
|
|
self,
|
|
dataset_id: str,
|
|
include_deleted: bool = False
|
|
) -> List[DatasetMappingResponse]:
|
|
"""根据源数据集ID获取所有映射关系"""
|
|
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
|
|
|
|
query = self._build_query_with_dataset_name().where(
|
|
LabelingProject.dataset_id == dataset_id
|
|
)
|
|
|
|
if not include_deleted:
|
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
result = await self.db.execute(
|
|
query.order_by(LabelingProject.created_at.desc())
|
|
)
|
|
rows = result.all()
|
|
|
|
logger.debug(f"Found {len(rows)} mappings")
|
|
# Convert rows to responses (async comprehension)
|
|
responses = []
|
|
for row in rows:
|
|
response = await self._to_response_from_row(row, include_template=False)
|
|
responses.append(response)
|
|
return responses
|
|
|
|
async def get_mapping_by_labeling_project_id(
|
|
self,
|
|
labeling_project_id: str
|
|
) -> Optional[DatasetMappingResponse]:
|
|
"""根据Label Studio项目ID获取映射"""
|
|
logger.debug(f"Get mapping by Label Studio project id: {labeling_project_id}")
|
|
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.labeling_project_id == labeling_project_id,
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
)
|
|
mapping = result.scalar_one_or_none()
|
|
|
|
if mapping:
|
|
logger.debug(f"Found mapping: {mapping.id}")
|
|
return await self._to_response(mapping)
|
|
|
|
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
|
|
return None
|
|
|
|
async def get_mapping_by_uuid(
|
|
self,
|
|
mapping_id: str,
|
|
include_template: bool = False
|
|
) -> Optional[DatasetMappingResponse]:
|
|
"""
|
|
根据映射UUID获取映射
|
|
|
|
Args:
|
|
mapping_id: 映射UUID
|
|
include_template: 是否包含完整的模板信息
|
|
"""
|
|
logger.debug(f"Get mapping: {mapping_id}, include_template={include_template}")
|
|
|
|
result = await self.db.execute(
|
|
select(LabelingProject).where(
|
|
LabelingProject.id == mapping_id,
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
)
|
|
mapping = result.scalar_one_or_none()
|
|
|
|
if mapping:
|
|
logger.debug(f"Found mapping: {mapping.id}")
|
|
return await self._to_response(mapping, include_template=include_template)
|
|
|
|
logger.debug(f"No mapping found for mapping id: {mapping_id}")
|
|
return None
|
|
|
|
async def update_mapping(
|
|
self,
|
|
mapping_id: str,
|
|
update_data: DatasetMappingUpdateRequest
|
|
) -> Optional[DatasetMappingResponse]:
|
|
"""更新映射信息"""
|
|
logger.info(f"Update mapping: {mapping_id}")
|
|
|
|
mapping = await self.get_mapping_by_uuid(mapping_id)
|
|
if not mapping:
|
|
return None
|
|
|
|
update_values = update_data.model_dump(exclude_unset=True)
|
|
update_values["last_updated_at"] = datetime.now()
|
|
|
|
result = await self.db.execute(
|
|
update(LabelingProject)
|
|
.where(LabelingProject.id == mapping_id)
|
|
.values(**update_values)
|
|
)
|
|
await self.db.commit()
|
|
|
|
if result.rowcount and result.rowcount > 0: # type: ignore
|
|
return await self.get_mapping_by_uuid(mapping_id)
|
|
return None
|
|
|
|
async def soft_delete_mapping(self, mapping_id: str) -> bool:
|
|
"""软删除映射"""
|
|
logger.info(f"Soft delete mapping: {mapping_id}")
|
|
|
|
result = await self.db.execute(
|
|
update(LabelingProject)
|
|
.where(
|
|
LabelingProject.id == mapping_id,
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
.values(deleted_at=datetime.now())
|
|
)
|
|
await self.db.commit()
|
|
|
|
success = result.rowcount and result.rowcount > 0 # type: ignore
|
|
if success:
|
|
logger.info(f"Mapping soft-deleted: {mapping_id}")
|
|
else:
|
|
logger.warning(f"Mapping not exists or already deleted: {mapping_id}")
|
|
|
|
return success
|
|
|
|
async def get_all_mappings(
|
|
self,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[DatasetMappingResponse]:
|
|
"""获取所有有效映射"""
|
|
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
|
|
|
|
query = self._build_query_with_dataset_name().where(
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
|
|
result = await self.db.execute(
|
|
query
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.order_by(LabelingProject.created_at.desc())
|
|
)
|
|
rows = result.all()
|
|
|
|
logger.debug(f"Found {len(rows)} mappings")
|
|
# Convert rows to responses (async comprehension)
|
|
responses = []
|
|
for row in rows:
|
|
response = await self._to_response_from_row(row, include_template=False)
|
|
responses.append(response)
|
|
return responses
|
|
|
|
async def count_mappings(self, include_deleted: bool = False) -> int:
|
|
"""统计映射总数"""
|
|
query = select(func.count()).select_from(LabelingProject)
|
|
|
|
if not include_deleted:
|
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
result = await self.db.execute(query)
|
|
return result.scalar_one()
|
|
|
|
async def get_all_mappings_with_count(
|
|
self,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
include_deleted: bool = False,
|
|
include_template: bool = False
|
|
) -> Tuple[List[DatasetMappingResponse], int]:
|
|
"""
|
|
获取所有映射及总数(用于分页)
|
|
|
|
Args:
|
|
skip: 跳过记录数
|
|
limit: 返回记录数
|
|
include_deleted: 是否包含已删除的记录
|
|
include_template: 是否包含完整的模板信息
|
|
"""
|
|
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}, include_template={include_template}")
|
|
|
|
# 构建查询
|
|
query = self._build_query_with_dataset_name()
|
|
if not include_deleted:
|
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
# 获取总数
|
|
count_query = select(func.count()).select_from(LabelingProject)
|
|
if not include_deleted:
|
|
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
count_result = await self.db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
# 获取数据
|
|
result = await self.db.execute(
|
|
query
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.order_by(LabelingProject.created_at.desc())
|
|
)
|
|
rows = result.all()
|
|
|
|
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
|
# Convert rows to responses (async comprehension)
|
|
responses = []
|
|
for row in rows:
|
|
response = await self._to_response_from_row(row, include_template=include_template)
|
|
responses.append(response)
|
|
return responses, total
|
|
|
|
async def get_template_id_by_dataset_id(self, dataset_id: str) -> Optional[str]:
|
|
"""
|
|
Get template ID for a dataset by finding its labeling project
|
|
|
|
Args:
|
|
dataset_id: Dataset UUID
|
|
|
|
Returns:
|
|
Template ID or None if no labeling project found or no template associated
|
|
"""
|
|
logger.debug(f"Looking up template for dataset: {dataset_id}")
|
|
|
|
result = await self.db.execute(
|
|
select(LabelingProject.template_id)
|
|
.where(
|
|
LabelingProject.dataset_id == dataset_id,
|
|
LabelingProject.deleted_at.is_(None)
|
|
)
|
|
.limit(1)
|
|
)
|
|
|
|
template_id = result.scalar_one_or_none()
|
|
|
|
if template_id:
|
|
logger.debug(f"Found template {template_id} for dataset {dataset_id}")
|
|
else:
|
|
logger.warning(f"No template found for dataset {dataset_id}")
|
|
|
|
return template_id
|
|
|
|
async def get_mappings_by_source_with_count(
|
|
self,
|
|
dataset_id: str,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
include_deleted: bool = False,
|
|
include_template: bool = False
|
|
) -> Tuple[List[DatasetMappingResponse], int]:
|
|
"""
|
|
根据源数据集ID获取映射关系及总数(用于分页)
|
|
|
|
Args:
|
|
dataset_id: 数据集ID
|
|
skip: 跳过记录数
|
|
limit: 返回记录数
|
|
include_deleted: 是否包含已删除的记录
|
|
include_template: 是否包含完整的模板信息
|
|
"""
|
|
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}, include_template={include_template}")
|
|
|
|
# 构建查询
|
|
query = self._build_query_with_dataset_name().where(
|
|
LabelingProject.dataset_id == dataset_id
|
|
)
|
|
|
|
if not include_deleted:
|
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
# 获取总数
|
|
count_query = select(func.count()).select_from(LabelingProject).where(
|
|
LabelingProject.dataset_id == dataset_id
|
|
)
|
|
if not include_deleted:
|
|
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
|
|
|
count_result = await self.db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
# 获取数据
|
|
result = await self.db.execute(
|
|
query
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.order_by(LabelingProject.created_at.desc())
|
|
)
|
|
rows = result.all()
|
|
|
|
logger.debug(f"Found {len(rows)} mappings, total: {total}")
|
|
# Convert rows to responses (async comprehension)
|
|
responses = []
|
|
for row in rows:
|
|
response = await self._to_response_from_row(row, include_template=include_template)
|
|
responses.append(response)
|
|
return responses, total
|