Develop labeling module (#25)

* refactor: remove db table management from LS adapter (mv to scripts later); change adapter to use the same MySQL DB as other modules.

* refactor: Rename LS Adapter module to datamate-python
This commit is contained in:
Jinglong Wang
2025-10-27 16:16:14 +08:00
committed by GitHub
parent 46dfb389f1
commit 7f819563db
69 changed files with 1104 additions and 703 deletions

View File

@@ -53,9 +53,6 @@ LS_TASK_PAGE_SIZE=1000
# ========================= # =========================
# Data Management 服务配置 # Data Management 服务配置
# ========================= # =========================
# DM 服务地址
DM_SERVICE_BASE_URL=http://data-engine:8080
# DM 存储文件夹前缀(通常与 Label Studio 的 local-files 文件夹映射一致) # DM 存储文件夹前缀(通常与 Label Studio 的 local-files 文件夹映射一致)
DM_FILE_PATH_PREFIX=/ DM_FILE_PATH_PREFIX=/

View File

@@ -0,0 +1,86 @@
# Label Studio Adapter (DataMate)
这是 DataMate 的 Label Studio Adapter 服务,负责将 DataMate 的项目与 Label Studio 同步并提供对外的 HTTP API(基于 FastAPI)。
## 简要说明
- 框架:FastAPI
- 异步数据库/ORM:SQLAlchemy (async)
- 数据库迁移:Alembic
- 运行器:uvicorn
## 快速开始(开发)
1. 克隆仓库并进入项目目录
2. 创建并激活虚拟环境:
```bash
python -m venv .venv
source .venv/bin/activate
```
3. 安装依赖:
```bash
pip install -r requirements.txt
```
4. 准备环境变量(示例)
创建 `.env` 并设置必要的变量,例如:
- DATABASE_URL(或根据项目配置使用具体变量)
- LABEL_STUDIO_BASE_URL
- LABEL_STUDIO_USER_TOKEN
(具体变量请参考 `.env.example`
5. 数据库迁移(开发环境):
```bash
alembic upgrade head
```
6. 启动开发服务器(示例与常用参数):
- 本地开发(默认 host/port,自动重载):
```bash
uvicorn app.main:app --reload
```
- 指定主机与端口并打开调试日志:
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug
```
- 在生产环境使用多个 worker(不使用 --reload):
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4 --log-level info --proxy-headers
```
- 使用环境变量启动(示例):
```bash
HOST=0.0.0.0 PORT=8000 uvicorn app.main:app --reload
```
注意:
- `--reload` 仅用于开发,会监视文件变化并重启进程;不要在生产中使用。
- `--workers` 提供并发处理能力,但会增加内存占用;生产时通常配合进程管理或容器编排(Kubernetes)使用。
- 若需要完整的生产部署建议使用 ASGI 服务器(如 gunicorn + uvicorn workers / 或直接使用 uvicorn 在容器中配合进程管理)。
访问 API 文档:
- Swagger UI: http://127.0.0.1:8000/docs
- ReDoc: http://127.0.0.1:8000/redoc (推荐使用)
## 使用(简要)
- 所有 API 路径以 `/api` 前缀注册(见 `app/main.py``app.include_router(api_router, prefix="/api")`)。
- 根路径 `/` 返回服务信息和文档链接。
更多细节请查看 `doc/usage.md`(接口使用)和 `doc/development.md`(开发说明)。

View File

@@ -4,7 +4,7 @@ from typing import Optional
from app.db.database import get_db from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService from app.services.dataset_mapping_service import DatasetMappingService
from app.clients import get_clients from app.infrastructure import DatamateClient, LabelStudioClient
from app.schemas.dataset_mapping import ( from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest, DatasetMappingCreateRequest,
DatasetMappingCreateResponse, DatasetMappingCreateResponse,
@@ -30,18 +30,19 @@ async def create_dataset_mapping(
注意一个数据集可以创建多个标注项目 注意一个数据集可以创建多个标注项目
""" """
try: try:
# 获取全局客户端实例 dm_client = DatamateClient(db)
dm_client_instance, ls_client_instance = get_clients() ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
service = DatasetMappingService(db) service = DatasetMappingService(db)
logger.info(f"Create dataset mapping request: {request.source_dataset_id}") logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息 # 从DM服务获取数据集信息
dataset_info = await dm_client_instance.get_dataset(request.source_dataset_id) dataset_info = await dm_client.get_dataset(request.dataset_id)
if not dataset_info: if not dataset_info:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Dataset not found in DM service: {request.source_dataset_id}" detail=f"Dataset not found in DM service: {request.dataset_id}"
) )
# 确定数据类型(基于数据集类型) # 确定数据类型(基于数据集类型)
@@ -55,11 +56,10 @@ async def create_dataset_mapping(
elif "text" in type_code: elif "text" in type_code:
data_type = "text" data_type = "text"
# 生成项目名称
project_name = f"{dataset_info.name}" project_name = f"{dataset_info.name}"
# 在Label Studio中创建项目 # 在Label Studio中创建项目
project_data = await ls_client_instance.create_project( project_data = await ls_client.create_project(
title=project_name, title=project_name,
description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}", description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}",
data_type=data_type data_type=data_type
@@ -74,8 +74,8 @@ async def create_dataset_mapping(
project_id = project_data["id"] project_id = project_data["id"]
# 配置本地存储:dataset/<id> # 配置本地存储:dataset/<id>
local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.source_dataset_id}" local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.dataset_id}"
storage_result = await ls_client_instance.create_local_storage( storage_result = await ls_client.create_local_storage(
project_id=project_id, project_id=project_id,
path=local_storage_path, path=local_storage_path,
title="Dataset_BLOB", title="Dataset_BLOB",
@@ -85,7 +85,7 @@ async def create_dataset_mapping(
# 配置本地存储:upload # 配置本地存储:upload
local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}" local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}"
storage_result = await ls_client_instance.create_local_storage( storage_result = await ls_client.create_local_storage(
project_id=project_id, project_id=project_id,
path=local_storage_path, path=local_storage_path,
title="Upload_BLOB", title="Upload_BLOB",
@@ -107,7 +107,7 @@ async def create_dataset_mapping(
) )
logger.debug( logger.debug(
f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.source_dataset_id} <> L {mapping.labelling_project_id}" f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.dataset_id} <> L {mapping.labelling_project_id}"
) )
response_data = DatasetMappingCreateResponse( response_data = DatasetMappingCreateResponse(

View File

@@ -1,13 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional from typing import Optional
from app.db.database import get_db from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService from app.services.dataset_mapping_service import DatasetMappingService
from app.clients import get_clients from app.infrastructure import DatamateClient, LabelStudioClient
from app.schemas.dataset_mapping import DeleteDatasetResponse from app.schemas.dataset_mapping import DeleteDatasetResponse
from app.schemas import StandardResponse from app.schemas import StandardResponse
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.config import settings
from . import project_router from . import project_router
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -37,39 +39,39 @@ async def delete_mapping(
status_code=400, status_code=400,
detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided" detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided"
) )
# 获取全局客户端实例 ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
dm_client_instance, ls_client_instance = get_clients() token=settings.label_studio_user_token)
service = DatasetMappingService(db) service = DatasetMappingService(db)
mapping = None
# 优先使用 mapping_id 查询 # 优先使用 mapping_id 查询
if m: if m:
logger.info(f"Deleting by mapping UUID: {m}") logger.debug(f"Deleting by mapping UUID: {m}")
mapping = await service.get_mapping_by_uuid(m) mapping = await service.get_mapping_by_uuid(m)
# 如果没有提供 m,使用 proj 查询 # 如果没有提供 m,使用 proj 查询
elif proj: elif proj:
logger.info(f"Deleting by project ID: {proj}") logger.debug(f"Deleting by project ID: {proj}")
mapping = await service.get_mapping_by_labelling_project_id(proj) mapping = await service.get_mapping_by_labelling_project_id(proj)
else:
mapping = None
if not mapping: if not mapping:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Mapping not found" detail=f"Mapping either not found or not specified."
) )
mapping_id = mapping.mapping_id mapping_id = mapping.mapping_id
labelling_project_id = mapping.labelling_project_id labelling_project_id = mapping.labelling_project_id
labelling_project_name = mapping.labelling_project_name labelling_project_name = mapping.labelling_project_name
logger.info(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}") logger.debug(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}")
# 1. 删除 Label Studio 项目 # 1. 删除 Label Studio 项目
try: try:
delete_success = await ls_client_instance.delete_project(int(labelling_project_id)) delete_success = await ls_client.delete_project(int(labelling_project_id))
if delete_success: if delete_success:
logger.info(f"Successfully deleted Label Studio project: {labelling_project_id}") logger.debug(f"Successfully deleted Label Studio project: {labelling_project_id}")
else: else:
logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}") logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}")
except Exception as e: except Exception as e:
@@ -84,19 +86,17 @@ async def delete_mapping(
status_code=500, status_code=500,
detail="Failed to delete mapping record" detail="Failed to delete mapping record"
) )
logger.info(f"Successfully deleted mapping: {mapping_id}") logger.info(f"Successfully deleted mapping: {mapping_id}, Label Studio project: {labelling_project_id}")
response_data = DeleteDatasetResponse(
mapping_id=mapping_id,
status="success",
message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'"
)
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=response_data data=DeleteDatasetResponse(
mapping_id=mapping_id,
status="success",
message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'"
)
) )
except HTTPException: except HTTPException:

