Develop labeling module (#25)

* refactor: remove db table management from LS adapter (mv to scripts later); change adapter to use the same MySQL DB as other modules.

* refactor: Rename LS Adapter module to datamate-python
This commit is contained in:
Jinglong Wang
2025-10-27 16:16:14 +08:00
committed by GitHub
parent 46dfb389f1
commit 7f819563db
69 changed files with 1104 additions and 703 deletions

View File

@@ -0,0 +1 @@
# app/__init__.py

View File

@@ -0,0 +1,19 @@
"""
API 路由模块
集中管理所有API路由的组织结构
"""
from fastapi import APIRouter
from .system import router as system_router
from .project import project_router
# 创建主API路由器
api_router = APIRouter()
# 注册到主路由器
api_router.include_router(system_router, tags=["系统"])
api_router.include_router(project_router, prefix="/project", tags=["项目"])
# 导出路由器供 main.py 使用
__all__ = ["api_router"]

View File

@@ -0,0 +1,11 @@
"""
标注工程相关API路由模块
"""
from fastapi import APIRouter
project_router = APIRouter()
from . import create
from . import sync
from . import list
from . import delete

View File

@@ -0,0 +1,130 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.infrastructure import DatamateClient, LabelStudioClient
from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
)
from app.schemas import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from . import project_router
logger = get_logger(__name__)
@project_router.post("/create", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
async def create_dataset_mapping(
request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db)
):
"""
创建数据集映射
根据指定的DM程序中的数据集,创建Label Studio中的数据集,
在数据库中记录这一关联关系,返回Label Studio数据集的ID
注意:一个数据集可以创建多个标注项目
"""
try:
dm_client = DatamateClient(db)
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
service = DatasetMappingService(db)
logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息
dataset_info = await dm_client.get_dataset(request.dataset_id)
if not dataset_info:
raise HTTPException(
status_code=404,
detail=f"Dataset not found in DM service: {request.dataset_id}"
)
# 确定数据类型(基于数据集类型)
data_type = "image" # 默认值
if dataset_info.type and dataset_info.type.code:
type_code = dataset_info.type.code.lower()
if "audio" in type_code:
data_type = "audio"
elif "video" in type_code:
data_type = "video"
elif "text" in type_code:
data_type = "text"
project_name = f"{dataset_info.name}"
# 在Label Studio中创建项目
project_data = await ls_client.create_project(
title=project_name,
description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}",
data_type=data_type
)
if not project_data:
raise HTTPException(
status_code=500,
detail="Fail to create Label Studio project."
)
project_id = project_data["id"]
# 配置本地存储:dataset/<id>
local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.dataset_id}"
storage_result = await ls_client.create_local_storage(
project_id=project_id,
path=local_storage_path,
title="Dataset_BLOB",
use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}"
)
# 配置本地存储:upload
local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}"
storage_result = await ls_client.create_local_storage(
project_id=project_id,
path=local_storage_path,
title="Upload_BLOB",
use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}"
)
if not storage_result:
# 本地存储配置失败,记录警告但不中断流程
logger.warning(f"Failed to configure local storage for project {project_id}")
else:
logger.info(f"Local storage configured for project {project_id}: {local_storage_path}")
# 创建映射关系,包含项目名称
mapping = await service.create_mapping(
request,
str(project_id),
project_name
)
logger.debug(
f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.dataset_id} <> L {mapping.labelling_project_id}"
)
response_data = DatasetMappingCreateResponse(
mapping_id=mapping.mapping_id,
labelling_project_id=mapping.labelling_project_id,
labelling_project_name=mapping.labelling_project_name or project_name,
message="Dataset mapping created successfully"
)
return StandardResponse(
code=201,
message="success",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,106 @@
from fastapi import Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.infrastructure import DatamateClient, LabelStudioClient
from app.schemas.dataset_mapping import DeleteDatasetResponse
from app.schemas import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from . import project_router
logger = get_logger(__name__)
@project_router.delete("/mappings", response_model=StandardResponse[DeleteDatasetResponse])
async def delete_mapping(
m: Optional[str] = Query(None, description="映射UUID"),
proj: Optional[str] = Query(None, description="Label Studio项目ID"),
db: AsyncSession = Depends(get_db)
):
"""
删除映射关系和对应的 Label Studio 项目
可以通过以下任一方式指定要删除的映射:
- m: 映射UUID
- proj: Label Studio项目ID
- 两者都提供(优先使用 m)
此操作会:
1. 删除 Label Studio 中的项目
2. 软删除数据库中的映射记录
"""
try:
# 至少需要提供一个参数
if not m and not proj:
raise HTTPException(
status_code=400,
detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided"
)
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
service = DatasetMappingService(db)
# 优先使用 mapping_id 查询
if m:
logger.debug(f"Deleting by mapping UUID: {m}")
mapping = await service.get_mapping_by_uuid(m)
# 如果没有提供 m,使用 proj 查询
elif proj:
logger.debug(f"Deleting by project ID: {proj}")
mapping = await service.get_mapping_by_labelling_project_id(proj)
else:
mapping = None
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping either not found or not specified."
)
mapping_id = mapping.mapping_id
labelling_project_id = mapping.labelling_project_id
labelling_project_name = mapping.labelling_project_name
logger.debug(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}")
# 1. 删除 Label Studio 项目
try:
delete_success = await ls_client.delete_project(int(labelling_project_id))
if delete_success:
logger.debug(f"Successfully deleted Label Studio project: {labelling_project_id}")
else:
logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}")
except Exception as e:
logger.error(f"Error deleting Label Studio project: {e}")
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
# 2. 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(mapping_id)
if not soft_delete_success:
raise HTTPException(
status_code=500,
detail="Failed to delete mapping record"
)
logger.info(f"Successfully deleted mapping: {mapping_id}, Label Studio project: {labelling_project_id}")
return StandardResponse(
code=200,
message="success",
data=DeleteDatasetResponse(
mapping_id=mapping_id,
status="success",
message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'"
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,152 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
import math
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import DatasetMappingResponse
from app.schemas.common import StandardResponse, PaginatedData
from app.core.logging import get_logger
from . import project_router
logger = get_logger(__name__)
@project_router.get("/mappings/list", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
db: AsyncSession = Depends(get_db)
):
"""
查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"Listing mappings, page={page}, page_size={page_size}")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
limit=page_size
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=page_size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except Exception as e:
logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据 UUID 查询单个映射关系
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mapping: {mapping_id}")
mapping = await service.get_mapping_by_uuid(mapping_id)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {mapping_id}"
)
logger.info(f"Found mapping: {mapping.mapping_id}")
return StandardResponse(
code=200,
message="success",
data=mapping
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def get_mappings_by_source(
dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
db: AsyncSession = Depends(get_db)
):
"""
根据源数据集 ID 查询所有映射关系(分页)
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}")
# 获取数据和总数
mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id,
skip=skip,
limit=page_size
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=page_size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,71 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.services.sync_service import SyncService
from app.infrastructure import DatamateClient, LabelStudioClient
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
from app.schemas.dataset_mapping import (
DatasetMappingResponse,
SyncDatasetRequest,
SyncDatasetResponse,
)
from app.schemas import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from . import project_router
logger = get_logger(__name__)
@project_router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content(
request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db)
):
"""
同步数据集内容
根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中,
在数据库中记录更新时间,返回更新状态
"""
try:
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
dm_client = DatamateClient(db)
mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client, ls_client, mapping_service)
logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}")
# 根据 mapping_id 获取映射关系
mapping = await mapping_service.get_mapping_by_uuid(request.mapping_id)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {request.mapping_id}"
)
# 执行同步(使用映射中的源数据集UUID)
result = await sync_service.sync_dataset_files(request.mapping_id, request.batch_size)
logger.info(f"Sync completed: {result.synced_files}/{result.total_files} files")
return StandardResponse(
code=200,
message="success",
data=result
)
except HTTPException:
raise
except NoDatasetInfoFoundError as e:
logger.error(f"Failed to get dataset info: {e}")
raise HTTPException(status_code=404, detail=str(e))
except DatasetMappingNotFoundError as e:
logger.error(f"Mapping not found: {e}")
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Error syncing dataset content: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,33 @@
from fastapi import APIRouter
from typing import Dict, Any
from app.core.config import settings
from app.schemas import StandardResponse
router = APIRouter()
@router.get("/health", response_model=StandardResponse[Dict[str, Any]])
async def health_check():
"""健康检查端点"""
return StandardResponse(
code=200,
message="success",
data={
"status": "healthy",
"service": "Label Studio Adapter",
"version": settings.app_version
}
)
@router.get("/config", response_model=StandardResponse[Dict[str, Any]])
async def get_config():
"""获取配置信息"""
return StandardResponse(
code=200,
message="success",
data={
"app_name": settings.app_name,
"version": settings.app_version,
"label_studio_url": settings.label_studio_base_url,
"debug": settings.debug
}
)

View File

@@ -0,0 +1 @@
# app/core/__init__.py

View File

