You've already forked DataMate
Add Label Studio adapter module and its build scipts.
This commit is contained in:
11
runtime/label-studio-adapter/app/api/project/__init__.py
Normal file
11
runtime/label-studio-adapter/app/api/project/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
标注工程相关API路由模块
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
project_router = APIRouter()
|
||||
|
||||
from . import create
|
||||
from . import sync
|
||||
from . import list
|
||||
from . import delete
|
||||
130
runtime/label-studio-adapter/app/api/project/create.py
Normal file
130
runtime/label-studio-adapter/app/api/project/create.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.dataset_mapping_service import DatasetMappingService
|
||||
from app.clients import get_clients
|
||||
from app.schemas.dataset_mapping import (
|
||||
DatasetMappingCreateRequest,
|
||||
DatasetMappingCreateResponse,
|
||||
)
|
||||
from app.schemas import StandardResponse
|
||||
from app.core.logging import get_logger
|
||||
from app.core.config import settings
|
||||
from . import project_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@project_router.post("/create", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
|
||||
async def create_dataset_mapping(
|
||||
request: DatasetMappingCreateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建数据集映射
|
||||
|
||||
根据指定的DM程序中的数据集,创建Label Studio中的数据集,
|
||||
在数据库中记录这一关联关系,返回Label Studio数据集的ID
|
||||
|
||||
注意:一个数据集可以创建多个标注项目
|
||||
"""
|
||||
try:
|
||||
# 获取全局客户端实例
|
||||
dm_client_instance, ls_client_instance = get_clients()
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Create dataset mapping request: {request.source_dataset_id}")
|
||||
|
||||
# 从DM服务获取数据集信息
|
||||
dataset_info = await dm_client_instance.get_dataset(request.source_dataset_id)
|
||||
if not dataset_info:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Dataset not found in DM service: {request.source_dataset_id}"
|
||||
)
|
||||
|
||||
# 确定数据类型(基于数据集类型)
|
||||
data_type = "image" # 默认值
|
||||
if dataset_info.type and dataset_info.type.code:
|
||||
type_code = dataset_info.type.code.lower()
|
||||
if "audio" in type_code:
|
||||
data_type = "audio"
|
||||
elif "video" in type_code:
|
||||
data_type = "video"
|
||||
elif "text" in type_code:
|
||||
data_type = "text"
|
||||
|
||||
# 生成项目名称
|
||||
project_name = f"{dataset_info.name}"
|
||||
|
||||
# 在Label Studio中创建项目
|
||||
project_data = await ls_client_instance.create_project(
|
||||
title=project_name,
|
||||
description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}",
|
||||
data_type=data_type
|
||||
)
|
||||
|
||||
if not project_data:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Fail to create Label Studio project."
|
||||
)
|
||||
|
||||
project_id = project_data["id"]
|
||||
|
||||
# 配置本地存储:dataset/<id>
|
||||
local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.source_dataset_id}"
|
||||
storage_result = await ls_client_instance.create_local_storage(
|
||||
project_id=project_id,
|
||||
path=local_storage_path,
|
||||
title="Dataset_BLOB",
|
||||
use_blob_urls=True,
|
||||
description=f"Local storage for dataset {dataset_info.name}"
|
||||
)
|
||||
|
||||
# 配置本地存储:upload
|
||||
local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}"
|
||||
storage_result = await ls_client_instance.create_local_storage(
|
||||
project_id=project_id,
|
||||
path=local_storage_path,
|
||||
title="Upload_BLOB",
|
||||
use_blob_urls=True,
|
||||
description=f"Local storage for dataset {dataset_info.name}"
|
||||
)
|
||||
|
||||
if not storage_result:
|
||||
# 本地存储配置失败,记录警告但不中断流程
|
||||
logger.warning(f"Failed to configure local storage for project {project_id}")
|
||||
else:
|
||||
logger.info(f"Local storage configured for project {project_id}: {local_storage_path}")
|
||||
|
||||
# 创建映射关系,包含项目名称
|
||||
mapping = await service.create_mapping(
|
||||
request,
|
||||
str(project_id),
|
||||
project_name
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.source_dataset_id} <> L {mapping.labelling_project_id}"
|
||||
)
|
||||
|
||||
response_data = DatasetMappingCreateResponse(
|
||||
mapping_id=mapping.mapping_id,
|
||||
labelling_project_id=mapping.labelling_project_id,
|
||||
labelling_project_name=mapping.labelling_project_name or project_name,
|
||||
message="Dataset mapping created successfully"
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=201,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating dataset mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
106
runtime/label-studio-adapter/app/api/project/delete.py
Normal file
106
runtime/label-studio-adapter/app/api/project/delete.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Optional
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.dataset_mapping_service import DatasetMappingService
|
||||
from app.clients import get_clients
|
||||
from app.schemas.dataset_mapping import DeleteDatasetResponse
|
||||
from app.schemas import StandardResponse
|
||||
from app.core.logging import get_logger
|
||||
from . import project_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@project_router.delete("/mappings", response_model=StandardResponse[DeleteDatasetResponse])
|
||||
async def delete_mapping(
|
||||
m: Optional[str] = Query(None, description="映射UUID"),
|
||||
proj: Optional[str] = Query(None, description="Label Studio项目ID"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除映射关系和对应的 Label Studio 项目
|
||||
|
||||
可以通过以下任一方式指定要删除的映射:
|
||||
- m: 映射UUID
|
||||
- proj: Label Studio项目ID
|
||||
- 两者都提供(优先使用 m)
|
||||
|
||||
此操作会:
|
||||
1. 删除 Label Studio 中的项目
|
||||
2. 软删除数据库中的映射记录
|
||||
"""
|
||||
try:
|
||||
# 至少需要提供一个参数
|
||||
if not m and not proj:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided"
|
||||
)
|
||||
|
||||
# 获取全局客户端实例
|
||||
dm_client_instance, ls_client_instance = get_clients()
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
mapping = None
|
||||
|
||||
# 优先使用 mapping_id 查询
|
||||
if m:
|
||||
logger.info(f"Deleting by mapping UUID: {m}")
|
||||
mapping = await service.get_mapping_by_uuid(m)
|
||||
# 如果没有提供 m,使用 proj 查询
|
||||
elif proj:
|
||||
logger.info(f"Deleting by project ID: {proj}")
|
||||
mapping = await service.get_mapping_by_labelling_project_id(proj)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found"
|
||||
)
|
||||
|
||||
mapping_id = mapping.mapping_id
|
||||
labelling_project_id = mapping.labelling_project_id
|
||||
labelling_project_name = mapping.labelling_project_name
|
||||
|
||||
logger.info(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}")
|
||||
|
||||
# 1. 删除 Label Studio 项目
|
||||
try:
|
||||
delete_success = await ls_client_instance.delete_project(int(labelling_project_id))
|
||||
if delete_success:
|
||||
logger.info(f"Successfully deleted Label Studio project: {labelling_project_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Label Studio project: {e}")
|
||||
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
|
||||
|
||||
# 2. 软删除映射记录
|
||||
soft_delete_success = await service.soft_delete_mapping(mapping_id)
|
||||
|
||||
if not soft_delete_success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete mapping record"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully deleted mapping: {mapping_id}")
|
||||
|
||||
response_data = DeleteDatasetResponse(
|
||||
mapping_id=mapping_id,
|
||||
status="success",
|
||||
message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'"
|
||||
)
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
110
runtime/label-studio-adapter/app/api/project/list.py
Normal file
110
runtime/label-studio-adapter/app/api/project/list.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.dataset_mapping_service import DatasetMappingService
|
||||
from app.schemas.dataset_mapping import DatasetMappingResponse
|
||||
from app.schemas import StandardResponse
|
||||
from app.core.logging import get_logger
|
||||
from . import project_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]])
|
||||
async def list_mappings(
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
查询所有映射关系
|
||||
|
||||
返回所有有效的数据集映射关系(未被软删除的)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Listing mappings, skip={skip}, limit={limit}")
|
||||
|
||||
mappings = await service.get_all_mappings(skip=skip, limit=limit)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mappings
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing mappings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@project_router.get("/mappings/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Get mapping: {mapping_id}")
|
||||
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Found mapping: {mapping.mapping_id}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mapping
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]])
|
||||
async def get_mappings_by_source(
|
||||
source_dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系
|
||||
|
||||
返回该数据集创建的所有标注项目(包括已删除的)
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {source_dataset_id}")
|
||||
|
||||
mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mappings
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mappings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
68
runtime/label-studio-adapter/app/api/project/sync.py
Normal file
68
runtime/label-studio-adapter/app/api/project/sync.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Optional
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.dataset_mapping_service import DatasetMappingService
|
||||
from app.services.sync_service import SyncService
|
||||
from app.clients import get_clients
|
||||
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
|
||||
from app.schemas.dataset_mapping import (
|
||||
DatasetMappingResponse,
|
||||
SyncDatasetRequest,
|
||||
SyncDatasetResponse,
|
||||
)
|
||||
from app.schemas import StandardResponse
|
||||
from app.core.logging import get_logger
|
||||
from . import project_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@project_router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
|
||||
async def sync_dataset_content(
|
||||
request: SyncDatasetRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
同步数据集内容
|
||||
|
||||
根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中,
|
||||
在数据库中记录更新时间,返回更新状态
|
||||
"""
|
||||
try:
|
||||
dm_client_instance, ls_client_instance = get_clients()
|
||||
mapping_service = DatasetMappingService(db)
|
||||
sync_service = SyncService(dm_client_instance, ls_client_instance, mapping_service)
|
||||
|
||||
logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}")
|
||||
|
||||
# 根据 mapping_id 获取映射关系
|
||||
mapping = await mapping_service.get_mapping_by_uuid(request.mapping_id)
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {request.mapping_id}"
|
||||
)
|
||||
|
||||
# 执行同步(使用映射中的源数据集UUID)
|
||||
result = await sync_service.sync_dataset_files(request.mapping_id, request.batch_size)
|
||||
|
||||
logger.info(f"Sync completed: {result.synced_files}/{result.total_files} files")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=result
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except NoDatasetInfoFoundError as e:
|
||||
logger.error(f"Failed to get dataset info: {e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except DatasetMappingNotFoundError as e:
|
||||
logger.error(f"Mapping not found: {e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing dataset content: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
Reference in New Issue
Block a user