View File

@@ -98,9 +98,9 @@ async def get_mapping(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) @project_router.get("/mappings/by-source/{dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def get_mappings_by_source( async def get_mappings_by_source(
source_dataset_id: str, dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"), page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"), page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@@ -116,11 +116,11 @@ async def get_mappings_by_source(
# 计算 skip # 计算 skip
skip = (page - 1) * page_size skip = (page - 1) * page_size
logger.info(f"Get mappings by source dataset id: {source_dataset_id}, page={page}, page_size={page_size}") logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}")
# 获取数据和总数 # 获取数据和总数
mappings, total = await service.get_mappings_by_source_with_count( mappings, total = await service.get_mappings_by_source_with_count(
source_dataset_id=source_dataset_id, dataset_id=dataset_id,
skip=skip, skip=skip,
limit=page_size limit=page_size
) )

View File

@@ -5,7 +5,7 @@ from typing import List, Optional
from app.db.database import get_db from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService from app.services.dataset_mapping_service import DatasetMappingService
from app.services.sync_service import SyncService from app.services.sync_service import SyncService
from app.clients import get_clients from app.infrastructure import DatamateClient, LabelStudioClient
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
from app.schemas.dataset_mapping import ( from app.schemas.dataset_mapping import (
DatasetMappingResponse, DatasetMappingResponse,
@@ -14,6 +14,7 @@ from app.schemas.dataset_mapping import (
) )
from app.schemas import StandardResponse from app.schemas import StandardResponse
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.config import settings
from . import project_router from . import project_router
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -30,10 +31,12 @@ async def sync_dataset_content(
在数据库中记录更新时间返回更新状态 在数据库中记录更新时间返回更新状态
""" """
try: try:
dm_client_instance, ls_client_instance = get_clients() ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
dm_client = DatamateClient(db)
mapping_service = DatasetMappingService(db) mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client_instance, ls_client_instance, mapping_service) sync_service = SyncService(dm_client, ls_client, mapping_service)
logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}") logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}")
# 根据 mapping_id 获取映射关系 # 根据 mapping_id 获取映射关系

View File

@@ -27,7 +27,6 @@ async def get_config():
data={ data={
"app_name": settings.app_name, "app_name": settings.app_name,
"version": settings.app_version, "version": settings.app_version,
"dm_service_url": settings.dm_service_base_url,
"label_studio_url": settings.label_studio_base_url, "label_studio_url": settings.label_studio_base_url,
"debug": settings.debug "debug": settings.debug
} }

View File

@@ -73,7 +73,6 @@ class Settings(BaseSettings):
# ========================= # =========================
# Data Management 服务配置 # Data Management 服务配置
# ========================= # =========================
dm_service_base_url: str = "http://data-engine"
dm_file_path_prefix: str = "/" # DM存储文件夹前缀 dm_file_path_prefix: str = "/" # DM存储文件夹前缀

View File

@@ -0,0 +1,6 @@
# app/clients/__init__.py
from .label_studio import Client as LabelStudioClient
from .datamate import Client as DatamateClient
__all__ = ["LabelStudioClient", "DatamateClient"]

View File