@@ -0,0 +1,145 @@
from pydantic_settings import BaseSettings
from typing import Optional
import os
from pathlib import Path
class Settings(BaseSettings):
"""应用程序配置"""
class Config:
env_file = ".env"
case_sensitive = False
extra = 'ignore' # 允许额外字段(如 Shell 脚本专用的环境变量)
# =========================
# Adapter 服务配置
# =========================
app_name: str = "Label Studio Adapter"
app_version: str = "1.0.0"
app_description: str = "Adapter for integrating Data Management System with Label Studio"
debug: bool = True
# 服务器配置
host: str = "0.0.0.0"
port: int = 8000
# CORS配置
allowed_origins: list = ["*"]
allowed_methods: list = ["*"]
allowed_headers: list = ["*"]
# MySQL数据库配置 (优先级1)
mysql_host: Optional[str] = None
mysql_port: int = 3306
mysql_user: Optional[str] = None
mysql_password: Optional[str] = None
mysql_database: Optional[str] = None
# PostgreSQL数据库配置 (优先级2)
postgres_host: Optional[str] = None
postgres_port: int = 5432
postgres_user: Optional[str] = None
postgres_password: Optional[str] = None
postgres_database: Optional[str] = None
# SQLite数据库配置 (优先级3 - 兜底)
sqlite_path: str = "data/labelstudio_adapter.db"
# 直接数据库URL配置(如果提供,将覆盖上述配置)
database_url: Optional[str] = None
# 日志配置
log_level: str = "INFO"
# 安全配置
secret_key: str = "your-secret-key-change-this-in-production"
access_token_expire_minutes: int = 30
# =========================
# Label Studio 服务配置
# =========================
label_studio_base_url: str = "http://label-studio:8080"
label_studio_username: Optional[str] = None # Label Studio 用户名(用于登录)
label_studio_password: Optional[str] = None # Label Studio 密码(用于登录)
label_studio_user_token: Optional[str] = None # Legacy Token
label_studio_local_storage_dataset_base_path: str = "/label-studio/local_files/dataset" # Label Studio容器中的本地存储基础路径
label_studio_local_storage_upload_base_path: str = "/label-studio/local_files/upload" # Label Studio容器中的本地存储基础路径
label_studio_file_path_prefix: str = "/data/local-files/?d=" # Label Studio本地文件服务路径前缀
ls_task_page_size: int = 1000
# =========================
# Data Management 服务配置
# =========================
dm_file_path_prefix: str = "/" # DM存储文件夹前缀
@property
def computed_database_url(self) -> str:
"""
根据优先级自动选择数据库连接URL
优先级:MySQL > PostgreSQL > SQLite3
"""
# 如果直接提供了database_url,优先使用
if self.database_url:
return self.database_url
# 优先级1: MySQL
if all([self.mysql_host, self.mysql_user, self.mysql_password, self.mysql_database]):
return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}"
# 优先级2: PostgreSQL
if all([self.postgres_host, self.postgres_user, self.postgres_password, self.postgres_database]):
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_database}"
# 优先级3: SQLite (兜底)
sqlite_full_path = Path(self.sqlite_path).absolute()
# 确保目录存在
sqlite_full_path.parent.mkdir(parents=True, exist_ok=True)
return f"sqlite+aiosqlite:///{sqlite_full_path}"
@property
def sync_database_url(self) -> str:
"""
用于数据库迁移的同步连接URL
将异步驱动替换为同步驱动
"""
async_url = self.computed_database_url
# 替换异步驱动为同步驱动
sync_replacements = {
"mysql+aiomysql://": "mysql+pymysql://",
"postgresql+asyncpg://": "postgresql+psycopg2://",
"sqlite+aiosqlite:///": "sqlite:///"
}
for async_driver, sync_driver in sync_replacements.items():
if async_url.startswith(async_driver):
return async_url.replace(async_driver, sync_driver)
return async_url
def get_database_info(self) -> dict:
"""获取数据库配置信息"""
url = self.computed_database_url
if url.startswith("mysql"):
db_type = "MySQL"
elif url.startswith("postgresql"):
db_type = "PostgreSQL"
elif url.startswith("sqlite"):
db_type = "SQLite"
else:
db_type = "Unknown"
return {
"type": db_type,
"url": url,
"sync_url": self.sync_database_url
}
# 全局设置实例
settings = Settings()

View File

@@ -0,0 +1,53 @@
import logging
import sys
from pathlib import Path
from app.core.config import settings
def setup_logging():
"""配置应用程序日志"""
# 创建logs目录
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# 配置日志格式
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
# 创建处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(getattr(logging, settings.log_level.upper()))
file_handler = logging.FileHandler(
log_dir / "app.log",
encoding="utf-8"
)
file_handler.setLevel(getattr(logging, settings.log_level.upper()))
error_handler = logging.FileHandler(
log_dir / "error.log",
encoding="utf-8"
)
error_handler.setLevel(logging.ERROR)
# 设置格式
formatter = logging.Formatter(log_format, date_format)
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
error_handler.setFormatter(formatter)
# 配置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.log_level.upper()))
root_logger.addHandler(console_handler)
root_logger.addHandler(file_handler)
root_logger.addHandler(error_handler)
# 配置第三方库日志级别(减少详细日志)
logging.getLogger("uvicorn").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR) # 隐藏SQL查询日志
logging.getLogger("httpx").setLevel(logging.WARNING)
def get_logger(name: str) -> logging.Logger:
"""获取指定名称的日志器"""
return logging.getLogger(name)

View File

@@ -0,0 +1 @@
# app/db/__init__.py

View File

@@ -0,0 +1,39 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base
from app.core.config import settings
from app.core.logging import get_logger
from typing import AsyncGenerator
logger = get_logger(__name__)
# 获取数据库配置信息
db_info = settings.get_database_info()
logger.info(f"使用数据库: {db_info['type']}")
logger.info(f"连接URL: {db_info['url']}")
# 创建数据库引擎
engine = create_async_engine(
settings.computed_database_url,
echo=False, # 关闭SQL调试日志以减少输出
future=True,
# SQLite特殊配置
connect_args={"check_same_thread": False} if "sqlite" in settings.computed_database_url else {}
)
# 创建会话工厂
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# 创建基础模型类
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()

View File

@@ -0,0 +1,31 @@
"""
自定义异常类定义
"""
class LabelStudioAdapterException(Exception):
"""Label Studio Adapter 基础异常类"""
pass
class DatasetMappingNotFoundError(LabelStudioAdapterException):
"""数据集映射未找到异常"""
def __init__(self, mapping_id: str):
self.mapping_id = mapping_id
super().__init__(f"Dataset mapping not found: {mapping_id}")
class NoDatasetInfoFoundError(LabelStudioAdapterException):
"""无法获取数据集信息异常"""
def __init__(self, dataset_uuid: str):
self.dataset_uuid = dataset_uuid
super().__init__(f"Failed to get dataset info: {dataset_uuid}")
class LabelStudioClientError(LabelStudioAdapterException):
"""Label Studio 客户端错误"""
pass
class DMServiceClientError(LabelStudioAdapterException):
"""DM 服务客户端错误"""
pass
class SyncServiceError(LabelStudioAdapterException):
"""同步服务错误"""
pass

View File

@@ -0,0 +1,6 @@
# app/clients/__init__.py
from .label_studio import Client as LabelStudioClient
from .datamate import Client as DatamateClient
__all__ = ["LabelStudioClient", "DatamateClient"]

View File

