Develop labeling module (#25)

* refactor: remove db table management from LS adapter (mv to scripts later); change adapter to use the same MySQL DB as other modules.

* refactor: Rename LS Adapter module to datamate-python
This commit is contained in:
Jinglong Wang
2025-10-27 16:16:14 +08:00
committed by GitHub
parent 46dfb389f1
commit 7f819563db
69 changed files with 1104 additions and 703 deletions

View File

@@ -0,0 +1,298 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.models.dm.labeling_project import LabelingProject
from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
)
from app.core.logging import get_logger
logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_mapping(
self,
mapping_data: DatasetMappingCreateRequest,
labelling_project_id: str,
labelling_project_name: str
) -> DatasetMappingResponse:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {mapping_data.dataset_id} -> {labelling_project_id}")
db_mapping = LabelingProject(
mapping_id=str(uuid.uuid4()),
dataset_id=mapping_data.dataset_id,
labelling_project_id=labelling_project_id,
labelling_project_name=labelling_project_name
)
self.db.add(db_mapping)
await self.db.commit()
await self.db.refresh(db_mapping)
logger.info(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
async def get_mapping_by_source_uuid(
self,
dataset_id: str
) -> Optional[DatasetMappingResponse]:
"""根据源数据集ID获取映射(返回第一个未删除的)"""
logger.debug(f"Get mapping by source dataset id: {dataset_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
async def get_mappings_by_dataset_id(
self,
dataset_id: str,
include_deleted: bool = False
) -> List[DatasetMappingResponse]:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def get_mapping_by_labelling_project_id(
self,
labelling_project_id: str
) -> Optional[DatasetMappingResponse]:
"""根据Label Studio项目ID获取映射"""
logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.labeling_project_id == labelling_project_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labelling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
return None
async def update_mapping(
self,
mapping_id: str,
update_data: DatasetMappingUpdateRequest
) -> Optional[DatasetMappingResponse]:
"""更新映射信息"""
logger.info(f"Update mapping: {mapping_id}")
mapping = await self.get_mapping_by_uuid(mapping_id)
if not mapping:
return None
update_values = update_data.model_dump(exclude_unset=True)
update_values["last_updated_at"] = datetime.now()
result = await self.db.execute(
update(LabelingProject)
.where(LabelingProject.id == mapping_id)
.values(**update_values)
)
await self.db.commit()
if result.rowcount > 0:
return await self.get_mapping_by_uuid(mapping_id)
return None
async def update_last_updated_at(self, mapping_id: str) -> bool:
"""更新最后更新时间"""
logger.debug(f"Update mapping last updated at: {mapping_id}")
result = await self.db.execute(
update(LabelingProject)
.where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
.values(last_updated_at=datetime.utcnow())
)
await self.db.commit()
return result.rowcount > 0
async def soft_delete_mapping(self, mapping_id: str) -> bool:
"""软删除映射"""
logger.info(f"Soft delete mapping: {mapping_id}")
result = await self.db.execute(
update(LabelingProject)
.where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
.values(deleted_at=datetime.now())
)
await self.db.commit()
success = result.rowcount > 0
if success:
logger.info(f"Mapping soft-deleted: {mapping_id}")
else:
logger.warning(f"Mapping not exists or already deleted: {mapping_id}")
return success
async def get_all_mappings(
self,
skip: int = 0,
limit: int = 100
) -> List[DatasetMappingResponse]:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one()
async def get_all_mappings_with_count(
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
async def get_mappings_by_source_with_count(
self,
dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total