refactor: Reorganize datamate-python (#34)

refactor: Reorganize datamate-python (previously label-studio-adapter) into a DDD style structure.
This commit is contained in:
Jason Wang
2025-10-30 01:32:59 +08:00
committed by GitHub
parent 0614157c0b
commit 2f7341dc1f
79 changed files with 1077 additions and 1577 deletions

View File

@@ -0,0 +1,283 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.core.logging import get_logger
from app.db.models import LabelingProject
from app.module.annotation.schema import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
)
logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_mapping(
self,
mapping_data: DatasetMappingCreateRequest,
labeling_project_id: str,
labeling_project_name: str
) -> DatasetMappingResponse:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {mapping_data.dataset_id} -> {labeling_project_id}")
db_mapping = LabelingProject(
id=str(uuid.uuid4()),
dataset_id=mapping_data.dataset_id,
labeling_project_id=labeling_project_id,
name=labeling_project_name
)
self.db.add(db_mapping)
await self.db.commit()
await self.db.refresh(db_mapping)
logger.debug(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
async def get_mapping_by_source_uuid(
self,
dataset_id: str
) -> Optional[DatasetMappingResponse]:
"""根据源数据集ID获取映射(返回第一个未删除的)"""
logger.debug(f"Get mapping by source dataset id: {dataset_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
async def get_mappings_by_dataset_id(
self,
dataset_id: str,
include_deleted: bool = False
) -> List[DatasetMappingResponse]:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def get_mapping_by_labeling_project_id(
self,
labeling_project_id: str
) -> Optional[DatasetMappingResponse]:
"""根据Label Studio项目ID获取映射"""
logger.debug(f"Get mapping by Label Studio project id: {labeling_project_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.labeling_project_id == labeling_project_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labeling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
return None
async def update_mapping(
self,
mapping_id: str,
update_data: DatasetMappingUpdateRequest
) -> Optional[DatasetMappingResponse]:
"""更新映射信息"""
logger.info(f"Update mapping: {mapping_id}")
mapping = await self.get_mapping_by_uuid(mapping_id)
if not mapping:
return None
update_values = update_data.model_dump(exclude_unset=True)
update_values["last_updated_at"] = datetime.now()
result = await self.db.execute(
update(LabelingProject)
.where(LabelingProject.id == mapping_id)
.values(**update_values)
)
await self.db.commit()
if result.rowcount > 0:
return await self.get_mapping_by_uuid(mapping_id)
return None
async def soft_delete_mapping(self, mapping_id: str) -> bool:
"""软删除映射"""
logger.info(f"Soft delete mapping: {mapping_id}")
result = await self.db.execute(
update(LabelingProject)
.where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
.values(deleted_at=datetime.now())
)
await self.db.commit()
success = result.rowcount > 0
if success:
logger.info(f"Mapping soft-deleted: {mapping_id}")
else:
logger.warning(f"Mapping not exists or already deleted: {mapping_id}")
return success
async def get_all_mappings(
self,
skip: int = 0,
limit: int = 100
) -> List[DatasetMappingResponse]:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one()
async def get_all_mappings_with_count(
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
async def get_mappings_by_source_with_count(
self,
dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total

View File

@@ -0,0 +1,272 @@
from typing import Optional, List, Dict, Any, Tuple
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from app.core.config import settings
from app.exception import NoDatasetInfoFoundError
from ..client import LabelStudioClient
from ..schema import SyncDatasetResponse
from ..service.mapping import DatasetMappingService
logger = get_logger(__name__)
class SyncService:
"""数据同步服务"""
def __init__(
self,
dm_client: DatasetManagementService,
ls_client: LabelStudioClient,
mapping_service: DatasetMappingService
):
self.dm_client = dm_client
self.ls_client = ls_client
self.mapping_service = mapping_service
def determine_data_type(self, file_type: str) -> str:
"""根据文件类型确定数据类型"""
file_type_lower = file_type.lower()
if any(ext in file_type_lower for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg', 'webp']):
return 'image'
elif any(ext in file_type_lower for ext in ['mp3', 'wav', 'flac', 'aac', 'ogg']):
return 'audio'
elif any(ext in file_type_lower for ext in ['mp4', 'avi', 'mov', 'wmv', 'flv', 'webm']):
return 'video'
elif any(ext in file_type_lower for ext in ['txt', 'doc', 'docx', 'pdf']):
return 'text'
else:
return 'image' # 默认为图像类型
async def get_existing_dm_file_mapping(self, project_id: str) -> Dict[str, int]:
"""
获取Label Studio项目中已存在的DM文件ID到任务ID的映射
Args:
project_id: Label Studio项目ID
Returns:
file_id到task_id的映射字典
"""
try:
logger.info(f"Fetching existing task mappings for project {project_id} (page_size={settings.ls_task_page_size})")
dm_file_to_task_mapping = {}
# 使用Label Studio客户端封装的方法获取所有任务
page_size = getattr(settings, 'ls_task_page_size', 1000)
# 调用封装好的方法获取所有任务,page=None表示获取全部
result = await self.ls_client.get_project_tasks(
project_id=project_id,
page=None, # 不指定page,获取所有任务
page_size=page_size
)
logger.info(f"Fetched tasks result: {result}")
if not result:
logger.warning(f"Failed to fetch tasks for project {project_id}")
return {}
logger.info(f"Successfully fetched tasks for project {project_id}")
all_tasks = result.get("tasks", [])
# 遍历所有任务,构建映射
for task in all_tasks:
# logger.debug(task)
try:
file_id = task.get('data', {}).get('file_id')
task_id = task.get('id')
dm_file_to_task_mapping[str(file_id)] = task_id
except Exception as e:
logger.error(f"Error processing task {task.get('id')}: {e}")
continue
logger.debug(dm_file_to_task_mapping)
logger.info(f"Found {len(dm_file_to_task_mapping)} existing task mappings")
return dm_file_to_task_mapping
except Exception as e:
logger.error(f"Error while fetching existing tasks: {e}")
return {} # 发生错误时返回空字典,会同步所有文件
async def sync_dataset_files(
self,
id: str,
batch_size: int = 50
) -> SyncDatasetResponse:
"""同步数据集文件到Label Studio"""
logger.info(f"Start syncing dataset by mapping: {id}")
# 获取映射关系
mapping = await self.mapping_service.get_mapping_by_uuid(id)
if not mapping:
logger.error(f"Dataset mapping not found: {id}")
return SyncDatasetResponse(
id="",
status="error",
synced_files=0,
total_files=0,
message=f"Dataset mapping not found: {id}"
)
try:
# 获取数据集信息
dataset_info = await self.dm_client.get_dataset(mapping.dataset_id)
if not dataset_info:
raise NoDatasetInfoFoundError(mapping.dataset_id)
synced_files = 0
deleted_tasks = 0
total_files = dataset_info.fileCount
page = 0
logger.info(f"Total files in dataset: {total_files}")
# 获取Label Studio中已存在的DM文件ID到任务ID的映射
existing_dm_file_mapping = await self.get_existing_dm_file_mapping(mapping.labeling_project_id)
existing_file_ids = set(existing_dm_file_mapping.keys())
logger.info(f"{len(existing_file_ids)} tasks already exist in Label Studio")
# 收集DM中当前存在的所有文件ID
current_file_ids = set()
while True:
files_response = await self.dm_client.get_dataset_files(
mapping.dataset_id,
page=page,
size=batch_size,
)
if not files_response or not files_response.content:
logger.info(f"No more files on page {page + 1}")
break
logger.info(f"Processing page {page + 1}, total {len(files_response.content)} files")
# 筛选出新文件并批量创建任务
tasks = []
new_files_count = 0
existing_files_count = 0
for file_info in files_response.content:
# 记录当前DM中存在的文件ID
current_file_ids.add(str(file_info.id))
# 检查文件是否已存在
if str(file_info.id) in existing_file_ids:
existing_files_count += 1
logger.debug(f"Skip existing file: {file_info.originalName} (ID: {file_info.id})")
continue
new_files_count += 1
data_type = self.determine_data_type(file_info.fileType)
# 替换文件路径前缀:只替换开头的前缀,不影响路径中间可能出现的相同字符串
file_path = file_info.filePath.removeprefix(settings.dm_file_path_prefix)
file_path = settings.label_studio_file_path_prefix + file_path
# 构造任务数据
task_data = {
"data": {
f"{data_type}": file_path,
"file_path": file_info.filePath,
"file_id": file_info.id,
"original_name": file_info.originalName,
"dataset_id": mapping.dataset_id,
}
}
tasks.append(task_data)
logger.info(f"Page {page + 1}: new files {new_files_count}, existing files {existing_files_count}")
# 批量创建Label Studio任务
if tasks:
batch_result = await self.ls_client.create_tasks_batch(
mapping.labeling_project_id,
tasks
)
if batch_result:
synced_files += len(tasks)
logger.info(f"Successfully synced {len(tasks)} files")
else:
logger.warning(f"Batch task creation failed, fallback to single creation")
# 如果批量创建失败,尝试单个创建
for task_data in tasks:
task_result = await self.ls_client.create_task(
mapping.labeling_project_id,
task_data["data"],
task_data.get("meta")
)
if task_result:
synced_files += 1
# 检查是否还有更多页面
if page >= files_response.totalPages - 1:
break
page += 1
# 清理在DM中不存在但在Label Studio中存在的任务
tasks_to_delete = []
for file_id, task_id in existing_dm_file_mapping.items():
if file_id not in current_file_ids:
tasks_to_delete.append(task_id)
logger.debug(f"Mark task for deletion: {task_id} (DM file ID: {file_id})")
if tasks_to_delete:
logger.info(f"Deleting {len(tasks_to_delete)} tasks not present in DM")
delete_result = await self.ls_client.delete_tasks_batch(tasks_to_delete)
deleted_tasks = delete_result.get("successful", 0)
logger.info(f"Successfully deleted {deleted_tasks} tasks")
else:
logger.info("No tasks to delete")
logger.info(f"Sync completed: total_files={total_files}, created={synced_files}, deleted={deleted_tasks}")
return SyncDatasetResponse(
id=mapping.id,
status="success",
synced_files=synced_files,
total_files=total_files,
message=f"Sync completed: created {synced_files} files, deleted {deleted_tasks} tasks"
)
except Exception as e:
logger.error(f"Error while syncing dataset: {e}")
return SyncDatasetResponse(
id=mapping.id,
status="error",
synced_files=0,
total_files=0,
message=f"Sync failed: {str(e)}"
)
async def get_sync_status(
self,
dataset_id: str
) -> Optional[Dict[str, Any]]:
"""获取同步状态"""
mapping = await self.mapping_service.get_mapping_by_source_uuid(dataset_id)
if not mapping:
return None
# 获取DM数据集信息
dataset_info = await self.dm_client.get_dataset(dataset_id)
# 获取Label Studio项目任务数量
tasks_info = await self.ls_client.get_project_tasks(mapping.labeling_project_id)
return {
"id": mapping.id,
"dataset_id": dataset_id,
"labeling_project_id": mapping.labeling_project_id,
"dm_total_files": dataset_info.fileCount if dataset_info else 0,
"ls_total_tasks": tasks_info.get("count", 0) if tasks_info else 0,
"sync_ratio": (
tasks_info.get("count", 0) / dataset_info.fileCount
if dataset_info and dataset_info.fileCount > 0 and tasks_info else 0
)
}