@@ -0,0 +1,159 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
from typing import Optional
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.dm_service import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse
from app.models.dm.dataset import Dataset
from app.models.dm.dataset_files import DatasetFiles
logger = get_logger(__name__)
class Client:
"""数据管理服务客户端 - 直接访问数据库"""
def __init__(self, db: AsyncSession):
"""
初始化 DM 客户端
Args:
db: 数据库会话
"""
self.db = db
logger.info("Initialize DM service client (Database mode)")
async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]:
"""获取数据集详情"""
try:
logger.info(f"Getting dataset detail: {dataset_id} ...")
result = await self.db.execute(
select(Dataset).where(Dataset.id == dataset_id)
)
dataset = result.scalar_one_or_none()
if not dataset:
logger.error(f"Dataset not found: {dataset_id}")
return None
# 将数据库模型转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
return DatasetResponse(
id=dataset.id, # type: ignore
name=dataset.name, # type: ignore
description=dataset.description or "", # type: ignore
datasetType=dataset.dataset_type, # type: ignore
status=dataset.status, # type: ignore
fileCount=dataset.file_count or 0, # type: ignore
totalSize=dataset.size_bytes or 0, # type: ignore
createdAt=dataset.created_at, # type: ignore
updatedAt=dataset.updated_at, # type: ignore
createdBy=dataset.created_by # type: ignore
)
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
) -> Optional[PagedDatasetFileResponse]:
"""获取数据集文件列表"""
try:
logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
# 构建查询
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
# 添加可选过滤条件
if file_type:
query = query.where(DatasetFiles.file_type == file_type)
if status:
query = query.where(DatasetFiles.status == status)
# 获取总数
count_query = select(func.count()).select_from(DatasetFiles).where(
DatasetFiles.dataset_id == dataset_id
)
if file_type:
count_query = count_query.where(DatasetFiles.file_type == file_type)
if status:
count_query = count_query.where(DatasetFiles.status == status)
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 分页查询
query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc())
result = await self.db.execute(query)
files = result.scalars().all()
# 转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
content = [
DatasetFileResponse(
id=f.id, # type: ignore
fileName=f.file_name, # type: ignore
fileType=f.file_type or "", # type: ignore
filePath=f.file_path, # type: ignore
originalName=f.file_name, # type: ignore
size=f.file_size, # type: ignore
status=f.status, # type: ignore
uploadedAt=f.upload_time, # type: ignore
description=None,
uploadedBy=None,
lastAccessTime=f.last_access_time # type: ignore
)
for f in files
]
total_pages = (total + size - 1) // size if size > 0 else 0
return PagedDatasetFileResponse(
content=content,
totalElements=total,
totalPages=total_pages,
page=page,
size=size
)
except Exception as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""
下载文件内容
注意:此方法保留接口兼容性,但实际文件下载可能需要通过文件系统或对象存储
"""
logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]:
"""获取文件下载URL(或文件路径)"""
try:
result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == dataset_id
)
)
file = result.scalar_one_or_none()
if not file:
logger.error(f"File not found: {file_id} in dataset {dataset_id}")
return None
# 返回文件路径(可以是本地路径或对象存储URL)
return file.file_path # type: ignore
except Exception as e:
logger.error(f"Failed to get file path for {file_id}: {e}")
return None
async def close(self):
"""关闭客户端连接(数据库模式下无需操作)"""
logger.info("DM service client closed (Database mode)")

View File

@@ -0,0 +1,469 @@
import httpx
from typing import Optional, Dict, Any, List
import json
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.label_studio import (
LabelStudioProject,
LabelStudioCreateProjectRequest,
LabelStudioCreateTaskRequest
)
logger = get_logger(__name__)
class Client:
"""Label Studio服务客户端
使用 HTTP REST API 直接与 Label Studio 交互
认证方式:使用 Authorization: Token {token} 头部进行认证
"""
# 默认标注配置模板
DEFAULT_LABEL_CONFIGS = {
"image": """
<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
<Label value="Object" background="red"/>
</RectangleLabels>
</View>
""",
"text": """
<View>
<Text name="text" value="$text"/>
<Choices name="sentiment" toName="text">
<Choice value="positive"/>
<Choice value="negative"/>
<Choice value="neutral"/>
</Choices>
</View>
""",
"audio": """
<View>
<Audio name="audio" value="$audio"/>
<AudioRegionLabels name="label" toName="audio">
<Label value="Speech" background="red"/>
<Label value="Noise" background="blue"/>
</AudioRegionLabels>
</View>
""",
"video": """
<View>
<Video name="video" value="$video"/>
<VideoRegionLabels name="label" toName="video">
<Label value="Action" background="red"/>
</VideoRegionLabels>
</View>
"""
}
def __init__(
self,
base_url: Optional[str] = None,
token: Optional[str] = None,
timeout: float = 30.0
):
"""初始化 Label Studio 客户端
Args:
base_url: Label Studio 服务地址
token: API Token(使用 Authorization: Token {token} 头部)
timeout: 请求超时时间(秒)
"""
self.base_url = (base_url or settings.label_studio_base_url).rstrip("/")
self.token = token or settings.label_studio_user_token
self.timeout = timeout
if not self.token:
raise ValueError("Label Studio API token is required")
# 初始化 HTTP 客户端
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout,
headers={
"Authorization": f"Token {self.token}",
"Content-Type": "application/json"
}
)
logger.info(f"Label Studio client initialized: {self.base_url}")
def get_label_config_by_type(self, data_type: str) -> str:
"""根据数据类型获取标注配置"""
return self.DEFAULT_LABEL_CONFIGS.get(data_type.lower(), self.DEFAULT_LABEL_CONFIGS["image"])
async def create_project(
self,
title: str,
description: str = "",
label_config: Optional[str] = None,
data_type: str = "image"
) -> Optional[Dict[str, Any]]:
"""创建Label Studio项目"""
try:
logger.info(f"Creating Label Studio project: {title}")
if not label_config:
label_config = self.get_label_config_by_type(data_type)
project_data = {
"title": title,
"description": description,
"label_config": label_config.strip()
}
response = await self.client.post("/api/projects", json=project_data)
response.raise_for_status()
project = response.json()
project_id = project.get("id")
if not project_id:
raise Exception("Label Studio response does not contain project ID")
logger.info(f"Project created successfully, ID: {project_id}")
return project
except httpx.HTTPStatusError as e:
logger.error(f"Create project failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while creating Label Studio project: {e}")
return None
async def import_tasks(
self,
project_id: int,
tasks: List[Dict[str, Any]],
commit_to_project: bool = True,
return_task_ids: bool = True
) -> Optional[Dict[str, Any]]:
"""批量导入任务到Label Studio项目"""
try:
logger.info(f"Importing {len(tasks)} tasks into project {project_id}")
response = await self.client.post(
f"/api/projects/{project_id}/import",
json=tasks,
params={
"commit_to_project": str(commit_to_project).lower(),
"return_task_ids": str(return_task_ids).lower()
}
)
response.raise_for_status()
result = response.json()
task_count = result.get("task_count", len(tasks))
logger.info(f"Tasks imported successfully: {task_count}")
return result
except httpx.HTTPStatusError as e:
logger.error(f"Import tasks failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while importing tasks: {e}")
return None
async def create_tasks_batch(
self,
project_id: str,
tasks: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""批量创建任务的便利方法"""
try:
pid = int(project_id)
return await self.import_tasks(pid, tasks)
except ValueError as e:
logger.error(f"Invalid project ID format: {project_id}, error: {e}")
return None
except Exception as e:
logger.error(f"Error while creating tasks in batch: {e}")
return None
async def create_task(
self,
project_id: str,
data: Dict[str, Any],
meta: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""创建单个任务"""
try:
task = {"data": data}
if meta:
task["meta"] = meta
return await self.create_tasks_batch(project_id, [task])
except Exception as e:
logger.error(f"Error while creating single task: {e}")
return None
async def get_project_tasks(
self,
project_id: str,
page: Optional[int] = None,
page_size: int = 1000
) -> Optional[Dict[str, Any]]:
"""获取项目任务信息
Args:
project_id: 项目ID
page: 页码(从1开始)。如果为None,则获取所有任务
page_size: 每页大小
Returns:
如果指定了page参数,返回包含分页信息的字典:
{
"count": 总任务数,
"page": 当前页码,
"page_size": 每页大小,
"project_id": 项目ID,
"tasks": 当前页的任务列表
}
如果page为None,返回包含所有任务的字典:
"count": 总任务数,
"project_id": 项目ID,
"tasks": 所有任务列表
}
"""
try:
pid = int(project_id)
# 如果指定了page,直接获取单页任务
if page is not None:
logger.info(f"Fetching tasks for project {pid}, page {page} (page_size={page_size})")
response = await self.client.get(
f"/api/projects/{pid}/tasks",
params={
"page": page,
"page_size": page_size
}
)
response.raise_for_status()
result = response.json()
# 返回单页结果,包含分页信息
return {
"count": result.get("total", len(result.get("tasks", []))),
"page": page,
"page_size": page_size,
"project_id": pid,
"tasks": result.get("tasks", [])
}
# 如果未指定page,获取所有任务
logger.info(f"Start fetching all tasks for project {pid} (page_size={page_size})")
all_tasks = []
current_page = 1
while True:
try:
response = await self.client.get(
f"/api/projects/{pid}/tasks",
params={
"page": current_page,
"page_size": page_size
}
)
response.raise_for_status()
result = response.json()
tasks = result.get("tasks", [])
if not tasks:
logger.debug(f"No more tasks on page {current_page}")
break
all_tasks.extend(tasks)
logger.debug(f"Fetched page {current_page}, {len(tasks)} tasks")
# 检查是否还有更多页
total = result.get("total", 0)
if len(all_tasks) >= total:
break
current_page += 1
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
# 超出页数范围,结束分页
logger.debug(f"Reached last page (page {current_page})")
break
else:
raise
logger.info(f"Fetched all tasks for project {pid}, total {len(all_tasks)}")
# 返回所有任务,不包含分页信息
return {
"count": len(all_tasks),
"project_id": pid,
"tasks": all_tasks
}
except httpx.HTTPStatusError as e:
logger.error(f"获取项目任务失败 HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"获取项目任务时发生错误: {e}")
return None
async def delete_task(
self,
task_id: int
) -> bool:
"""删除单个任务"""
try:
logger.info(f"Deleting task: {task_id}")
response = await self.client.delete(f"/api/tasks/{task_id}")
response.raise_for_status()
logger.info(f"Task deleted: {task_id}")
return True
except httpx.HTTPStatusError as e:
logger.error(f"Delete task {task_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False
except Exception as e:
logger.error(f"Error while deleting task {task_id}: {e}")
return False
async def delete_tasks_batch(
self,
task_ids: List[int]
) -> Dict[str, int]:
"""批量删除任务"""
try:
logger.info(f"Deleting {len(task_ids)} tasks in batch")
successful_deletions = 0
failed_deletions = 0
for task_id in task_ids:
if await self.delete_task(task_id):
successful_deletions += 1
else:
failed_deletions += 1
logger.info(f"Batch deletion finished: success {successful_deletions}, failed {failed_deletions}")
return {
"successful": successful_deletions,
"failed": failed_deletions,
"total": len(task_ids)
}
except Exception as e:
logger.error(f"Error while deleting tasks in batch: {e}")
return {
"successful": 0,
"failed": len(task_ids),
"total": len(task_ids)
}
async def get_project(self, project_id: int) -> Optional[Dict[str, Any]]:
"""获取项目信息"""
try:
logger.info(f"Fetching project info: {project_id}")
response = await self.client.get(f"/api/projects/{project_id}")
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Get project info failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while getting project info: {e}")
return None
async def delete_project(self, project_id: int) -> bool:
"""删除项目"""
try:
logger.info(f"Deleting project: {project_id}")
response = await self.client.delete(f"/api/projects/{project_id}")
response.raise_for_status()
logger.info(f"Project deleted: {project_id}")
return True
except httpx.HTTPStatusError as e:
logger.error(f"Delete project {project_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False
except Exception as e:
logger.error(f"Error while deleting project {project_id}: {e}")
return False
async def create_local_storage(
self,
project_id: int,
path: str,
title: str,
use_blob_urls: bool = True,
regex_filter: Optional[str] = None,
description: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""创建本地文件存储配置
Args:
project_id: Label Studio 项目 ID
path: 本地文件路径(在 Label Studio 容器中的路径)
title: 存储配置标题
use_blob_urls: 是否使用 blob URLs(建议 True)
regex_filter: 文件过滤正则表达式(可选)
description: 存储描述(可选)
Returns:
创建的存储配置信息,失败返回 None
"""
try:
logger.info(f"Creating local storage for project {project_id}: {path}")
storage_data = {
"project": project_id,
"path": path,
"title": title,
"use_blob_urls": use_blob_urls
}
if regex_filter:
storage_data["regex_filter"] = regex_filter
if description:
storage_data["description"] = description
response = await self.client.post(
"/api/storages/localfiles/",
json=storage_data
)
response.raise_for_status()
storage = response.json()
storage_id = storage.get("id")
logger.info(f"Local storage created successfully, ID: {storage_id}")
return storage
except httpx.HTTPStatusError as e:
logger.error(f"Create local storage failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while creating local storage: {e}")
return None
async def close(self):
"""关闭客户端连接"""
try:
await self.client.aclose()
logger.info("Label Studio client closed")
except Exception as e:
logger.error(f"Error while closing Label Studio client: {e}")