@@ -0,0 +1,159 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import func
from typing import Optional
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.dm_service import DatasetResponse, PagedDatasetFileResponse, DatasetFileResponse
from app.models.dm.dataset import Dataset
from app.models.dm.dataset_files import DatasetFiles
logger = get_logger(__name__)
class Client:
"""数据管理服务客户端 - 直接访问数据库"""
def __init__(self, db: AsyncSession):
"""
初始化 DM 客户端
Args:
db: 数据库会话
"""
self.db = db
logger.info("Initialize DM service client (Database mode)")
async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]:
"""获取数据集详情"""
try:
logger.info(f"Getting dataset detail: {dataset_id} ...")
result = await self.db.execute(
select(Dataset).where(Dataset.id == dataset_id)
)
dataset = result.scalar_one_or_none()
if not dataset:
logger.error(f"Dataset not found: {dataset_id}")
return None
# 将数据库模型转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
return DatasetResponse(
id=dataset.id, # type: ignore
name=dataset.name, # type: ignore
description=dataset.description or "", # type: ignore
datasetType=dataset.dataset_type, # type: ignore
status=dataset.status, # type: ignore
fileCount=dataset.file_count or 0, # type: ignore
totalSize=dataset.size_bytes or 0, # type: ignore
createdAt=dataset.created_at, # type: ignore
updatedAt=dataset.updated_at, # type: ignore
createdBy=dataset.created_by # type: ignore
)
except Exception as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
) -> Optional[PagedDatasetFileResponse]:
"""获取数据集文件列表"""
try:
logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
# 构建查询
query = select(DatasetFiles).where(DatasetFiles.dataset_id == dataset_id)
# 添加可选过滤条件
if file_type:
query = query.where(DatasetFiles.file_type == file_type)
if status:
query = query.where(DatasetFiles.status == status)
# 获取总数
count_query = select(func.count()).select_from(DatasetFiles).where(
DatasetFiles.dataset_id == dataset_id
)
if file_type:
count_query = count_query.where(DatasetFiles.file_type == file_type)
if status:
count_query = count_query.where(DatasetFiles.status == status)
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 分页查询
query = query.offset(page * size).limit(size).order_by(DatasetFiles.created_at.desc())
result = await self.db.execute(query)
files = result.scalars().all()
# 转换为响应模型
# type: ignore 用于忽略 SQLAlchemy 的类型检查问题
content = [
DatasetFileResponse(
id=f.id, # type: ignore
fileName=f.file_name, # type: ignore
fileType=f.file_type or "", # type: ignore
filePath=f.file_path, # type: ignore
originalName=f.file_name, # type: ignore
size=f.file_size, # type: ignore
status=f.status, # type: ignore
uploadedAt=f.upload_time, # type: ignore
description=None,
uploadedBy=None,
lastAccessTime=f.last_access_time # type: ignore
)
for f in files
]
total_pages = (total + size - 1) // size if size > 0 else 0
return PagedDatasetFileResponse(
content=content,
totalElements=total,
totalPages=total_pages,
page=page,
size=size
)
except Exception as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""
下载文件内容
注意:此方法保留接口兼容性,但实际文件下载可能需要通过文件系统或对象存储
"""
logger.warning(f"download_file is deprecated when using database mode. Use get_file_download_url instead.")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> Optional[str]:
"""获取文件下载URL(或文件路径)"""
try:
result = await self.db.execute(
select(DatasetFiles).where(
DatasetFiles.id == file_id,
DatasetFiles.dataset_id == dataset_id
)
)
file = result.scalar_one_or_none()
if not file:
logger.error(f"File not found: {file_id} in dataset {dataset_id}")
return None
# 返回文件路径(可以是本地路径或对象存储URL)
return file.file_path # type: ignore
except Exception as e:
logger.error(f"Failed to get file path for {file_id}: {e}")
return None
async def close(self):
"""关闭客户端连接(数据库模式下无需操作)"""
logger.info("DM service client closed (Database mode)")

View File

@@ -12,7 +12,7 @@ from app.schemas.label_studio import (
logger = get_logger(__name__) logger = get_logger(__name__)
class LabelStudioClient: class Client:
"""Label Studio服务客户端 """Label Studio服务客户端
使用 HTTP REST API 直接与 Label Studio 交互 使用 HTTP REST API 直接与 Label Studio 交互

View File

@@ -8,7 +8,7 @@ from typing import Dict, Any
from .core.config import settings from .core.config import settings
from .core.logging import setup_logging, get_logger from .core.logging import setup_logging, get_logger
from .clients import DMServiceClient, LabelStudioClient, set_clients from .infrastructure import LabelStudioClient
from .api import api_router from .api import api_router
from .schemas import StandardResponse from .schemas import StandardResponse
@@ -23,23 +23,12 @@ async def lifespan(app: FastAPI):
# 启动时初始化 # 启动时初始化
logger.info("Starting Label Studio Adapter...") logger.info("Starting Label Studio Adapter...")
# 初始化客户端
dm_client = DMServiceClient()
# 初始化 Label Studio 客户端,使用 HTTP REST API + Token 认证 # 初始化 Label Studio 客户端,使用 HTTP REST API + Token 认证
ls_client = LabelStudioClient( ls_client = LabelStudioClient(
base_url=settings.label_studio_base_url, base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token token=settings.label_studio_user_token
) )
# 设置全局客户端
set_clients(dm_client, ls_client)
# 数据库初始化由 Alembic 管理
# 在 Docker 环境中,entrypoint.sh 会在启动前运行: alembic upgrade head
# 在开发环境中,手动运行: alembic upgrade head
logger.info("Database schema managed by Alembic")
logger.info("Label Studio Adapter started") logger.info("Label Studio Adapter started")
yield yield
@@ -155,7 +144,6 @@ async def root():
"message": f"{settings.app_name} is running", "message": f"{settings.app_name} is running",
"version": settings.app_version, "version": settings.app_version,
"docs_url": "/docs", "docs_url": "/docs",
"dm_service_url": settings.dm_service_base_url,
"label_studio_url": settings.label_studio_base_url "label_studio_url": settings.label_studio_base_url
} }
) )

View File

@@ -0,0 +1,138 @@
# DataMate 数据模型结构
本文档列出了根据 `scripts/db` 中的 SQL 文件创建的所有 Python 数据模型。
## 模型组织结构
```
app/models/
├── __init__.py # 主模块导出文件
├── dm/ # 数据管理 (Data Management) 模块
│ ├── __init__.py
│ ├── annotation_template.py # 标注模板
│ ├── labeling_project.py # 标注项目
│ ├── dataset.py # 数据集
│ ├── dataset_files.py # 数据集文件
│ ├── dataset_statistics.py # 数据集统计
│ ├── dataset_tag.py # 数据集标签关联
│ ├── tag.py # 标签
│ └── user.py # 用户
├── cleaning/ # 数据清洗 (Data Cleaning) 模块
│ ├── __init__.py
│ ├── clean_template.py # 清洗模板
│ ├── clean_task.py # 清洗任务
│ ├── operator_instance.py # 算子实例
│ └── clean_result.py # 清洗结果
├── collection/ # 数据归集 (Data Collection) 模块
│ ├── __init__.py
│ ├── task_execution.py # 任务执行明细
│ ├── collection_task.py # 数据归集任务
│ ├── task_log.py # 任务执行记录
│ └── datax_template.py # DataX模板配置
├── common/ # 通用 (Common) 模块
│ ├── __init__.py
│ └── chunk_upload_request.py # 文件切片上传请求
└── operator/ # 算子 (Operator) 模块
├── __init__.py
├── operator.py # 算子
├── operator_category.py # 算子分类
└── operator_category_relation.py # 算子分类关联
```
## 模块详情
### 1. Data Management (DM) 模块
对应 SQL: `data-management-init.sql``data-annotation-init.sql`
#### 模型列表:
- **AnnotationTemplate** (`t_dm_annotation_templates`) - 标注模板
- **LabelingProject** (`t_dm_labeling_projects`) - 标注项目
- **Dataset** (`t_dm_datasets`) - 数据集(支持医学影像、文本、问答等多种类型)
- **DatasetFiles** (`t_dm_dataset_files`) - 数据集文件
- **DatasetStatistics** (`t_dm_dataset_statistics`) - 数据集统计信息
- **Tag** (`t_dm_tags`) - 标签
- **DatasetTag** (`t_dm_dataset_tags`) - 数据集标签关联
- **User** (`users`) - 用户
### 2. Data Cleaning 模块
对应 SQL: `data-cleaning-init.sql`
#### 模型列表:
- **CleanTemplate** (`t_clean_template`) - 清洗模板
- **CleanTask** (`t_clean_task`) - 清洗任务
- **OperatorInstance** (`t_operator_instance`) - 算子实例
- **CleanResult** (`t_clean_result`) - 清洗结果
### 3. Data Collection (DC) 模块
对应 SQL: `data-collection-init.sql`
#### 模型列表:
- **TaskExecution** (`t_dc_task_executions`) - 任务执行明细
- **CollectionTask** (`t_dc_collection_tasks`) - 数据归集任务
- **TaskLog** (`t_dc_task_log`) - 任务执行记录
- **DataxTemplate** (`t_dc_datax_templates`) - DataX模板配置
### 4. Common 模块
对应 SQL: `data-common-init.sql`
#### 模型列表:
- **ChunkUploadRequest** (`t_chunk_upload_request`) - 文件切片上传请求
### 5. Operator 模块
对应 SQL: `data-operator-init.sql`
#### 模型列表:
- **Operator** (`t_operator`) - 算子
- **OperatorCategory** (`t_operator_category`) - 算子分类
- **OperatorCategoryRelation** (`t_operator_category_relation`) - 算子分类关联
## 使用方式
```python
# 导入所有模型
from app.models import (
# DM 模块
AnnotationTemplate,
LabelingProject,
Dataset,
DatasetFiles,
DatasetStatistics,
DatasetTag,
Tag,
User,
# Cleaning 模块
CleanTemplate,
CleanTask,
OperatorInstance,
CleanResult,
# Collection 模块
TaskExecution,
CollectionTask,
TaskLog,
DataxTemplate,
# Common 模块
ChunkUploadRequest,
# Operator 模块
Operator,
OperatorCategory,
OperatorCategoryRelation
)
# 或者按模块导入
from app.models.dm import Dataset, DatasetFiles
from app.models.collection import CollectionTask
from app.models.operator import Operator
```
## 注意事项
1. **UUID 主键**: 大部分表使用 UUID (String(36)) 作为主键
2. **时间戳**: 使用 `TIMESTAMP` 类型,并配置自动更新
3. **软删除**: 部分模型(如 AnnotationTemplate, LabelingProject)支持软删除,包含 `deleted_at` 字段和 `is_deleted` 属性
4. **JSON 字段**: 配置信息、元数据等使用 JSON 类型存储
5. **字段一致性**: 所有模型字段都严格按照 SQL 定义创建,确保与数据库表结构完全一致
## 更新记录
- 2025-10-25: 根据 `scripts/db` 中的 SQL 文件创建所有数据模型
- 已更新现有的 `annotation_template.py``labeling_project.py``dataset_files.py` 以匹配 SQL 定义

View File

@@ -0,0 +1,69 @@
# app/models/__init__.py
# Data Management (DM) 模块
from .dm import (
AnnotationTemplate,
LabelingProject,
Dataset,
DatasetFiles,
DatasetStatistics,
DatasetTag,
Tag,
User
)
# Data Cleaning 模块
from .cleaning import (
CleanTemplate,
CleanTask,
OperatorInstance,
CleanResult
)
# Data Collection (DC) 模块
from .collection import (
TaskExecution,
CollectionTask,
TaskLog,
DataxTemplate
)
# Common 模块
from .common import (
ChunkUploadRequest
)
# Operator 模块
from .operator import (
Operator,
OperatorCategory,
OperatorCategoryRelation
)
__all__ = [
# DM 模块
"AnnotationTemplate",
"LabelingProject",
"Dataset",
"DatasetFiles",
"DatasetStatistics",
"DatasetTag",
"Tag",
"User",
# Cleaning 模块
"CleanTemplate",
"CleanTask",
"OperatorInstance",
"CleanResult",
# Collection 模块
"TaskExecution",
"CollectionTask",
"TaskLog",
"DataxTemplate",
# Common 模块
"ChunkUploadRequest",
# Operator 模块
"Operator",
"OperatorCategory",
"OperatorCategoryRelation"
]

View File

@@ -0,0 +1,13 @@
# app/models/cleaning/__init__.py
from .clean_template import CleanTemplate
from .clean_task import CleanTask
from .operator_instance import OperatorInstance
from .clean_result import CleanResult
__all__ = [
"CleanTemplate",
"CleanTask",
"OperatorInstance",
"CleanResult"
]

View File

@@ -0,0 +1,22 @@
from sqlalchemy import Column, String, BigInteger, Text
from app.db.database import Base
class CleanResult(Base):
"""清洗结果模型"""
__tablename__ = "t_clean_result"
instance_id = Column(String(64), primary_key=True, comment="实例ID")
src_file_id = Column(String(64), nullable=True, comment="源文件ID")
dest_file_id = Column(String(64), primary_key=True, comment="目标文件ID")
src_name = Column(String(256), nullable=True, comment="源文件名")
dest_name = Column(String(256), nullable=True, comment="目标文件名")
src_type = Column(String(256), nullable=True, comment="源文件类型")
dest_type = Column(String(256), nullable=True, comment="目标文件类型")
src_size = Column(BigInteger, nullable=True, comment="源文件大小")
dest_size = Column(BigInteger, nullable=True, comment="目标文件大小")
status = Column(String(256), nullable=True, comment="处理状态")
result = Column(Text, nullable=True, comment="处理结果")
def __repr__(self):
return f"<CleanResult(instance_id={self.instance_id}, dest_file_id={self.dest_file_id}, status={self.status})>"

View File

@@ -0,0 +1,27 @@
from sqlalchemy import Column, String, BigInteger, Integer, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CleanTask(Base):
"""清洗任务模型"""
__tablename__ = "t_clean_task"
id = Column(String(64), primary_key=True, comment="任务ID")
name = Column(String(64), nullable=True, comment="任务名称")
description = Column(String(256), nullable=True, comment="任务描述")
status = Column(String(256), nullable=True, comment="任务状态")
src_dataset_id = Column(String(64), nullable=True, comment="源数据集ID")
src_dataset_name = Column(String(64), nullable=True, comment="源数据集名称")
dest_dataset_id = Column(String(64), nullable=True, comment="目标数据集ID")
dest_dataset_name = Column(String(64), nullable=True, comment="目标数据集名称")
before_size = Column(BigInteger, nullable=True, comment="清洗前大小")
after_size = Column(BigInteger, nullable=True, comment="清洗后大小")
file_count = Column(Integer, nullable=True, comment="文件数量")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
started_at = Column(TIMESTAMP, nullable=True, comment="开始时间")
finished_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
created_by = Column(String(256), nullable=True, comment="创建者")
def __repr__(self):
return f"<CleanTask(id={self.id}, name={self.name}, status={self.status})>"

View File

@@ -0,0 +1,18 @@
from sqlalchemy import Column, String, Text, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CleanTemplate(Base):
"""清洗模板模型"""
__tablename__ = "t_clean_template"
id = Column(String(64), primary_key=True, unique=True, comment="模板ID")
name = Column(String(64), nullable=True, comment="模板名称")
description = Column(String(256), nullable=True, comment="模板描述")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(256), nullable=True, comment="创建者")
def __repr__(self):
return f"<CleanTemplate(id={self.id}, name={self.name})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, Integer, Text
from app.db.database import Base
class OperatorInstance(Base):
"""算子实例模型"""
__tablename__ = "t_operator_instance"
instance_id = Column(String(256), primary_key=True, comment="实例ID")
operator_id = Column(String(256), primary_key=True, comment="算子ID")
op_index = Column(Integer, primary_key=True, comment="算子索引")
settings_override = Column(Text, nullable=True, comment="配置覆盖")
def __repr__(self):
return f"<OperatorInstance(instance_id={self.instance_id}, operator_id={self.operator_id}, index={self.op_index})>"

View File

@@ -0,0 +1,13 @@
# app/models/collection/__init__.py
from .task_execution import TaskExecution
from .collection_task import CollectionTask
from .task_log import TaskLog
from .datax_template import DataxTemplate
__all__ = [
"TaskExecution",
"CollectionTask",
"TaskLog",
"DataxTemplate"
]

View File

@@ -0,0 +1,28 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class CollectionTask(Base):
"""数据归集任务模型"""
__tablename__ = "t_dc_collection_tasks"
id = Column(String(36), primary_key=True, comment="任务ID(UUID)")
name = Column(String(255), nullable=False, comment="任务名称")
description = Column(Text, nullable=True, comment="任务描述")
sync_mode = Column(String(20), default='ONCE', comment="同步模式:ONCE/SCHEDULED")
config = Column(Text, nullable=False, comment="归集配置(DataX配置),包含源端和目标端配置信息")
schedule_expression = Column(String(255), nullable=True, comment="Cron调度表达式")
status = Column(String(20), default='DRAFT', comment="任务状态:DRAFT/READY/RUNNING/SUCCESS/FAILED/STOPPED")
retry_count = Column(Integer, default=3, comment="重试次数")
timeout_seconds = Column(Integer, default=3600, comment="超时时间(秒)")
max_records = Column(BigInteger, nullable=True, comment="最大处理记录数")
sort_field = Column(String(100), nullable=True, comment="增量字段")
last_execution_id = Column(String(36), nullable=True, comment="最后执行ID(UUID)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<CollectionTask(id={self.id}, name={self.name}, status={self.status})>"

View File

@@ -0,0 +1,23 @@
from sqlalchemy import Column, String, Text, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class DataxTemplate(Base):
"""DataX模板配置模型"""
__tablename__ = "t_dc_datax_templates"
id = Column(String(36), primary_key=True, comment="模板ID(UUID)")
name = Column(String(255), nullable=False, unique=True, comment="模板名称")
source_type = Column(String(50), nullable=False, comment="源数据源类型")
target_type = Column(String(50), nullable=False, comment="目标数据源类型")
template_content = Column(Text, nullable=False, comment="模板内容")
description = Column(Text, nullable=True, comment="模板描述")
version = Column(String(20), default='1.0.0', comment="版本号")
is_system = Column(Boolean, default=False, comment="是否系统模板")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
def __repr__(self):
return f"<DataxTemplate(id={self.id}, name={self.name}, source={self.source_type}, target={self.target_type})>"

View File

@@ -0,0 +1,34 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, DECIMAL, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class TaskExecution(Base):
"""任务执行明细模型"""
__tablename__ = "t_dc_task_executions"
id = Column(String(36), primary_key=True, comment="执行记录ID(UUID)")
task_id = Column(String(36), nullable=False, comment="任务ID")
task_name = Column(String(255), nullable=False, comment="任务名称")
status = Column(String(20), default='RUNNING', comment="执行状态:RUNNING/SUCCESS/FAILED/STOPPED")
progress = Column(DECIMAL(5, 2), default=0.00, comment="进度百分比")
records_total = Column(BigInteger, default=0, comment="总记录数")
records_processed = Column(BigInteger, default=0, comment="已处理记录数")
records_success = Column(BigInteger, default=0, comment="成功记录数")
records_failed = Column(BigInteger, default=0, comment="失败记录数")
throughput = Column(DECIMAL(10, 2), default=0.00, comment="吞吐量(条/秒)")
data_size_bytes = Column(BigInteger, default=0, comment="数据量(字节)")
started_at = Column(TIMESTAMP, nullable=True, comment="开始时间")
completed_at = Column(TIMESTAMP, nullable=True, comment="完成时间")
duration_seconds = Column(Integer, default=0, comment="执行时长(秒)")
config = Column(JSON, nullable=True, comment="执行配置")
error_message = Column(Text, nullable=True, comment="错误信息")
datax_job_id = Column(Text, nullable=True, comment="datax任务ID")
result = Column(Text, nullable=True, comment="执行结果")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<TaskExecution(id={self.id}, task_id={self.task_id}, status={self.status})>"

View File

@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Text, Integer, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class TaskLog(Base):
"""任务执行记录模型"""
__tablename__ = "t_dc_task_log"
id = Column(String(36), primary_key=True, comment="执行记录ID(UUID)")
task_id = Column(String(36), nullable=False, comment="任务ID")
task_name = Column(String(255), nullable=False, comment="任务名称")
sync_mode = Column(String(20), default='FULL', comment="同步模式:FULL/INCREMENTAL")
status = Column(String(20), default='RUNNING', comment="执行状态:RUNNING/SUCCESS/FAILED/STOPPED")
start_time = Column(TIMESTAMP, nullable=True, comment="开始时间")
end_time = Column(TIMESTAMP, nullable=True, comment="结束时间")
duration = Column(BigInteger, nullable=True, comment="执行时长(毫秒)")
process_id = Column(String(50), nullable=True, comment="进程ID")
log_path = Column(String(500), nullable=True, comment="日志文件路径")
error_msg = Column(Text, nullable=True, comment="错误信息")
result = Column(Text, nullable=True, comment="执行结果")
retry_times = Column(Integer, default=0, comment="重试次数")
create_time = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
def __repr__(self):
return f"<TaskLog(id={self.id}, task_id={self.task_id}, status={self.status})>"

View File

@@ -0,0 +1,7 @@
# app/models/common/__init__.py
from .chunk_upload_request import ChunkUploadRequest
__all__ = [
"ChunkUploadRequest"
]

View File

@@ -0,0 +1,19 @@
from sqlalchemy import Column, String, Integer, Text, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class ChunkUploadRequest(Base):
"""文件切片上传请求模型"""
__tablename__ = "t_chunk_upload_request"
id = Column(String(36), primary_key=True, comment="UUID")
total_file_num = Column(Integer, nullable=True, comment="总文件数")
uploaded_file_num = Column(Integer, nullable=True, comment="已上传文件数")
upload_path = Column(String(256), nullable=True, comment="文件路径")
timeout = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="上传请求超时时间")
service_id = Column(String(64), nullable=True, comment="上传请求所属服务:DATA-MANAGEMENT(数据管理)")
check_info = Column(Text, nullable=True, comment="业务信息")
def __repr__(self):
return f"<ChunkUploadRequest(id={self.id}, service_id={self.service_id}, progress={self.uploaded_file_num}/{self.total_file_num})>"

View File

@@ -0,0 +1,21 @@
# app/models/dm/__init__.py
from .annotation_template import AnnotationTemplate
from .labeling_project import LabelingProject
from .dataset import Dataset
from .dataset_files import DatasetFiles
from .dataset_statistics import DatasetStatistics
from .dataset_tag import DatasetTag
from .tag import Tag
from .user import User
__all__ = [
"AnnotationTemplate",
"LabelingProject",
"Dataset",
"DatasetFiles",
"DatasetStatistics",
"DatasetTag",
"Tag",
"User"
]

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class AnnotationTemplate(Base):
"""标注模板模型"""
__tablename__ = "t_dm_annotation_templates"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
name = Column(String(32), nullable=False, comment="模板名称")
description = Column(String(255), nullable=True, comment="模板描述")
configuration = Column(JSON, nullable=True, comment="配置信息(JSON格式)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<AnnotationTemplate(id={self.id}, name={self.name})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -0,0 +1,35 @@
from sqlalchemy import Column, String, Text, BigInteger, Integer, Boolean, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class Dataset(Base):
"""数据集模型(支持医学影像、文本、问答等多种类型)"""
__tablename__ = "t_dm_datasets"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(255), nullable=False, comment="数据集名称")
description = Column(Text, nullable=True, comment="数据集描述")
dataset_type = Column(String(50), nullable=False, comment="数据集类型:IMAGE/TEXT/QA/MULTIMODAL/OTHER")
category = Column(String(100), nullable=True, comment="数据集分类:医学影像/问答/文献等")
path = Column(String(500), nullable=True, comment="数据存储路径")
format = Column(String(50), nullable=True, comment="数据格式:DCM/JPG/JSON/CSV等")
schema_info = Column(JSON, nullable=True, comment="数据结构信息")
size_bytes = Column(BigInteger, default=0, comment="数据大小(字节)")
file_count = Column(BigInteger, default=0, comment="文件数量")
record_count = Column(BigInteger, default=0, comment="记录数量")
retention_days = Column(Integer, default=0, comment="数据保留天数(0表示长期保留)")
tags = Column(JSON, nullable=True, comment="标签列表")
metadata = Column(JSON, nullable=True, comment="元数据信息")
status = Column(String(50), default='DRAFT', comment="状态:DRAFT/ACTIVE/ARCHIVED")
is_public = Column(Boolean, default=False, comment="是否公开")
is_featured = Column(Boolean, default=False, comment="是否推荐")
version = Column(BigInteger, nullable=False, default=0, comment="版本号")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
created_by = Column(String(255), nullable=True, comment="创建者")
updated_by = Column(String(255), nullable=True, comment="更新者")
def __repr__(self):
return f"<Dataset(id={self.id}, name={self.name}, type={self.dataset_type})>"

View File

@@ -0,0 +1,27 @@
from sqlalchemy import Column, String, JSON, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetFiles(Base):
"""DM数据集文件模型"""
__tablename__ = "t_dm_dataset_files"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="所属数据集ID(UUID)")
file_name = Column(String(255), nullable=False, comment="文件名")
file_path = Column(String(1000), nullable=False, comment="文件路径")
file_type = Column(String(50), nullable=True, comment="文件格式:JPG/PNG/DCM/TXT等")
file_size = Column(BigInteger, default=0, comment="文件大小(字节)")
check_sum = Column(String(64), nullable=True, comment="文件校验和")
tags = Column(JSON, nullable=True, comment="文件标签信息")
metadata = Column(JSON, nullable=True, comment="文件元数据")
status = Column(String(50), default='ACTIVE', comment="文件状态:ACTIVE/DELETED/PROCESSING")
upload_time = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="上传时间")
last_access_time = Column(TIMESTAMP, nullable=True, comment="最后访问时间")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<DatasetFiles(id={self.id}, dataset_id={self.dataset_id}, file_name={self.file_name})>"

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, String, Date, BigInteger, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetStatistics(Base):
"""数据集统计信息模型"""
__tablename__ = "t_dm_dataset_statistics"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID(UUID)")
stat_date = Column(Date, nullable=False, comment="统计日期")
total_files = Column(BigInteger, default=0, comment="总文件数")
total_size = Column(BigInteger, default=0, comment="总大小(字节)")
processed_files = Column(BigInteger, default=0, comment="已处理文件数")
error_files = Column(BigInteger, default=0, comment="错误文件数")
download_count = Column(BigInteger, default=0, comment="下载次数")
view_count = Column(BigInteger, default=0, comment="查看次数")
quality_metrics = Column(JSON, nullable=True, comment="质量指标")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<DatasetStatistics(id={self.id}, dataset_id={self.dataset_id}, date={self.stat_date})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class DatasetTag(Base):
"""数据集标签关联模型"""
__tablename__ = "t_dm_dataset_tags"
dataset_id = Column(String(36), primary_key=True, comment="数据集ID(UUID)")
tag_id = Column(String(36), primary_key=True, comment="标签ID(UUID)")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
def __repr__(self):
return f"<DatasetTag(dataset_id={self.dataset_id}, tag_id={self.tag_id})>"

View File

@@ -0,0 +1,26 @@
from sqlalchemy import Column, String, Integer, JSON, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class LabelingProject(Base):
"""DM标注项目模型(原 DatasetMapping)"""
__tablename__ = "t_dm_labeling_projects"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID主键ID")
dataset_id = Column(String(36), nullable=False, comment="数据集ID")
name = Column(String(32), nullable=False, comment="项目名称")
labeling_project_id = Column(Integer, nullable=False, comment="Label Studio项目ID")
configuration = Column(JSON, nullable=True, comment="标签配置")
progress = Column(JSON, nullable=True, comment="标注进度统计")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
deleted_at = Column(TIMESTAMP, nullable=True, comment="删除时间(软删除)")
def __repr__(self):
return f"<LabelingProject(id={self.id}, dataset_id={self.dataset_id}, name={self.name})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -0,0 +1,21 @@
from sqlalchemy import Column, String, Text, BigInteger, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class Tag(Base):
"""标签模型"""
__tablename__ = "t_dm_tags"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID")
name = Column(String(100), nullable=False, unique=True, comment="标签名称")
description = Column(Text, nullable=True, comment="标签描述")
category = Column(String(50), nullable=True, comment="标签分类")
color = Column(String(7), nullable=True, comment="标签颜色(十六进制)")
usage_count = Column(BigInteger, default=0, comment="使用次数")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<Tag(id={self.id}, name={self.name}, category={self.category})>"

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, BigInteger, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class User(Base):
"""用户模型"""
__tablename__ = "users"
id = Column(BigInteger, primary_key=True, autoincrement=True, comment="用户ID")
username = Column(String(255), nullable=False, unique=True, comment="用户名")
email = Column(String(255), nullable=False, unique=True, comment="邮箱")
password_hash = Column(String(255), nullable=False, comment="密码哈希")
full_name = Column(String(255), nullable=True, comment="真实姓名")
avatar_url = Column(String(500), nullable=True, comment="头像URL")
role = Column(String(50), nullable=False, default='USER', comment="角色:ADMIN/USER")
organization = Column(String(255), nullable=True, comment="所属机构")
enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
last_login_at = Column(TIMESTAMP, nullable=True, comment="最后登录时间")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<User(id={self.id}, username={self.username}, role={self.role})>"

View File

@@ -0,0 +1,11 @@
# app/models/operator/__init__.py
from .operator import Operator
from .operator_category import OperatorCategory
from .operator_category_relation import OperatorCategoryRelation
__all__ = [
"Operator",
"OperatorCategory",
"OperatorCategoryRelation"
]

View File

@@ -0,0 +1,24 @@
from sqlalchemy import Column, String, Text, Boolean, TIMESTAMP
from sqlalchemy.sql import func
from app.db.database import Base
class Operator(Base):
"""算子模型"""
__tablename__ = "t_operator"
id = Column(String(64), primary_key=True, comment="算子ID")
name = Column(String(64), nullable=True, comment="算子名称")
description = Column(String(256), nullable=True, comment="算子描述")
version = Column(String(256), nullable=True, comment="版本")
inputs = Column(String(256), nullable=True, comment="输入类型")
outputs = Column(String(256), nullable=True, comment="输出类型")
runtime = Column(Text, nullable=True, comment="运行时信息")
settings = Column(Text, nullable=True, comment="配置信息")
file_name = Column(Text, nullable=True, comment="文件名")
is_star = Column(Boolean, nullable=True, comment="是否收藏")
created_at = Column(TIMESTAMP, server_default=func.current_timestamp(), comment="创建时间")
updated_at = Column(TIMESTAMP, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), comment="更新时间")
def __repr__(self):
return f"<Operator(id={self.id}, name={self.name}, version={self.version})>"

View File

@@ -0,0 +1,15 @@
from sqlalchemy import Column, String, Integer
from app.db.database import Base
class OperatorCategory(Base):
"""算子分类模型"""
__tablename__ = "t_operator_category"
id = Column(Integer, primary_key=True, autoincrement=True, comment="分类ID")
name = Column(String(64), nullable=True, comment="分类名称")
type = Column(String(64), nullable=True, comment="分类类型")
parent_id = Column(Integer, nullable=True, comment="父分类ID")
def __repr__(self):
return f"<OperatorCategory(id={self.id}, name={self.name}, type={self.type})>"

View File

@@ -0,0 +1,13 @@
from sqlalchemy import Column, String, Integer
from app.db.database import Base
class OperatorCategoryRelation(Base):
"""算子分类关联模型"""
__tablename__ = "t_operator_category_relation"
category_id = Column(Integer, primary_key=True, comment="分类ID")
operator_id = Column(String(64), primary_key=True, comment="算子ID")
def __repr__(self):
return f"<OperatorCategoryRelation(category_id={self.category_id}, operator_id={self.operator_id})>"

View File

@@ -6,7 +6,7 @@ from .common import BaseResponseModel
class DatasetMappingBase(BaseResponseModel): class DatasetMappingBase(BaseResponseModel):
"""数据集映射 基础模型""" """数据集映射 基础模型"""
source_dataset_id: str = Field(..., description="源数据集ID") dataset_id: str = Field(..., description="源数据集ID")
class DatasetMappingCreateRequest(DatasetMappingBase): class DatasetMappingCreateRequest(DatasetMappingBase):
"""数据集映射 创建 请求模型""" """数据集映射 创建 请求模型"""
@@ -21,7 +21,7 @@ class DatasetMappingCreateResponse(BaseResponseModel):
class DatasetMappingUpdateRequest(BaseResponseModel): class DatasetMappingUpdateRequest(BaseResponseModel):
"""数据集映射 更新 请求模型""" """数据集映射 更新 请求模型"""
source_dataset_id: Optional[str] = Field(None, description="源数据集ID") dataset_id: Optional[str] = Field(None, description="源数据集ID")
class DatasetMappingResponse(DatasetMappingBase): class DatasetMappingResponse(DatasetMappingBase):
"""数据集映射 查询 响应模型""" """数据集映射 查询 响应模型"""

View File

@@ -5,7 +5,7 @@ from typing import Optional, List, Tuple
from datetime import datetime from datetime import datetime
import uuid import uuid
from app.models.dataset_mapping import DatasetMapping from app.models.dm.labeling_project import LabelingProject
from app.schemas.dataset_mapping import ( from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest, DatasetMappingCreateRequest,
DatasetMappingUpdateRequest, DatasetMappingUpdateRequest,
@@ -28,11 +28,11 @@ class DatasetMappingService:
labelling_project_name: str labelling_project_name: str
) -> DatasetMappingResponse: ) -> DatasetMappingResponse:
"""创建数据集映射""" """创建数据集映射"""
logger.info(f"Create dataset mapping: {mapping_data.source_dataset_id} -> {labelling_project_id}") logger.info(f"Create dataset mapping: {mapping_data.dataset_id} -> {labelling_project_id}")
db_mapping = DatasetMapping( db_mapping = LabelingProject(
mapping_id=str(uuid.uuid4()), mapping_id=str(uuid.uuid4()),
source_dataset_id=mapping_data.source_dataset_id, dataset_id=mapping_data.dataset_id,
labelling_project_id=labelling_project_id, labelling_project_id=labelling_project_id,
labelling_project_name=labelling_project_name labelling_project_name=labelling_project_name
) )
@@ -41,48 +41,48 @@ class DatasetMappingService:
await self.db.commit() await self.db.commit()
await self.db.refresh(db_mapping) await self.db.refresh(db_mapping)
logger.info(f"Mapping created: {db_mapping.mapping_id}") logger.info(f"Mapping created: {db_mapping.id}")
return DatasetMappingResponse.model_validate(db_mapping) return DatasetMappingResponse.model_validate(db_mapping)
async def get_mapping_by_source_uuid( async def get_mapping_by_source_uuid(
self, self,
source_dataset_id: str dataset_id: str
) -> Optional[DatasetMappingResponse]: ) -> Optional[DatasetMappingResponse]:
"""根据源数据集ID获取映射(返回第一个未删除的)""" """根据源数据集ID获取映射(返回第一个未删除的)"""
logger.debug(f"Get mapping by source dataset id: {source_dataset_id}") logger.debug(f"Get mapping by source dataset id: {dataset_id}")
result = await self.db.execute( result = await self.db.execute(
select(DatasetMapping).where( select(LabelingProject).where(
DatasetMapping.source_dataset_id == source_dataset_id, LabelingProject.dataset_id == dataset_id,
DatasetMapping.deleted_at.is_(None) LabelingProject.deleted_at.is_(None)
) )
) )
mapping = result.scalar_one_or_none() mapping = result.scalar_one_or_none()
if mapping: if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}") logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping) return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for source dataset id: {source_dataset_id}") logger.debug(f"No mapping found for source dataset id: {dataset_id}")
return None return None
async def get_mappings_by_source_dataset_id( async def get_mappings_by_dataset_id(
self, self,
source_dataset_id: str, dataset_id: str,
include_deleted: bool = False include_deleted: bool = False
) -> List[DatasetMappingResponse]: ) -> List[DatasetMappingResponse]:
"""根据源数据集ID获取所有映射关系""" """根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {source_dataset_id}") logger.debug(f"Get all mappings by source dataset id: {dataset_id}")
query = select(DatasetMapping).where( query = select(LabelingProject).where(
DatasetMapping.source_dataset_id == source_dataset_id LabelingProject.dataset_id == dataset_id
) )
if not include_deleted: if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute( result = await self.db.execute(
query.order_by(DatasetMapping.created_at.desc()) query.order_by(LabelingProject.created_at.desc())
) )
mappings = result.scalars().all() mappings = result.scalars().all()
@@ -97,9 +97,9 @@ class DatasetMappingService:
logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}") logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}")
result = await self.db.execute( result = await self.db.execute(
select(DatasetMapping).where( select(LabelingProject).where(
DatasetMapping.labelling_project_id == labelling_project_id, LabelingProject.labeling_project_id == labelling_project_id,
DatasetMapping.deleted_at.is_(None) LabelingProject.deleted_at.is_(None)
) )
) )
mapping = result.scalar_one_or_none() mapping = result.scalar_one_or_none()
@@ -116,15 +116,15 @@ class DatasetMappingService:
logger.debug(f"Get mapping: {mapping_id}") logger.debug(f"Get mapping: {mapping_id}")
result = await self.db.execute( result = await self.db.execute(
select(DatasetMapping).where( select(LabelingProject).where(
DatasetMapping.mapping_id == mapping_id, LabelingProject.id == mapping_id,
DatasetMapping.deleted_at.is_(None) LabelingProject.deleted_at.is_(None)
) )
) )
mapping = result.scalar_one_or_none() mapping = result.scalar_one_or_none()
if mapping: if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}") logger.debug(f"Found mapping: {mapping.id}")
return DatasetMappingResponse.model_validate(mapping) return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Mapping not found: {mapping_id}") logger.debug(f"Mapping not found: {mapping_id}")
@@ -143,11 +143,11 @@ class DatasetMappingService:
return None return None
update_values = update_data.model_dump(exclude_unset=True) update_values = update_data.model_dump(exclude_unset=True)
update_values["last_updated_at"] = datetime.utcnow() update_values["last_updated_at"] = datetime.now()
result = await self.db.execute( result = await self.db.execute(
update(DatasetMapping) update(LabelingProject)
.where(DatasetMapping.mapping_id == mapping_id) .where(LabelingProject.id == mapping_id)
.values(**update_values) .values(**update_values)
) )
await self.db.commit() await self.db.commit()
@@ -161,10 +161,10 @@ class DatasetMappingService:
logger.debug(f"Update mapping last updated at: {mapping_id}") logger.debug(f"Update mapping last updated at: {mapping_id}")
result = await self.db.execute( result = await self.db.execute(
update(DatasetMapping) update(LabelingProject)
.where( .where(
DatasetMapping.mapping_id == mapping_id, LabelingProject.id == mapping_id,
DatasetMapping.deleted_at.is_(None) LabelingProject.deleted_at.is_(None)
) )
.values(last_updated_at=datetime.utcnow()) .values(last_updated_at=datetime.utcnow())
) )
@@ -176,12 +176,12 @@ class DatasetMappingService:
logger.info(f"Soft delete mapping: {mapping_id}") logger.info(f"Soft delete mapping: {mapping_id}")
result = await self.db.execute( result = await self.db.execute(
update(DatasetMapping) update(LabelingProject)
.where( .where(
DatasetMapping.mapping_id == mapping_id, LabelingProject.id == mapping_id,
DatasetMapping.deleted_at.is_(None) LabelingProject.deleted_at.is_(None)
) )
.values(deleted_at=datetime.utcnow()) .values(deleted_at=datetime.now())
) )
await self.db.commit() await self.db.commit()
@@ -202,11 +202,11 @@ class DatasetMappingService:
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}") logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
result = await self.db.execute( result = await self.db.execute(
select(DatasetMapping) select(LabelingProject)
.where(DatasetMapping.deleted_at.is_(None)) .where(LabelingProject.deleted_at.is_(None))
.offset(skip) .offset(skip)
.limit(limit) .limit(limit)
.order_by(DatasetMapping.created_at.desc()) .order_by(LabelingProject.created_at.desc())
) )
mappings = result.scalars().all() mappings = result.scalars().all()
@@ -215,10 +215,10 @@ class DatasetMappingService:
async def count_mappings(self, include_deleted: bool = False) -> int: async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数""" """统计映射总数"""
query = select(func.count()).select_from(DatasetMapping) query = select(func.count()).select_from(LabelingProject)
if not include_deleted: if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
result = await self.db.execute(query) result = await self.db.execute(query)
return result.scalar_one() return result.scalar_one()
@@ -233,14 +233,14 @@ class DatasetMappingService:
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}") logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询 # 构建查询
query = select(DatasetMapping) query = select(LabelingProject)
if not include_deleted: if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(DatasetMapping) count_query = select(func.count()).select_from(LabelingProject)
if not include_deleted: if not include_deleted:
count_query = count_query.where(DatasetMapping.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query) count_result = await self.db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
@@ -250,7 +250,7 @@ class DatasetMappingService:
query query
.offset(skip) .offset(skip)
.limit(limit) .limit(limit)
.order_by(DatasetMapping.created_at.desc()) .order_by(LabelingProject.created_at.desc())
) )
mappings = result.scalars().all() mappings = result.scalars().all()
@@ -259,28 +259,28 @@ class DatasetMappingService:
async def get_mappings_by_source_with_count( async def get_mappings_by_source_with_count(
self, self,
source_dataset_id: str, dataset_id: str,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
include_deleted: bool = False include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]: ) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)""" """根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {source_dataset_id}") logger.debug(f"Get mappings by source dataset id with count: {dataset_id}")
# 构建查询 # 构建查询
query = select(DatasetMapping).where( query = select(LabelingProject).where(
DatasetMapping.source_dataset_id == source_dataset_id LabelingProject.dataset_id == dataset_id
) )
if not include_deleted: if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(DatasetMapping).where( count_query = select(func.count()).select_from(LabelingProject).where(
DatasetMapping.source_dataset_id == source_dataset_id LabelingProject.dataset_id == dataset_id
) )
if not include_deleted: if not include_deleted:
count_query = count_query.where(DatasetMapping.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))
count_result = await self.db.execute(count_query) count_result = await self.db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
@@ -290,7 +290,7 @@ class DatasetMappingService:
query query
.offset(skip) .offset(skip)
.limit(limit) .limit(limit)
.order_by(DatasetMapping.created_at.desc()) .order_by(LabelingProject.created_at.desc())
) )
mappings = result.scalars().all() mappings = result.scalars().all()

View File

@@ -1,6 +1,5 @@
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any, Tuple
from app.clients.dm_client import DMServiceClient from app.infrastructure import LabelStudioClient, DatamateClient
from app.clients.label_studio_client import LabelStudioClient
from app.services.dataset_mapping_service import DatasetMappingService from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import SyncDatasetResponse from app.schemas.dataset_mapping import SyncDatasetResponse
from app.core.logging import get_logger from app.core.logging import get_logger
@@ -14,7 +13,7 @@ class SyncService:
def __init__( def __init__(
self, self,
dm_client: DMServiceClient, dm_client: DatamateClient,
ls_client: LabelStudioClient, ls_client: LabelStudioClient,
mapping_service: DatasetMappingService mapping_service: DatasetMappingService
): ):
@@ -107,9 +106,9 @@ class SyncService:
try: try:
# 获取数据集信息 # 获取数据集信息
dataset_info = await self.dm_client.get_dataset(mapping.source_dataset_id) dataset_info = await self.dm_client.get_dataset(mapping.dataset_id)
if not dataset_info: if not dataset_info:
raise NoDatasetInfoFoundError(mapping.source_dataset_id) raise NoDatasetInfoFoundError(mapping.dataset_id)
synced_files = 0 synced_files = 0
deleted_tasks = 0 deleted_tasks = 0
@@ -129,7 +128,7 @@ class SyncService:
# 分页获取并同步文件 # 分页获取并同步文件
while True: while True:
files_response = await self.dm_client.get_dataset_files( files_response = await self.dm_client.get_dataset_files(
mapping.source_dataset_id, mapping.dataset_id,
page=page, page=page,
size=batch_size, size=batch_size,
status="COMPLETED" # 只同步已完成的文件 status="COMPLETED" # 只同步已完成的文件
@@ -173,7 +172,7 @@ class SyncService:
"meta": { "meta": {
"file_size": file_info.size, "file_size": file_info.size,
"file_type": file_info.fileType, "file_type": file_info.fileType,
"dm_dataset_id": mapping.source_dataset_id, "dm_dataset_id": mapping.dataset_id,
"dm_file_id": file_info.id, "dm_file_id": file_info.id,
"original_name": file_info.originalName, "original_name": file_info.originalName,
} }
@@ -249,22 +248,22 @@ class SyncService:
async def get_sync_status( async def get_sync_status(
self, self,
source_dataset_id: str dataset_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""获取同步状态""" """获取同步状态"""
mapping = await self.mapping_service.get_mapping_by_source_uuid(source_dataset_id) mapping = await self.mapping_service.get_mapping_by_source_uuid(dataset_id)
if not mapping: if not mapping:
return None return None
# 获取DM数据集信息 # 获取DM数据集信息
dataset_info = await self.dm_client.get_dataset(source_dataset_id) dataset_info = await self.dm_client.get_dataset(dataset_id)
# 获取Label Studio项目任务数量 # 获取Label Studio项目任务数量
tasks_info = await self.ls_client.get_project_tasks(mapping.labelling_project_id) tasks_info = await self.ls_client.get_project_tasks(mapping.labelling_project_id)
return { return {
"mapping_id": mapping.mapping_id, "mapping_id": mapping.mapping_id,
"source_dataset_id": source_dataset_id, "dataset_id": dataset_id,
"labelling_project_id": mapping.labelling_project_id, "labelling_project_id": mapping.labelling_project_id,
"last_updated_at": mapping.last_updated_at, "last_updated_at": mapping.last_updated_at,
"dm_total_files": dataset_info.fileCount if dataset_info else 0, "dm_total_files": dataset_info.fileCount if dataset_info else 0,

View File

@@ -0,0 +1,5 @@
uvicorn app.main:app \
--host 0.0.0.0 \
--port 18000 \
--reload \
--log-level debug

View File

@@ -1,148 +0,0 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# sqlalchemy.url = driver://user:pass@localhost/dbname
# 注释掉默认 URL,我们将在 env.py 中从应用配置读取
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -1 +0,0 @@
Generic single-database configuration.

View File

@@ -1,145 +0,0 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import create_engine, text
from alembic import context
import os
from urllib.parse import quote_plus
# 导入应用配置和模型
from app.core.config import settings
from app.db.database import Base
# 导入所有模型,以便 autogenerate 能够检测到它们
from app.models import dataset_mapping # noqa
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
def ensure_database_and_user():
"""
确保数据库和用户存在
使用 root 用户连接 MySQL,创建数据库和应用用户
"""
# 只在 MySQL 配置时执行
if not settings.mysql_host:
return
mysql_root_password = os.getenv('MYSQL_ROOT_PASSWORD', 'password')
# URL 编码密码以处理特殊字符
encoded_password = quote_plus(mysql_root_password)
# 使用 root 用户连接(不指定数据库)
root_url = f"mysql+pymysql://root:{encoded_password}@{settings.mysql_host}:{settings.mysql_port}/"
try:
root_engine = create_engine(root_url, poolclass=pool.NullPool)
with root_engine.connect() as conn:
# 创建数据库(如果不存在)
conn.execute(text(
f"CREATE DATABASE IF NOT EXISTS `{settings.mysql_database}` "
f"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
))
conn.commit()
# 创建用户(如果不存在)- 使用 MySQL 8 默认的 caching_sha2_password
conn.execute(text(
f"CREATE USER IF NOT EXISTS '{settings.mysql_user}'@'%' "
f"IDENTIFIED BY '{settings.mysql_password}'"
))
conn.commit()
# 授予权限
conn.execute(text(
f"GRANT ALL PRIVILEGES ON `{settings.mysql_database}`.* TO '{settings.mysql_user}'@'%'"
))
conn.commit()
# 刷新权限
conn.execute(text("FLUSH PRIVILEGES"))
conn.commit()
root_engine.dispose()
print(f"✓ Database '{settings.mysql_database}' and user '{settings.mysql_user}' are ready")
except Exception as e:
print(f"⚠️ Warning: Could not ensure database and user: {e}")
print(f" This may be expected if database already exists or permissions are set")
# 从应用配置设置数据库 URL
config.set_main_option('sqlalchemy.url', settings.sync_database_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
# 先确保数据库和用户存在
ensure_database_and_user()
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,28 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -1,41 +0,0 @@
"""Initiation
Revision ID: 755dc1afb8ad
Revises:
Create Date: 2025-10-20 19:34:20.258554
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '755dc1afb8ad'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mapping',
sa.Column('mapping_id', sa.String(length=36), nullable=False),
sa.Column('source_dataset_id', sa.String(length=36), nullable=False, comment='源数据集ID'),
sa.Column('labelling_project_id', sa.String(length=36), nullable=False, comment='标注项目ID'),
sa.Column('labelling_project_name', sa.String(length=255), nullable=True, comment='标注项目名称'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='最后更新时间'),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True, comment='删除时间'),
sa.PrimaryKeyConstraint('mapping_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mapping')
# ### end Alembic commands ###

View File

@@ -1,8 +0,0 @@
# app/clients/__init__.py
from .dm_client import DMServiceClient
from .label_studio_client import LabelStudioClient
from .client_manager import get_clients, set_clients, get_dm_client, get_ls_client
__all__ = ["DMServiceClient", "LabelStudioClient", "get_clients", "set_clients", "get_dm_client", "get_ls_client"]

View File

@@ -1,34 +0,0 @@
from typing import Optional
from fastapi import HTTPException
from .dm_client import DMServiceClient
from .label_studio_client import LabelStudioClient
# 全局客户端实例(将在main.py中初始化)
dm_client: Optional[DMServiceClient] = None
ls_client: Optional[LabelStudioClient] = None
def get_clients() -> tuple[DMServiceClient, LabelStudioClient]:
"""获取客户端实例"""
global dm_client, ls_client
if not dm_client or not ls_client:
raise HTTPException(status_code=500, detail="客户端未初始化")
return dm_client, ls_client
def set_clients(dm_client_instance: DMServiceClient, ls_client_instance: LabelStudioClient) -> None:
"""设置全局客户端实例"""
global dm_client, ls_client
dm_client = dm_client_instance
ls_client = ls_client_instance
def get_dm_client() -> DMServiceClient:
"""获取DM服务客户端"""
if not dm_client:
raise HTTPException(status_code=500, detail="DM客户端未初始化")
return dm_client
def get_ls_client() -> LabelStudioClient:
"""获取Label Studio客户端"""
if not ls_client:
raise HTTPException(status_code=500, detail="Label Studio客户端未初始化")
return ls_client

View File

@@ -1,138 +0,0 @@
import httpx
from typing import Optional
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.dm_service import DatasetResponse, PagedDatasetFileResponse
logger = get_logger(__name__)
class DMServiceClient:
"""数据管理服务客户端"""
def __init__(self, base_url: str|None = None, timeout: float = 30.0):
self.base_url = base_url or settings.dm_service_base_url
self.timeout = timeout
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout
)
logger.info(f"Initialize DM service client, base url: {self.base_url}")
@staticmethod
def _unwrap_payload(data):
"""Unwrap common envelope shapes like {'code': ..., 'message': ..., 'data': {...}}."""
if isinstance(data, dict) and 'data' in data and isinstance(data['data'], (dict, list)):
return data['data']
return data
@staticmethod
def _is_error_payload(data) -> bool:
"""Detect error-shaped payloads returned with HTTP 200."""
if not isinstance(data, dict):
return False
# Common patterns: {error, message, ...} or {code, message, ...} without data
if 'error' in data and 'message' in data:
return True
if 'code' in data and 'message' in data and 'data' not in data:
return True
return False
@staticmethod
def _keys(d):
return list(d.keys()) if isinstance(d, dict) else []
async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]:
"""获取数据集详情"""
try:
logger.info(f"Getting dataset detail: {dataset_id} ...")
response = await self.client.get(f"/data-management/datasets/{dataset_id}")
response.raise_for_status()
raw = response.json()
data = self._unwrap_payload(raw)
if self._is_error_payload(data):
logger.error(f"DM service returned error for dataset {dataset_id}: {data}")
return None
if not isinstance(data, dict):
logger.error(f"Unexpected dataset payload type for {dataset_id}: {type(data).__name__}")
return None
required = ["id", "name", "description", "datasetType", "status", "fileCount", "totalSize"]
if not all(k in data for k in required):
logger.error(f"Dataset payload missing required fields for {dataset_id}. Keys: {self._keys(data)}")
return None
return DatasetResponse(**data)
except httpx.HTTPError as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
except Exception as e:
logger.error(f"[Unexpected] [GET] dataset {dataset_id}: \n{e}\nRaw JSON received: \n{raw}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
) -> Optional[PagedDatasetFileResponse]:
"""获取数据集文件列表"""
try:
logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
params: dict = {"page": page, "size": size}
if file_type:
params["fileType"] = file_type
if status:
params["status"] = status
response = await self.client.get(
f"/data-management/datasets/{dataset_id}/files",
params=params
)
response.raise_for_status()
raw = response.json()
data = self._unwrap_payload(raw)
if self._is_error_payload(data):
logger.error(f"DM service returned error for dataset files {dataset_id}: {data}")
return None
if not isinstance(data, dict):
logger.error(f"Unexpected dataset files payload type for {dataset_id}: {type(data).__name__}")
return None
required = ["content", "totalElements", "totalPages", "page", "size"]
if not all(k in data for k in required):
logger.error(f"Files payload missing required fields for {dataset_id}. Keys: {self._keys(data)}")
return None
return PagedDatasetFileResponse(**data)
except httpx.HTTPError as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
except Exception as e:
logger.error(f"[Unexpected] [GET] dataset files {dataset_id}: \n{e}\nRaw JSON received: \n{raw}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""下载文件内容"""
try:
logger.info(f"Download file: dataset={dataset_id}, file={file_id}")
response = await self.client.get(
f"/data-management/datasets/{dataset_id}/files/{file_id}/download"
)
response.raise_for_status()
return response.content
except httpx.HTTPError as e:
logger.error(f"Failed to download file {file_id}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error while downloading file {file_id}: {e}")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> str:
"""获取文件下载URL"""
return f"{self.base_url}/data-management/datasets/{dataset_id}/files/{file_id}/download"
async def close(self):
"""关闭客户端连接"""
await self.client.aclose()
logger.info("DM service client connection closed")

View File

@@ -1,5 +0,0 @@
# app/models/__init__.py
from .dataset_mapping import DatasetMapping
__all__ = ["DatasetMapping"]

View File

@@ -1,25 +0,0 @@
from sqlalchemy import Column, String, DateTime, Boolean, Text
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetMapping(Base):
"""数据集映射模型"""
__tablename__ = "mapping"
mapping_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
source_dataset_id = Column(String(36), nullable=False, comment="源数据集ID")
labelling_project_id = Column(String(36), nullable=False, comment="标注项目ID")
labelling_project_name = Column(String(255), nullable=True, comment="标注项目名称")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
last_updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后更新时间")
deleted_at = Column(DateTime(timezone=True), nullable=True, comment="删除时间")
def __repr__(self):
return f"<DatasetMapping(uuid={self.mapping_id}, source={self.source_dataset_id}, labelling={self.labelling_project_id})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None