View File

@@ -0,0 +1,160 @@
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from contextlib import asynccontextmanager
from typing import Dict, Any
from .core.config import settings
from .core.logging import setup_logging, get_logger
from .infrastructure import LabelStudioClient
from .api import api_router
from .schemas import StandardResponse
# 设置日志
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用程序生命周期管理"""
# 启动时初始化
logger.info("Starting Label Studio Adapter...")
# 初始化 Label Studio 客户端,使用 HTTP REST API + Token 认证
ls_client = LabelStudioClient(
base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token
)
logger.info("Label Studio Adapter started")
yield
# 关闭时清理
logger.info("Shutting down Label Studio Adapter...")
# 客户端清理会在客户端管理器中处理
logger.info("Label Studio Adapter stopped")
# 创建FastAPI应用
app = FastAPI(
title=settings.app_name,
description=settings.app_description,
version=settings.app_version,
debug=settings.debug,
lifespan=lifespan
)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_credentials=True,
allow_methods=settings.allowed_methods,
allow_headers=settings.allowed_headers,
)
# 自定义异常处理器:StarletteHTTPException (包括404等)
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
"""将Starlette的HTTPException转换为标准响应格式"""
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": "error",
"data": {
"detail": exc.detail
}
}
)
# 自定义异常处理器:FastAPI HTTPException
@app.exception_handler(HTTPException)
async def fastapi_http_exception_handler(request: Request, exc: HTTPException):
"""将FastAPI的HTTPException转换为标准响应格式"""
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": "error",
"data": {
"detail": exc.detail
}
}
)
# 自定义异常处理器:RequestValidationError
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""将请求验证错误转换为标准响应格式"""
return JSONResponse(
status_code=422,
content={
"code": 422,
"message": "error",
"data": {
"detail": "Validation error",
"errors": exc.errors()
}
}
)
# 自定义异常处理器:未捕获的异常
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""将未捕获的异常转换为标准响应格式"""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"code": 500,
"message": "error",
"data": {
"detail": "Internal server error"
}
}
)
# 注册路由
app.include_router(api_router, prefix="/api")
# 测试端点:验证异常处理
@app.get("/test-404", include_in_schema=False)
async def test_404():
"""测试404异常处理"""
raise HTTPException(status_code=404, detail="Test 404 error")
@app.get("/test-500", include_in_schema=False)
async def test_500():
"""测试500异常处理"""
raise Exception("Test uncaught exception")
# 根路径重定向到文档
@app.get("/", response_model=StandardResponse[Dict[str, Any]], include_in_schema=False)
async def root():
"""根路径,返回服务信息"""
return StandardResponse(
code=200,
message="success",
data={
"message": f"{settings.app_name} is running",
"version": settings.app_version,
"docs_url": "/docs",
"label_studio_url": settings.label_studio_base_url
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
log_level=settings.log_level.lower()
)

View File

@@ -0,0 +1,138 @@
# DataMate 数据模型结构
本文档列出了根据 `scripts/db` 中的 SQL 文件创建的所有 Python 数据模型。
## 模型组织结构
```
app/models/
├── __init__.py # 主模块导出文件
├── dm/ # 数据管理 (Data Management) 模块
│ ├── __init__.py
│ ├── annotation_template.py # 标注模板
│ ├── labeling_project.py # 标注项目
│ ├── dataset.py # 数据集
│ ├── dataset_files.py # 数据集文件
│ ├── dataset_statistics.py # 数据集统计
│ ├── dataset_tag.py # 数据集标签关联
│ ├── tag.py # 标签
│ └── user.py # 用户
├── cleaning/ # 数据清洗 (Data Cleaning) 模块
│ ├── __init__.py
│ ├── clean_template.py # 清洗模板
│ ├── clean_task.py # 清洗任务
│ ├── operator_instance.py # 算子实例
│ └── clean_result.py # 清洗结果
├── collection/ # 数据归集 (Data Collection) 模块
│ ├── __init__.py
│ ├── task_execution.py # 任务执行明细
│ ├── collection_task.py # 数据归集任务
│ ├── task_log.py # 任务执行记录
│ └── datax_template.py # DataX模板配置
├── common/ # 通用 (Common) 模块
│ ├── __init__.py
│ └── chunk_upload_request.py # 文件切片上传请求
└── operator/ # 算子 (Operator) 模块
├── __init__.py
├── operator.py # 算子
├── operator_category.py # 算子分类
└── operator_category_relation.py # 算子分类关联
```
## 模块详情
### 1. Data Management (DM) 模块
对应 SQL: `data-management-init.sql``data-annotation-init.sql`
#### 模型列表:
- **AnnotationTemplate** (`t_dm_annotation_templates`) - 标注模板
- **LabelingProject** (`t_dm_labeling_projects`) - 标注项目
- **Dataset** (`t_dm_datasets`) - 数据集(支持医学影像、文本、问答等多种类型)
- **DatasetFiles** (`t_dm_dataset_files`) - 数据集文件
- **DatasetStatistics** (`t_dm_dataset_statistics`) - 数据集统计信息
- **Tag** (`t_dm_tags`) - 标签
- **DatasetTag** (`t_dm_dataset_tags`) - 数据集标签关联
- **User** (`users`) - 用户
### 2. Data Cleaning 模块
对应 SQL: `data-cleaning-init.sql`
#### 模型列表:
- **CleanTemplate** (`t_clean_template`) - 清洗模板
- **CleanTask** (`t_clean_task`) - 清洗任务
- **OperatorInstance** (`t_operator_instance`) - 算子实例
- **CleanResult** (`t_clean_result`) - 清洗结果
### 3. Data Collection (DC) 模块
对应 SQL: `data-collection-init.sql`
#### 模型列表:
- **TaskExecution** (`t_dc_task_executions`) - 任务执行明细
- **CollectionTask** (`t_dc_collection_tasks`) - 数据归集任务
- **TaskLog** (`t_dc_task_log`) - 任务执行记录
- **DataxTemplate** (`t_dc_datax_templates`) - DataX模板配置
### 4. Common 模块
对应 SQL: `data-common-init.sql`
#### 模型列表:
- **ChunkUploadRequest** (`t_chunk_upload_request`) - 文件切片上传请求
### 5. Operator 模块
对应 SQL: `data-operator-init.sql`
#### 模型列表:
- **Operator** (`t_operator`) - 算子
- **OperatorCategory** (`t_operator_category`) - 算子分类
- **OperatorCategoryRelation** (`t_operator_category_relation`) - 算子分类关联
## 使用方式
```python
# 导入所有模型
from app.models import (
# DM 模块
AnnotationTemplate,
LabelingProject,
Dataset,
DatasetFiles,
DatasetStatistics,
DatasetTag,
Tag,
User,
# Cleaning 模块
CleanTemplate,
CleanTask,
OperatorInstance,
CleanResult,
# Collection 模块
TaskExecution,
CollectionTask,
TaskLog,
DataxTemplate,
# Common 模块
ChunkUploadRequest,
# Operator 模块
Operator,
OperatorCategory,
OperatorCategoryRelation
)
# 或者按模块导入
from app.models.dm import Dataset, DatasetFiles
from app.models.collection import CollectionTask
from app.models.operator import Operator
```
## 注意事项
1. **UUID 主键**: 大部分表使用 UUID (String(36)) 作为主键
2. **时间戳**: 使用 `TIMESTAMP` 类型,并配置自动更新
3. **软删除**: 部分模型(如 AnnotationTemplate, LabelingProject)支持软删除,包含 `deleted_at` 字段和 `is_deleted` 属性
4. **JSON 字段**: 配置信息、元数据等使用 JSON 类型存储
5. **字段一致性**: 所有模型字段都严格按照 SQL 定义创建,确保与数据库表结构完全一致
## 更新记录
- 2025-10-25: 根据 `scripts/db` 中的 SQL 文件创建所有数据模型
- 已更新现有的 `annotation_template.py``labeling_project.py``dataset_files.py` 以匹配 SQL 定义

View File

@@ -0,0 +1,69 @@
# app/models/__init__.py
# Data Management (DM) 模块
from .dm import (
AnnotationTemplate,
LabelingProject,
Dataset,
DatasetFiles,
DatasetStatistics,
DatasetTag,
Tag,
User
)
# Data Cleaning 模块
from .cleaning import (
CleanTemplate,
CleanTask,
OperatorInstance,
CleanResult
)
# Data Collection (DC) 模块
from .collection import (
TaskExecution,
CollectionTask,
TaskLog,
DataxTemplate
)
# Common 模块
from .common import (
ChunkUploadRequest
)
# Operator 模块
from .operator import (
Operator,
OperatorCategory,
OperatorCategoryRelation
)
__all__ = [
# DM 模块
"AnnotationTemplate",
"LabelingProject",
"Dataset",
"DatasetFiles",
"DatasetStatistics",
"DatasetTag",
"Tag",
"User",
# Cleaning 模块
"CleanTemplate",
"CleanTask",
"OperatorInstance",
"CleanResult",
# Collection 模块
"TaskExecution",
"CollectionTask",
"TaskLog",
"DataxTemplate",
# Common 模块
"ChunkUploadRequest",
# Operator 模块
"Operator",
"OperatorCategory",
"OperatorCategoryRelation"
]

View File

@@ -0,0 +1,13 @@
# app/models/cleaning/__init__.py
from .clean_template import CleanTemplate
from .clean_task import CleanTask
from .operator_instance import OperatorInstance
from .clean_result import CleanResult
__all__ = [
"CleanTemplate",
"CleanTask",
"OperatorInstance",
"CleanResult"
]

View File

@@ -0,0 +1,22 @@
from sqlalchemy import Column, String, BigInteger, Text
from app.db.database import Base
class CleanResult(Base):
"""清洗结果模型"""
__tablename__ = "t_clean_result"
instance_id = Column(String(64), primary_key=True, comment="实例ID")
src_file_id = Column(String(64), nullable=True, comment="源文件ID")
dest_file_id = Column(String(64), primary_key=True, comment="目标文件ID")
src_name = Column(String(256), nullable=True, comment="源文件名")
dest_name = Column(String(256), nullable=True, comment="目标文件名")
src_type = Column(String(256), nullable=True, comment="源文件类型")
dest_type = Column(String(256), nullable=True, comment="目标文件类型")
src_size = Column(BigInteger, nullable=True, comment="源文件大小")
dest_size = Column(BigInteger, nullable=True, comment="目标文件大小")
status = Column(String(256), nullable=True, comment="处理状态")
result = Column(Text, nullable=True, comment="处理结果")
def __repr__(self):
return f"<CleanResult(instance_id={self.instance_id}, dest_file_id={self.dest_file_id}, status={self.status})>"

View File

@@ -0,0 +1,27 @@
from sqlalchemy import Column, String, BigInteger, Integer, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CleanTask(Base):
"""清洗任务模型"""
__tablename__ = "t_clean_task"
id = Column(String(64), primary_key=True, comment="任务ID")
name = Column(String(64), nullable=True, comment="任务名称")
description = Column(String(256), nullable=True, comment="任务描述")
status = Column(String(256), nullable=True, comment="任务状态")
src_dataset_id = Column(String(64), nullable=True, comment="源数据集ID")
src_dataset_name = Column(String(64), nullable=True, comment="源数据集名称")
dest_dataset_id = Column(String(64), nullable=True, comment="目标数据集ID")
dest_dataset_name = Column(String(64), nullable=True, comment="目标数据集名称")
before_size = Column(BigInteger, nullable=True, comment="清洗前大小")
after_size = Column(BigInteger, nullable=True, comment="清洗后大小")
file_count = Column(Integer, nullable=True, comment="文件数量")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
started_at = Column(TIMESTAMP, nullable=True, comment="开始时间")
finished_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
created_by = Column(String(256), nullable=True, comment="创建者")
def __repr__(self):
return f"<CleanTask(id={self.id}, name={self.name}, status={self.status})>"

View File

@@ -0,0 +1,18 @@
from sqlalchemy import Column, String, Text, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CleanTemplate(Base):
"""清洗模板模型"""
__tablename__ = "t_clean_template"
id = Column(String(64), primary_key=True, unique=True, comment="模板ID")
name = Column(String(64), nullable=True, comment="模板名称")
description = Column(String(256), nullable=True, comment="模板描述")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(256), nullable=True, comment="创建者")
def __repr__(self):
return f"<CleanTemplate(id={self.id}, name={self.name})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, Integer, Text
from app.db.database import Base
class OperatorInstance(Base):
"""算子实例模型"""
__tablename__ = "t_operator_instance"
instance_id = Column(String(256), primary_key=True, comment="实例ID")
operator_id = Column(String(256), primary_key=True, comment="算子ID")
op_index = Column(Integer, primary_key=True, comment="算子索引")
settings_override = Column(Text, nullable=True, comment="配置覆盖")
def __repr__(self):
return f"<OperatorInstance(instance_id={self.instance_id}, operator_id={self.operator_id}, index={self.op_index})>"

View File

@@ -0,0 +1,13 @@
# app/models/collection/__init__.py
from .task_execution import TaskExecution
from .collection_task import CollectionTask
from .task_log import TaskLog
from .datax_template import DataxTemplate
__all__ = [
"TaskExecution",
"CollectionTask",
"TaskLog",
"DataxTemplate"
]

View File

@@ -0,0 +1,28 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CollectionTask(Base):
"""数据归集任务模型"""
__tablename__ = "t_dc_collection_tasks"
id = Column(String(36), primary_key=True, comment="任务ID(UUID)")
name = Column(String(255), nullable=False, comment="任务名称")
description = Column(Text, nullable=True, comment="任务描述")
sync_mode = Column(String(20), default='ONCE', comment="同步模式:ONCE/SCHEDULED")
config = Column(Text, nullable=False, comment="归集配置(DataX配置),包含源端和目标端配置信息")
schedule_expression = Column(String(255), nullable=True, comment="Cron调度表达式")
status = Column(String(20), default='DRAFT', comment="任务状态:DRAFT/READY/RUNNING/SUCCESS/FAILED/STOPPED")
retry_count = Column(Integer, default=3, comment="重试次数")
timeout_seconds = Column(Integer, default=3600, comment="超时时间(秒)")
max_records = Column(BigInteger, nullable=True, comment="最大处理记录数")
sort_field = Column(String(100), nullable=True, comment="增量字段")
last_execution_id = Column(String(36), nullable=True, comment="最后执行ID(UUID)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<CollectionTask(id={self.id}, name={self.name}, status={self.status})>"

View File

@@ -0,0 +1,23 @@
from sqlalchemy import Column, String, Text, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class DataxTemplate(Base):
"""DataX模板配置模型"""
__tablename__ = "t_dc_datax_templates"
id = Column(String(36), primary_key=True, comment="模板ID(UUID)")
name = Column(String(255), nullable=False, unique=True, comment="模板名称")
source_type = Column(String(50), nullable=False, comment="源数据源类型")
target_type = Column(String(50), nullable=False, comment="目标数据源类型")
template_content = Column(Text, nullable=False, comment="模板内容")
description = Column(Text, nullable=True, comment="模板描述")
version = Column(String(20), default='1.0.0', comment="版本号")
is_system = Column(Boolean, default=False, comment="是否系统模板")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
def __repr__(self):
return f"<DataxTemplate(id={self.id}, name={self.name}, source={self.source_type}, target={self.target_type})>"

View File

@@ -0,0 +1,34 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, DECIMAL, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class TaskExecution(Base):
"""任务执行明细模型"""
__tablename__ = "t_dc_task_executions"
id = Column(String(36), primary_key=True, comment="执行记录ID(UUID)")
task_id = Column(String(36), nullable=False, comment="任务ID")
task_name = Column(String(255), nullable=False, comment="任务名称")
status = Column(String(20), default='RUNNING', comment="执行状态:RUNNING/SUCCESS/FAILED/STOPPED")
progress = Column(DECIMAL(5, 2), default=0.00, comment="进度百分比")
records_total = Column(BigInteger, default=0, comment="总记录数")
records_processed = Column(BigInteger, default=0, comment="已处理记录数")
records_success = Column(BigInteger, default=0, comment="成功记录数")
records_failed = Column(BigInteger, default=0, comment="失败记录数")
throughput = Column(DECIMAL(10, 2), default=0.00, comment="吞吐量(条/秒)")
data_size_bytes = Column(BigInteger, default=0, comment="数据量(字节)")
started_at = Column(TIMESTAMP, nullable=True, comment="开始时间")
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
duration_seconds = Column(Integer, default=0, comment="执行时长(秒)")
config = Column(JSON, nullable=True, comment="执行配置")
error_message = Column(Text, nullable=True, comment="错误信息")
datax_job_id = Column(Text, nullable=True, comment="datax任务ID")
result = Column(Text, nullable=True, comment="执行结果")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<TaskExecution(id={self.id}, task_id={self.task_id}, status={self.status})>"

View File

@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class TaskLog(Base):
"""任务执行记录模型"""
__tablename__ = "t_dc_task_log"
id = Column(String(36), primary_key=True, comment="执行记录ID(UUID)")
task_id = Column(String(36), nullable=False, comment="任务ID")
task_name = Column(String(255), nullable=False, comment="任务名称")
sync_mode = Column(String(20), default='FULL', comment="同步模式:FULL/INCREMENTAL")
status = Column(String(20), default='RUNNING', comment="执行状态:RUNNING/SUCCESS/FAILED/STOPPED")
start_time = Column(TIMESTAMP, nullable=True, comment="开始时间")
end_time = Column(TIMESTAMP, nullable=True, comment="结束时间")
duration = Column(BigInteger, nullable=True, comment="执行时长(毫秒)")
process_id = Column(String(50), nullable=True, comment="进程ID")
log_path = Column(String(500), nullable=True, comment="日志文件路径")
error_msg = Column(Text, nullable=True, comment="错误信息")
result = Column(Text, nullable=True, comment="执行结果")
retry_times = Column(Integer, default=0, comment="重试次数")
create_time = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
def __repr__(self):
return f"<TaskLog(id={self.id}, task_id={self.task_id}, status={self.status})>"

View File

@@ -0,0 +1,7 @@
# app/models/common/__init__.py
from .chunk_upload_request import ChunkUploadRequest
__all__ = [
"ChunkUploadRequest"
]

View File

@@ -0,0 +1,19 @@
from sqlalchemy import Column, String, Integer, Text, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class ChunkUploadRequest(Base):
"""文件切片上传请求模型"""
__tablename__ = "t_chunk_upload_request"
id = Column(String(36), primary_key=True, comment="UUID")
total_file_num = Column(Integer, nullable=True, comment="总文件数")
uploaded_file_num = Column(Integer, nullable=True, comment="已上传文件数")
upload_path = Column(String(256), nullable=True, comment="文件路径")
timeout = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="上传请求超时时间")
service_id = Column(String(64), nullable=True, comment="上传请求所属服务:DATA-MANAGEMENT(数据管理)")
check_info = Column(Text, nullable=True, comment="业务信息")
def __repr__(self):
return f"<ChunkUploadRequest(id={self.id}, service_id={self.service_id}, progress={self.uploaded_file_num}/{self.total_file_num})>"

View File

@@ -0,0 +1,21 @@
# app/models/dm/__init__.py
from .annotation_template import AnnotationTemplate
from .labeling_project import LabelingProject
from .dataset import Dataset
from .dataset_files import DatasetFiles
from .dataset_statistics import DatasetStatistics
from .dataset_tag import DatasetTag
from .tag import Tag
from .user import User
__all__ = [
"AnnotationTemplate",
"LabelingProject",
"Dataset",
"DatasetFiles",
"DatasetStatistics",
"DatasetTag",
"Tag",
"User"
]

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class AnnotationTemplate(Base):
"""标注模板模型"""
__tablename__ = "t_dm_annotation_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
name = Column(String(32), nullable=False, comment="模板名称")
description = Column(String(255), nullable=True, comment="模板描述")
configuration = Column(JSON, nullable=True, comment="配置信息(JSON格式)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<AnnotationTemplate(id={self.id}, name={self.name})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -0,0 +1,35 @@
from sqlalchemy import Column, String, Text, BigInteger, Integer, Boolean, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class Dataset(Base):
"""数据集模型(支持医学影像、文本、问答等多种类型)"""
__tablename__ = "t_dm_datasets"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(255), nullable=False, comment="数据集名称")
description = Column(Text, nullable=True, comment="数据集描述")
dataset_type = Column(String(50), nullable=False, comment="数据集类型:IMAGE/TEXT/QA/MULTIMODAL/OTHER")
category = Column(String(100), nullable=True, comment="数据集分类:医学影像/问答/文献等")
path = Column(String(500), nullable=True, comment="数据存储路径")
format = Column(String(50), nullable=True, comment="数据格式:DCM/JPG/JSON/CSV等")
schema_info = Column(JSON, nullable=True, comment="数据结构信息")
size_bytes = Column(BigInteger, default=0, comment="数据大小(字节)")
file_count = Column(BigInteger, default=0, comment="文件数量")
record_count = Column(BigInteger, default=0, comment="记录数量")
retention_days = Column(Integer, default=0, comment="数据保留天数(0表示长期保留)")
tags = Column(JSON, nullable=True, comment="标签列表")
metadata = Column(JSON, nullable=True, comment="元数据信息")
status = Column(String(50), default='DRAFT', comment="状态:DRAFT/ACTIVE/ARCHIVED")
is_public = Column(Boolean, default=False, comment="是否公开")
is_featured = Column(Boolean, default=False, comment="是否推荐")
version = Column(BigInteger, nullable=False, default=0, comment="版本号")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<Dataset(id={self.id}, name={self.name}, type={self.dataset_type})>"

View File

@@ -0,0 +1,27 @@
from sqlalchemy import Column, String, JSON, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetFiles(Base):
"""DM数据集文件模型"""
__tablename__ = "t_dm_dataset_files"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="所属数据集ID(UUID)")
file_name = Column(String(255), nullable=False, comment="文件名")
file_path = Column(String(1000), nullable=False, comment="文件路径")
file_type = Column(String(50), nullable=True, comment="文件格式:JPG/PNG/DCM/TXT等")
file_size = Column(BigInteger, default=0, comment="文件大小(字节)")
check_sum = Column(String(64), nullable=True, comment="文件校验和")
tags = Column(JSON, nullable=True, comment="文件标签信息")
metadata = Column(JSON, nullable=True, comment="文件元数据")
status = Column(String(50), default='ACTIVE', comment="文件状态:ACTIVE/DELETED/PROCESSING")
upload_time = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="上传时间")
last_access_time = Column(TIMESTAMP, nullable=True, comment="最后访问时间")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<DatasetFiles(id={self.id}, dataset_id={self.dataset_id}, file_name={self.file_name})>"

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, String, Date, BigInteger, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetStatistics(Base):
"""数据集统计信息模型"""
__tablename__ = "t_dm_dataset_statistics"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID(UUID)")
stat_date = Column(Date, nullable=False, comment="统计日期")
total_files = Column(BigInteger, default=0, comment="总文件数")
total_size = Column(BigInteger, default=0, comment="总大小(字节)")
processed_files = Column(BigInteger, default=0, comment="已处理文件数")
error_files = Column(BigInteger, default=0, comment="错误文件数")
download_count = Column(BigInteger, default=0, comment="下载次数")
view_count = Column(BigInteger, default=0, comment="查看次数")
quality_metrics = Column(JSON, nullable=True, comment="质量指标")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<DatasetStatistics(id={self.id}, dataset_id={self.dataset_id}, date={self.stat_date})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class DatasetTag(Base):
"""数据集标签关联模型"""
__tablename__ = "t_dm_dataset_tags"
dataset_id = Column(String(36), primary_key=True, comment="数据集ID(UUID)")
tag_id = Column(String(36), primary_key=True, comment="标签ID(UUID)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
def __repr__(self):
return f"<DatasetTag(dataset_id={self.dataset_id}, tag_id={self.tag_id})>"

View File

@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Integer, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class LabelingProject(Base):
"""DM标注项目模型(原 DatasetMapping)"""
__tablename__ = "t_dm_labeling_projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
name = Column(String(32), nullable=False, comment="项目名称")
labeling_project_id = Column(Integer, nullable=False, comment="Label Studio项目ID")
configuration = Column(JSON, nullable=True, comment="标签配置")
progress = Column(JSON, nullable=True, comment="标注进度统计")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<LabelingProject(id={self.id}, dataset_id={self.dataset_id}, name={self.name})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -0,0 +1,21 @@
from sqlalchemy import Column, String, Text, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class Tag(Base):
"""标签模型"""
__tablename__ = "t_dm_tags"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(100), nullable=False, unique=True, comment="标签名称")
description = Column(Text, nullable=True, comment="标签描述")
category = Column(String(50), nullable=True, comment="标签分类")
color = Column(String(7), nullable=True, comment="标签颜色(十六进制)")
usage_count = Column(BigInteger, default=0, comment="使用次数")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<Tag(id={self.id}, name={self.name}, category={self.category})>"

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class User(Base):
"""用户模型"""
__tablename__ = "users"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="用户ID")
username = Column(String(255), nullable=False, unique=True, comment="用户名")
email = Column(String(255), nullable=False, unique=True, comment="邮箱")
password_hash = Column(String(255), nullable=False, comment="密码哈希")
full_name = Column(String(255), nullable=True, comment="真实姓名")
avatar_url = Column(String(500), nullable=True, comment="头像URL")
role = Column(String(50), nullable=False, default='USER', comment="角色:ADMIN/USER")
organization = Column(String(255), nullable=True, comment="所属机构")
enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
last_login_at = Column(TIMESTAMP, nullable=True, comment="最后登录时间")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<User(id={self.id}, username={self.username}, role={self.role})>"

View File

@@ -0,0 +1,11 @@
# app/models/operator/__init__.py
from .operator import Operator
from .operator_category import OperatorCategory
from .operator_category_relation import OperatorCategoryRelation
__all__ = [
"Operator",
"OperatorCategory",
"OperatorCategoryRelation"
]

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, Text, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class Operator(Base):
"""算子模型"""
__tablename__ = "t_operator"
id = Column(String(64), primary_key=True, comment="算子ID")
name = Column(String(64), nullable=True, comment="算子名称")
description = Column(String(256), nullable=True, comment="算子描述")
version = Column(String(256), nullable=True, comment="版本")
inputs = Column(String(256), nullable=True, comment="输入类型")
outputs = Column(String(256), nullable=True, comment="输出类型")
runtime = Column(Text, nullable=True, comment="运行时信息")
settings = Column(Text, nullable=True, comment="配置信息")
file_name = Column(Text, nullable=True, comment="文件名")
is_star = Column(Boolean, nullable=True, comment="是否收藏")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<Operator(id={self.id}, name={self.name}, version={self.version})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, Integer
from app.db.database import Base
class OperatorCategory(Base):
"""算子分类模型"""
__tablename__ = "t_operator_category"
id = Column(Integer, primary_key=True, autoincrement=True, comment="分类ID")
name = Column(String(64), nullable=True, comment="分类名称")
type = Column(String(64), nullable=True, comment="分类类型")
parent_id = Column(Integer, nullable=True, comment="父分类ID")
def __repr__(self):
return f"<OperatorCategory(id={self.id}, name={self.name}, type={self.type})>"

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, String, Integer
from app.db.database import Base
class OperatorCategoryRelation(Base):
"""算子分类关联模型"""
__tablename__ = "t_operator_category_relation"
category_id = Column(Integer, primary_key=True, comment="分类ID")
operator_id = Column(String(64), primary_key=True, comment="算子ID")
def __repr__(self):
return f"<OperatorCategoryRelation(category_id={self.category_id}, operator_id={self.operator_id})>"

View File

@@ -0,0 +1,29 @@
# app/schemas/__init__.py
from .common import *
from .dataset_mapping import *
from .dm_service import *
from .label_studio import *
__all__ = [
# Common schemas
"StandardResponse",
# Dataset Mapping schemas
"DatasetMappingBase",
"DatasetMappingCreateRequest",
"DatasetMappingUpdateRequest",
"DatasetMappingResponse",
"DatasetMappingCreateResponse",
"SyncDatasetResponse",
"DeleteDatasetResponse",
# DM Service schemas
"DatasetFileResponse",
"PagedDatasetFileResponse",
"DatasetResponse",
# Label Studio schemas
"LabelStudioProject",
"LabelStudioTask"
]

View File

@@ -0,0 +1,63 @@
"""
通用响应模型
"""
from typing import Generic, TypeVar, Optional, List
from pydantic import BaseModel, Field
# 定义泛型类型变量
T = TypeVar('T')
# 定义一个将 snake_case 转换为 camelCase 的函数
def to_camel(string: str) -> str:
"""将 snake_case 字符串转换为 camelCase"""
components = string.split('_')
# 首字母小写,其余单词首字母大写
return components[0] + ''.join(x.title() for x in components[1:])
class BaseResponseModel(BaseModel):
"""基础响应模型,启用别名生成器"""
class Config:
populate_by_name = True
alias_generator = to_camel
class StandardResponse(BaseResponseModel, Generic[T]):
"""
标准API响应格式
所有API端点应返回此格式,确保响应的一致性
"""
code: int = Field(..., description="HTTP状态码")
message: str = Field(..., description="响应消息")
data: Optional[T] = Field(None, description="响应数据")
class Config:
populate_by_name = True
alias_generator = to_camel
json_schema_extra = {
"example": {
"code": 200,
"message": "success",
"data": {}
}
}
class PaginatedData(BaseResponseModel, Generic[T]):
"""分页数据容器"""
page: int = Field(..., description="当前页码(从1开始)")
size: int = Field(..., description="页大小")
total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据")
class Config:
json_schema_extra = {
"example": {
"page": 1,
"size": 20,
"totalElements": 100,
"totalPages": 5,
"content": []
}
}

View File

@@ -0,0 +1,56 @@
from pydantic import Field
from typing import Optional
from datetime import datetime
from .common import BaseResponseModel
class DatasetMappingBase(BaseResponseModel):
"""数据集映射 基础模型"""
dataset_id: str = Field(..., description="源数据集ID")
class DatasetMappingCreateRequest(DatasetMappingBase):
"""数据集映射 创建 请求模型"""
pass
class DatasetMappingCreateResponse(BaseResponseModel):
"""数据集映射 创建 响应模型"""
mapping_id: str = Field(..., description="映射UUID")
labelling_project_id: str = Field(..., description="Label Studio项目ID")
labelling_project_name: str = Field(..., description="Label Studio项目名称")
message: str = Field(..., description="响应消息")
class DatasetMappingUpdateRequest(BaseResponseModel):
"""数据集映射 更新 请求模型"""
dataset_id: Optional[str] = Field(None, description="源数据集ID")
class DatasetMappingResponse(DatasetMappingBase):
"""数据集映射 查询 响应模型"""
mapping_id: str = Field(..., description="映射UUID")
labelling_project_id: str = Field(..., description="标注项目ID")
labelling_project_name: Optional[str] = Field(None, description="标注项目名称")
created_at: datetime = Field(..., description="创建时间")
last_updated_at: datetime = Field(..., description="最后更新时间")
deleted_at: Optional[datetime] = Field(None, description="删除时间")
class Config:
from_attributes = True
populate_by_name = True
class SyncDatasetRequest(BaseResponseModel):
"""同步数据集请求模型"""
mapping_id: str = Field(..., description="映射ID(mapping UUID)")
batch_size: int = Field(50, ge=1, le=100, description="批处理大小")
class SyncDatasetResponse(BaseResponseModel):
"""同步数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="同步状态")
synced_files: int = Field(..., description="已同步文件数量")
total_files: int = Field(0, description="总文件数量")
message: str = Field(..., description="响应消息")
class DeleteDatasetResponse(BaseResponseModel):
"""删除数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="删除状态")
message: str = Field(..., description="响应消息")

View File

@@ -0,0 +1,58 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
class DatasetFileResponse(BaseModel):
"""DM服务数据集文件响应模型"""
id: str = Field(..., description="文件ID")
fileName: str = Field(..., description="文件名")
fileType: str = Field(..., description="文件类型")
filePath: str = Field(..., description="文件路径")
originalName: Optional[str] = Field(None, description="原始文件名")
size: Optional[int] = Field(None, description="文件大小(字节)")
status: Optional[str] = Field(None, description="文件状态")
uploadedAt: Optional[datetime] = Field(None, description="上传时间")
description: Optional[str] = Field(None, description="文件描述")
uploadedBy: Optional[str] = Field(None, description="上传者")
lastAccessTime: Optional[datetime] = Field(None, description="最后访问时间")
class PagedDatasetFileResponse(BaseModel):
"""DM服务分页文件响应模型"""
content: List[DatasetFileResponse] = Field(..., description="文件列表")
totalElements: int = Field(..., description="总元素数")
totalPages: int = Field(..., description="总页数")
page: int = Field(..., description="当前页码")
size: int = Field(..., description="每页大小")
class DatasetTypeResponse(BaseModel):
"""数据集类型响应模型"""
code: str = Field(..., description="类型编码")
name: str = Field(..., description="类型名称")
description: Optional[str] = Field(None, description="类型描述")
supportedFormats: List[str] = Field(default_factory=list, description="支持的文件格式")
icon: Optional[str] = Field(None, description="图标")
class DatasetResponse(BaseModel):
"""DM服务数据集响应模型"""
id: str = Field(..., description="数据集ID")
name: str = Field(..., description="数据集名称")
description: Optional[str] = Field(None, description="数据集描述")
datasetType: str = Field(..., description="数据集类型", alias="datasetType")
status: str = Field(..., description="数据集状态")
fileCount: int = Field(..., description="文件数量")
totalSize: int = Field(..., description="总大小(字节)")
createdAt: Optional[datetime] = Field(None, description="创建时间")
updatedAt: Optional[datetime] = Field(None, description="更新时间")
createdBy: Optional[str] = Field(None, description="创建者")
# 为了向后兼容,添加一个属性方法返回类型对象
@property
def type(self) -> DatasetTypeResponse:
"""兼容属性:返回类型对象"""
return DatasetTypeResponse(
code=self.datasetType,
name=self.datasetType,
description=None,
supportedFormats=[],
icon=None
)

View File

@@ -0,0 +1,38 @@
from pydantic import Field
from typing import Dict, Any, Optional, List
from datetime import datetime
from .common import BaseResponseModel
class LabelStudioProject(BaseResponseModel):
"""Label Studio项目模型"""
id: int = Field(..., description="项目ID")
title: str = Field(..., description="项目标题")
description: Optional[str] = Field(None, description="项目描述")
label_config: str = Field(..., description="标注配置")
created_at: Optional[datetime] = Field(None, description="创建时间")
updated_at: Optional[datetime] = Field(None, description="更新时间")
class LabelStudioTaskData(BaseResponseModel):
"""Label Studio任务数据模型"""
image: Optional[str] = Field(None, description="图像URL")
text: Optional[str] = Field(None, description="文本内容")
audio: Optional[str] = Field(None, description="音频URL")
video: Optional[str] = Field(None, description="视频URL")
filename: Optional[str] = Field(None, description="文件名")
class LabelStudioTask(BaseResponseModel):
"""Label Studio任务模型"""
data: LabelStudioTaskData = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")
meta: Optional[Dict[str, Any]] = Field(None, description="元数据")
class LabelStudioCreateProjectRequest(BaseResponseModel):
"""创建Label Studio项目请求模型"""
title: str = Field(..., description="项目标题")
description: str = Field("", description="项目描述")
label_config: str = Field(..., description="标注配置")
class LabelStudioCreateTaskRequest(BaseResponseModel):
"""创建Label Studio任务请求模型"""
data: Dict[str, Any] = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")

View File

@@ -0,0 +1,6 @@
# app/services/__init__.py
from .dataset_mapping_service import DatasetMappingService
from .sync_service import SyncService
__all__ = ["DatasetMappingService", "SyncService"]

View File

@@ -0,0 +1,298 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, func
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
from app.models.dm.labeling_project import LabelingProject
from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
)
from app.core.logging import get_logger
logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_mapping(
self,
mapping_data: DatasetMappingCreateRequest,
labelling_project_id: str,
labelling_project_name: str
) -> DatasetMappingResponse:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {mapping_data.dataset_id} -> {labelling_project_id}")
db_mapping = LabelingProject(
mapping_id=str(uuid.uuid4()),
dataset_id=mapping_data.dataset_id,
labelling_project_id=labelling_project_id,
labelling_project_name=labelling_project_name
)
self.db.add(db_mapping)
await self.db.commit()
await self.db.refresh(db_mapping)
logger.info(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping)
async def get_mapping_by_source_uuid(
self,
dataset_id: str
) -> Optional[DatasetMappingResponse]:
"""根据源数据集ID获取映射(返回第一个未删除的)"""
logger.debug(f"Get mapping by source dataset id: {dataset_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None
async def get_mappings_by_dataset_id(
self,
dataset_id: str,
include_deleted: bool = False
) -> List[DatasetMappingResponse]:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(
query.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def get_mapping_by_labelling_project_id(
self,
labelling_project_id: str
) -> Optional[DatasetMappingResponse]:
"""根据Label Studio项目ID获取映射"""
logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.labeling_project_id == labelling_project_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labelling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
result = await self.db.execute(
select(LabelingProject).where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
return None
async def update_mapping(
self,
mapping_id: str,
update_data: DatasetMappingUpdateRequest
) -> Optional[DatasetMappingResponse]:
"""更新映射信息"""
logger.info(f"Update mapping: {mapping_id}")
mapping = await self.get_mapping_by_uuid(mapping_id)
if not mapping:
return None
update_values = update_data.model_dump(exclude_unset=True)
update_values["last_updated_at"] = datetime.now()
result = await self.db.execute(
update(LabelingProject)
.where(LabelingProject.id == mapping_id)
.values(**update_values)
)
await self.db.commit()
if result.rowcount > 0:
return await self.get_mapping_by_uuid(mapping_id)
return None
async def update_last_updated_at(self, mapping_id: str) -> bool:
"""更新最后更新时间"""
logger.debug(f"Update mapping last updated at: {mapping_id}")
result = await self.db.execute(
update(LabelingProject)
.where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
.values(last_updated_at=datetime.utcnow())
)
await self.db.commit()
return result.rowcount > 0
async def soft_delete_mapping(self, mapping_id: str) -> bool:
"""软删除映射"""
logger.info(f"Soft delete mapping: {mapping_id}")
result = await self.db.execute(
update(LabelingProject)
.where(
LabelingProject.id == mapping_id,
LabelingProject.deleted_at.is_(None)
)
.values(deleted_at=datetime.now())
)
await self.db.commit()
success = result.rowcount > 0
if success:
logger.info(f"Mapping soft-deleted: {mapping_id}")
else:
logger.warning(f"Mapping not exists or already deleted: {mapping_id}")
return success
async def get_all_mappings(
self,
skip: int = 0,
limit: int = 100
) -> List[DatasetMappingResponse]:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
result = await self.db.execute(
select(LabelingProject)
.where(LabelingProject.deleted_at.is_(None))
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one()
async def get_all_mappings_with_count(
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(LabelingProject)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
async def get_mappings_by_source_with_count(
self,
dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询
query = select(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(LabelingProject.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total

View File

@@ -0,0 +1,275 @@
from typing import Optional, List, Dict, Any, Tuple
from app.infrastructure import LabelStudioClient, DatamateClient
from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import SyncDatasetResponse
from app.core.logging import get_logger
from app.core.config import settings
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
logger = get_logger(__name__)
class SyncService:
"""数据同步服务"""
def __init__(
self,
dm_client: DatamateClient,
ls_client: LabelStudioClient,
mapping_service: DatasetMappingService
):
self.dm_client = dm_client
self.ls_client = ls_client
self.mapping_service = mapping_service
def determine_data_type(self, file_type: str) -> str:
"""根据文件类型确定数据类型"""
file_type_lower = file_type.lower()
if any(ext in file_type_lower for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg', 'webp']):
return 'image'
elif any(ext in file_type_lower for ext in ['mp3', 'wav', 'flac', 'aac', 'ogg']):
return 'audio'
elif any(ext in file_type_lower for ext in ['mp4', 'avi', 'mov', 'wmv', 'flv', 'webm']):
return 'video'
elif any(ext in file_type_lower for ext in ['txt', 'doc', 'docx', 'pdf']):
return 'text'
else:
return 'image' # 默认为图像类型
async def get_existing_dm_file_mapping(self, project_id: str) -> Dict[str, int]:
"""
获取Label Studio项目中已存在的DM文件ID到任务ID的映射
Args:
project_id: Label Studio项目ID
Returns:
dm_file_id到task_id的映射字典
"""
try:
logger.info(f"Fetching existing task mappings for project {project_id} (page_size={settings.ls_task_page_size})")
dm_file_to_task_mapping = {}
# 使用Label Studio客户端封装的方法获取所有任务
page_size = getattr(settings, 'ls_task_page_size', 1000)
# 调用封装好的方法获取所有任务,page=None表示获取全部
result = await self.ls_client.get_project_tasks(
project_id=project_id,
page=None, # 不指定page,获取所有任务
page_size=page_size
)
if not result:
logger.warning(f"Failed to fetch tasks for project {project_id}")
return {}
all_tasks = result.get("tasks", [])
# 遍历所有任务,构建映射
for task in all_tasks:
# 检查任务的meta字段中是否有dm_file_id
meta = task.get('meta')
if meta:
dm_file_id = meta.get('dm_file_id')
if dm_file_id:
task_id = task.get('id')
if task_id:
dm_file_to_task_mapping[str(dm_file_id)] = task_id
logger.info(f"Found {len(dm_file_to_task_mapping)} existing task mappings")
return dm_file_to_task_mapping
except Exception as e:
logger.error(f"Error while fetching existing tasks: {e}")
return {} # 发生错误时返回空字典,会同步所有文件
async def sync_dataset_files(
self,
mapping_id: str,
batch_size: int = 50
) -> SyncDatasetResponse:
"""同步数据集文件到Label Studio"""
logger.info(f"Start syncing dataset by mapping: {mapping_id}")
# 获取映射关系
mapping = await self.mapping_service.get_mapping_by_uuid(mapping_id)
if not mapping:
logger.error(f"Dataset mapping not found: {mapping_id}")
return SyncDatasetResponse(
mapping_id="",
status="error",
synced_files=0,
total_files=0,
message=f"Dataset mapping not found: {mapping_id}"
)
try:
# 获取数据集信息
dataset_info = await self.dm_client.get_dataset(mapping.dataset_id)
if not dataset_info:
raise NoDatasetInfoFoundError(mapping.dataset_id)
synced_files = 0
deleted_tasks = 0
total_files = dataset_info.fileCount
page = 0
logger.info(f"Total files in dataset: {total_files}")
# 获取Label Studio中已存在的DM文件ID到任务ID的映射
existing_dm_file_mapping = await self.get_existing_dm_file_mapping(mapping.labelling_project_id)
existing_dm_file_ids = set(existing_dm_file_mapping.keys())
logger.info(f"{len(existing_dm_file_ids)} tasks already exist in Label Studio")
# 收集DM中当前存在的所有文件ID
current_dm_file_ids = set()
# 分页获取并同步文件
while True:
files_response = await self.dm_client.get_dataset_files(
mapping.dataset_id,
page=page,
size=batch_size,
status="COMPLETED" # 只同步已完成的文件
)
if not files_response or not files_response.content:
logger.info(f"No more files on page {page + 1}")
break
logger.info(f"Processing page {page + 1}, total {len(files_response.content)} files")
# 筛选出新文件并批量创建任务
tasks = []
new_files_count = 0
existing_files_count = 0
for file_info in files_response.content:
# 记录当前DM中存在的文件ID
current_dm_file_ids.add(str(file_info.id))
# 检查文件是否已存在
if str(file_info.id) in existing_dm_file_ids:
existing_files_count += 1
logger.debug(f"Skip existing file: {file_info.originalName} (ID: {file_info.id})")
continue
new_files_count += 1
# 确定数据类型
data_type = self.determine_data_type(file_info.fileType)
# 替换文件路径前缀:只替换开头的前缀,不影响路径中间可能出现的相同字符串
file_path = file_info.filePath.removeprefix(settings.dm_file_path_prefix)
file_path = settings.label_studio_file_path_prefix + file_path
# 构造任务数据
task_data = {
"data": {
data_type: file_path
},
"meta": {
"file_size": file_info.size,
"file_type": file_info.fileType,
"dm_dataset_id": mapping.dataset_id,
"dm_file_id": file_info.id,
"original_name": file_info.originalName,
}
}
tasks.append(task_data)
logger.info(f"Page {page + 1}: new files {new_files_count}, existing files {existing_files_count}")
# 批量创建Label Studio任务
if tasks:
batch_result = await self.ls_client.create_tasks_batch(
mapping.labelling_project_id,
tasks
)
if batch_result:
synced_files += len(tasks)
logger.info(f"Successfully synced {len(tasks)} files")
else:
logger.warning(f"Batch task creation failed, fallback to single creation")
# 如果批量创建失败,尝试单个创建
for task_data in tasks:
task_result = await self.ls_client.create_task(
mapping.labelling_project_id,
task_data["data"],
task_data.get("meta")
)
if task_result:
synced_files += 1
# 检查是否还有更多页面
if page >= files_response.totalPages - 1:
break
page += 1
# 清理在DM中不存在但在Label Studio中存在的任务
tasks_to_delete = []
for dm_file_id, task_id in existing_dm_file_mapping.items():
if dm_file_id not in current_dm_file_ids:
tasks_to_delete.append(task_id)
logger.debug(f"Mark task for deletion: {task_id} (DM file ID: {dm_file_id})")
if tasks_to_delete:
logger.info(f"Deleting {len(tasks_to_delete)} tasks not present in DM")
delete_result = await self.ls_client.delete_tasks_batch(tasks_to_delete)
deleted_tasks = delete_result.get("successful", 0)
logger.info(f"Successfully deleted {deleted_tasks} tasks")
else:
logger.info("No tasks to delete")
# 更新映射的最后更新时间
await self.mapping_service.update_last_updated_at(mapping.mapping_id)
logger.info(f"Sync completed: total_files={total_files}, created={synced_files}, deleted={deleted_tasks}")
return SyncDatasetResponse(
mapping_id=mapping.mapping_id,
status="success",
synced_files=synced_files,
total_files=total_files,
message=f"Sync completed: created {synced_files} files, deleted {deleted_tasks} tasks"
)
except Exception as e:
logger.error(f"Error while syncing dataset: {e}")
return SyncDatasetResponse(
mapping_id=mapping.mapping_id,
status="error",
synced_files=0,
total_files=0,
message=f"Sync failed: {str(e)}"
)
async def get_sync_status(
self,
dataset_id: str
) -> Optional[Dict[str, Any]]:
"""获取同步状态"""
mapping = await self.mapping_service.get_mapping_by_source_uuid(dataset_id)
if not mapping:
return None
# 获取DM数据集信息
dataset_info = await self.dm_client.get_dataset(dataset_id)
# 获取Label Studio项目任务数量
tasks_info = await self.ls_client.get_project_tasks(mapping.labelling_project_id)
return {
"mapping_id": mapping.mapping_id,
"dataset_id": dataset_id,
"labelling_project_id": mapping.labelling_project_id,
"last_updated_at": mapping.last_updated_at,
"dm_total_files": dataset_info.fileCount if dataset_info else 0,
"ls_total_tasks": tasks_info.get("count", 0) if tasks_info else 0,
"sync_ratio": (
tasks_info.get("count", 0) / dataset_info.fileCount
if dataset_info and dataset_info.fileCount > 0 and tasks_info else 0
)
}