Add Label Studio adapter module and its build scipts.

This commit is contained in:
Jason Wang
2025-10-22 15:14:01 +08:00
parent 1c97afed7d
commit c640105333
40 changed files with 2902 additions and 0 deletions

View File

@@ -0,0 +1,128 @@
# ====================================
# Label Studio Adapter Configuration
# ====================================
# =========================
# 应用程序配置
# =========================
APP_NAME="Label Studio Adapter"
APP_VERSION="1.0.0"
APP_DESCRIPTION="Adapter for integrating Data Management System with Label Studio"
DEBUG=true
# =========================
# 服务器配置
# =========================
HOST=0.0.0.0
PORT=18000
# =========================
# 日志配置
# =========================
LOG_LEVEL=INFO
# =========================
# Label Studio 服务配置
# =========================
# Label Studio 服务地址(根据部署方式调整)
# Docker 环境:http://label-studio:8080
# 本地开发:http://127.0.0.1:8000
LABEL_STUDIO_BASE_URL=http://label-studio:8080
# Label Studio 用户名和密码(用于自动创建用户)
LABEL_STUDIO_USERNAME=admin@example.com
LABEL_STUDIO_PASSWORD=password
# Label Studio API 认证 Token(Legacy Token,推荐使用)
# 从 Label Studio UI 的 Account & Settings > Access Token 获取
LABEL_STUDIO_USER_TOKEN=your-label-studio-token-here
# Label Studio 本地文件存储基础路径(容器内路径,用于 Docker 部署时的权限检查)
LABEL_STUDIO_LOCAL_BASE=/label-studio/local_files
# Label Studio 本地文件服务路径前缀(任务数据中的文件路径前缀)
LABEL_STUDIO_FILE_PATH_PREFIX=/data/local-files/?d=
# Label Studio 容器中的本地存储路径(用于配置 Local Storage)
LABEL_STUDIO_LOCAL_STORAGE_DATASET_BASE_PATH=/label-studio/local_files/dataset
LABEL_STUDIO_LOCAL_STORAGE_UPLOAD_BASE_PATH=/label-studio/local_files/upload
# Label Studio 任务列表分页大小
LS_TASK_PAGE_SIZE=1000
# =========================
# Data Management 服务配置
# =========================
# DM 服务地址
DM_SERVICE_BASE_URL=http://data-engine:8080
# DM 存储文件夹前缀(通常与 Label Studio 的 local-files 文件夹映射一致)
DM_FILE_PATH_PREFIX=/
# =========================
# Adapter 数据库配置 (MySQL)
# =========================
# 优先级1:如果配置了 MySQL,将优先使用 MySQL 数据库
MYSQL_HOST=adapter-db
MYSQL_PORT=3306
MYSQL_USER=label_studio_user
MYSQL_PASSWORD=user_password
MYSQL_DATABASE=label_studio_adapter
# =========================
# Label Studio 数据库配置 (PostgreSQL)
# =========================
# 仅在使用 docker-compose.label-studio.yml 启动 Label Studio 时需要配置
POSTGRES_HOST=label-studio-db
POSTGRES_PORT=5432
POSTGRES_USER=labelstudio
POSTGRES_PASSWORD=labelstudio@4321
POSTGRES_DATABASE=labelstudio
# =========================
# SQLite 数据库配置(兜底选项)
# =========================
# 优先级3:如果没有配置 MySQL/PostgreSQL,将使用 SQLite
SQLITE_PATH=./data/labelstudio_adapter.db
# =========================
# 可选:直接指定数据库 URL
# =========================
# 如果设置了此项,将覆盖上面的 MySQL/PostgreSQL/SQLite 配置
# DATABASE_URL=postgresql+asyncpg://user:password@host:port/database
# =========================
# 安全配置
# =========================
# 密钥(生产环境务必修改)
SECRET_KEY=your-secret-key-change-this-in-production
# Token 过期时间(分钟)
ACCESS_TOKEN_EXPIRE_MINUTES=30
# =========================
# CORS 配置
# =========================
# 允许的来源(生产环境建议配置具体域名)
ALLOWED_ORIGINS=["*"]
# 允许的 HTTP 方法
ALLOWED_METHODS=["*"]
# 允许的请求头
ALLOWED_HEADERS=["*"]
# =========================
# Docker Compose 配置
# =========================
# Docker Compose 项目名称前缀
COMPOSE_PROJECT_NAME=ls-adapter
# =========================
# 同步配置(未来扩展)
# =========================
# 批量同步任务的批次大小
SYNC_BATCH_SIZE=100
# 同步失败时的最大重试次数
MAX_RETRIES=3

View File

@@ -0,0 +1,6 @@
# Local Development Environment Files
.env
.dev.env
# logs
logs/

View File

@@ -0,0 +1,148 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
# sqlalchemy.url = driver://user:pass@localhost/dbname
# 注释掉默认 URL,我们将在 env.py 中从应用配置读取
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -0,0 +1 @@
Generic single-database configuration.

View File

@@ -0,0 +1,145 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy import create_engine, text
from alembic import context
import os
from urllib.parse import quote_plus
# 导入应用配置和模型
from app.core.config import settings
from app.db.database import Base
# 导入所有模型,以便 autogenerate 能够检测到它们
from app.models import dataset_mapping # noqa
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
def ensure_database_and_user():
"""
确保数据库和用户存在
使用 root 用户连接 MySQL,创建数据库和应用用户
"""
# 只在 MySQL 配置时执行
if not settings.mysql_host:
return
mysql_root_password = os.getenv('MYSQL_ROOT_PASSWORD', 'Huawei@123')
# URL 编码密码以处理特殊字符
encoded_password = quote_plus(mysql_root_password)
# 使用 root 用户连接(不指定数据库)
root_url = f"mysql+pymysql://root:{encoded_password}@{settings.mysql_host}:{settings.mysql_port}/"
try:
root_engine = create_engine(root_url, poolclass=pool.NullPool)
with root_engine.connect() as conn:
# 创建数据库(如果不存在)
conn.execute(text(
f"CREATE DATABASE IF NOT EXISTS `{settings.mysql_database}` "
f"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
))
conn.commit()
# 创建用户(如果不存在)- 使用 MySQL 8 默认的 caching_sha2_password
conn.execute(text(
f"CREATE USER IF NOT EXISTS '{settings.mysql_user}'@'%' "
f"IDENTIFIED BY '{settings.mysql_password}'"
))
conn.commit()
# 授予权限
conn.execute(text(
f"GRANT ALL PRIVILEGES ON `{settings.mysql_database}`.* TO '{settings.mysql_user}'@'%'"
))
conn.commit()
# 刷新权限
conn.execute(text("FLUSH PRIVILEGES"))
conn.commit()
root_engine.dispose()
print(f"✓ Database '{settings.mysql_database}' and user '{settings.mysql_user}' are ready")
except Exception as e:
print(f"⚠️ Warning: Could not ensure database and user: {e}")
print(f" This may be expected if database already exists or permissions are set")
# 从应用配置设置数据库 URL
config.set_main_option('sqlalchemy.url', settings.sync_database_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
# 先确保数据库和用户存在
ensure_database_and_user()
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,41 @@
"""Initiation
Revision ID: 755dc1afb8ad
Revises:
Create Date: 2025-10-20 19:34:20.258554
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '755dc1afb8ad'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('mapping',
sa.Column('mapping_id', sa.String(length=36), nullable=False),
sa.Column('source_dataset_id', sa.String(length=36), nullable=False, comment='源数据集ID'),
sa.Column('labelling_project_id', sa.String(length=36), nullable=False, comment='标注项目ID'),
sa.Column('labelling_project_name', sa.String(length=255), nullable=True, comment='标注项目名称'),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'),
sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='最后更新时间'),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True, comment='删除时间'),
sa.PrimaryKeyConstraint('mapping_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('mapping')
# ### end Alembic commands ###

View File

@@ -0,0 +1 @@
# app/__init__.py

View File

@@ -0,0 +1,19 @@
"""
API 路由模块
集中管理所有API路由的组织结构
"""
from fastapi import APIRouter
from .system import router as system_router
from .project import project_router
# 创建主API路由器
api_router = APIRouter()
# 注册到主路由器
api_router.include_router(system_router, tags=["系统"])
api_router.include_router(project_router, prefix="/project", tags=["项目"])
# 导出路由器供 main.py 使用
__all__ = ["api_router"]

View File

@@ -0,0 +1,11 @@
"""
标注工程相关API路由模块
"""
from fastapi import APIRouter
project_router = APIRouter()
from . import create
from . import sync
from . import list
from . import delete

View File

@@ -0,0 +1,130 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.clients import get_clients
from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest,
DatasetMappingCreateResponse,
)
from app.schemas import StandardResponse
from app.core.logging import get_logger
from app.core.config import settings
from . import project_router
logger = get_logger(__name__)
@project_router.post("/create", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
async def create_dataset_mapping(
request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db)
):
"""
创建数据集映射
根据指定的DM程序中的数据集,创建Label Studio中的数据集,
在数据库中记录这一关联关系,返回Label Studio数据集的ID
注意:一个数据集可以创建多个标注项目
"""
try:
# 获取全局客户端实例
dm_client_instance, ls_client_instance = get_clients()
service = DatasetMappingService(db)
logger.info(f"Create dataset mapping request: {request.source_dataset_id}")
# 从DM服务获取数据集信息
dataset_info = await dm_client_instance.get_dataset(request.source_dataset_id)
if not dataset_info:
raise HTTPException(
status_code=404,
detail=f"Dataset not found in DM service: {request.source_dataset_id}"
)
# 确定数据类型(基于数据集类型)
data_type = "image" # 默认值
if dataset_info.type and dataset_info.type.code:
type_code = dataset_info.type.code.lower()
if "audio" in type_code:
data_type = "audio"
elif "video" in type_code:
data_type = "video"
elif "text" in type_code:
data_type = "text"
# 生成项目名称
project_name = f"{dataset_info.name}"
# 在Label Studio中创建项目
project_data = await ls_client_instance.create_project(
title=project_name,
description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}",
data_type=data_type
)
if not project_data:
raise HTTPException(
status_code=500,
detail="Fail to create Label Studio project."
)
project_id = project_data["id"]
# 配置本地存储:dataset/<id>
local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.source_dataset_id}"
storage_result = await ls_client_instance.create_local_storage(
project_id=project_id,
path=local_storage_path,
title="Dataset_BLOB",
use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}"
)
# 配置本地存储:upload
local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}"
storage_result = await ls_client_instance.create_local_storage(
project_id=project_id,
path=local_storage_path,
title="Upload_BLOB",
use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}"
)
if not storage_result:
# 本地存储配置失败,记录警告但不中断流程
logger.warning(f"Failed to configure local storage for project {project_id}")
else:
logger.info(f"Local storage configured for project {project_id}: {local_storage_path}")
# 创建映射关系,包含项目名称
mapping = await service.create_mapping(
request,
str(project_id),
project_name
)
logger.debug(
f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.source_dataset_id} <> L {mapping.labelling_project_id}"
)
response_data = DatasetMappingCreateResponse(
mapping_id=mapping.mapping_id,
labelling_project_id=mapping.labelling_project_id,
labelling_project_name=mapping.labelling_project_name or project_name,
message="Dataset mapping created successfully"
)
return StandardResponse(
code=201,
message="success",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,106 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.clients import get_clients
from app.schemas.dataset_mapping import DeleteDatasetResponse
from app.schemas import StandardResponse
from app.core.logging import get_logger
from . import project_router
logger = get_logger(__name__)
@project_router.delete("/mappings", response_model=StandardResponse[DeleteDatasetResponse])
async def delete_mapping(
m: Optional[str] = Query(None, description="映射UUID"),
proj: Optional[str] = Query(None, description="Label Studio项目ID"),
db: AsyncSession = Depends(get_db)
):
"""
删除映射关系和对应的 Label Studio 项目
可以通过以下任一方式指定要删除的映射:
- m: 映射UUID
- proj: Label Studio项目ID
- 两者都提供(优先使用 m)
此操作会:
1. 删除 Label Studio 中的项目
2. 软删除数据库中的映射记录
"""
try:
# 至少需要提供一个参数
if not m and not proj:
raise HTTPException(
status_code=400,
detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided"
)
# 获取全局客户端实例
dm_client_instance, ls_client_instance = get_clients()
service = DatasetMappingService(db)
mapping = None
# 优先使用 mapping_id 查询
if m:
logger.info(f"Deleting by mapping UUID: {m}")
mapping = await service.get_mapping_by_uuid(m)
# 如果没有提供 m,使用 proj 查询
elif proj:
logger.info(f"Deleting by project ID: {proj}")
mapping = await service.get_mapping_by_labelling_project_id(proj)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found"
)
mapping_id = mapping.mapping_id
labelling_project_id = mapping.labelling_project_id
labelling_project_name = mapping.labelling_project_name
logger.info(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}")
# 1. 删除 Label Studio 项目
try:
delete_success = await ls_client_instance.delete_project(int(labelling_project_id))
if delete_success:
logger.info(f"Successfully deleted Label Studio project: {labelling_project_id}")
else:
logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}")
except Exception as e:
logger.error(f"Error deleting Label Studio project: {e}")
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
# 2. 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(mapping_id)
if not soft_delete_success:
raise HTTPException(
status_code=500,
detail="Failed to delete mapping record"
)
logger.info(f"Successfully deleted mapping: {mapping_id}")
response_data = DeleteDatasetResponse(
mapping_id=mapping_id,
status="success",
message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'"
)
return StandardResponse(
code=200,
message="success",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,110 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import DatasetMappingResponse
from app.schemas import StandardResponse
from app.core.logging import get_logger
from . import project_router
logger = get_logger(__name__)
@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]])
async def list_mappings(
skip: int = Query(0, ge=0, description="Number of records to skip"),
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
db: AsyncSession = Depends(get_db)
):
"""
查询所有映射关系
返回所有有效的数据集映射关系(未被软删除的)
"""
try:
service = DatasetMappingService(db)
logger.info(f"Listing mappings, skip={skip}, limit={limit}")
mappings = await service.get_all_mappings(skip=skip, limit=limit)
logger.info(f"Found {len(mappings)} mappings")
return StandardResponse(
code=200,
message="success",
data=mappings
)
except Exception as e:
logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据 UUID 查询单个映射关系
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mapping: {mapping_id}")
mapping = await service.get_mapping_by_uuid(mapping_id)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {mapping_id}"
)
logger.info(f"Found mapping: {mapping.mapping_id}")
return StandardResponse(
code=200,
message="success",
data=mapping
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]])
async def get_mappings_by_source(
source_dataset_id: str,
db: AsyncSession = Depends(get_db)
):
"""
根据源数据集 ID 查询所有映射关系
返回该数据集创建的所有标注项目(包括已删除的)
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mappings by source dataset id: {source_dataset_id}")
mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id)
logger.info(f"Found {len(mappings)} mappings")
return StandardResponse(
code=200,
message="success",
data=mappings
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,68 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.services.sync_service import SyncService
from app.clients import get_clients
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
from app.schemas.dataset_mapping import (
DatasetMappingResponse,
SyncDatasetRequest,
SyncDatasetResponse,
)
from app.schemas import StandardResponse
from app.core.logging import get_logger
from . import project_router
logger = get_logger(__name__)
@project_router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content(
request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db)
):
"""
同步数据集内容
根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中,
在数据库中记录更新时间,返回更新状态
"""
try:
dm_client_instance, ls_client_instance = get_clients()
mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client_instance, ls_client_instance, mapping_service)
logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}")
# 根据 mapping_id 获取映射关系
mapping = await mapping_service.get_mapping_by_uuid(request.mapping_id)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {request.mapping_id}"
)
# 执行同步(使用映射中的源数据集UUID)
result = await sync_service.sync_dataset_files(request.mapping_id, request.batch_size)
logger.info(f"Sync completed: {result.synced_files}/{result.total_files} files")
return StandardResponse(
code=200,
message="success",
data=result
)
except HTTPException:
raise
except NoDatasetInfoFoundError as e:
logger.error(f"Failed to get dataset info: {e}")
raise HTTPException(status_code=404, detail=str(e))
except DatasetMappingNotFoundError as e:
logger.error(f"Mapping not found: {e}")
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Error syncing dataset content: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -0,0 +1,34 @@
from fastapi import APIRouter
from typing import Dict, Any
from app.core.config import settings
from app.schemas import StandardResponse
router = APIRouter()
@router.get("/health", response_model=StandardResponse[Dict[str, Any]])
async def health_check():
"""健康检查端点"""
return StandardResponse(
code=200,
message="success",
data={
"status": "healthy",
"service": "Label Studio Adapter",
"version": settings.app_version
}
)
@router.get("/config", response_model=StandardResponse[Dict[str, Any]])
async def get_config():
"""获取配置信息"""
return StandardResponse(
code=200,
message="success",
data={
"app_name": settings.app_name,
"version": settings.app_version,
"dm_service_url": settings.dm_service_base_url,
"label_studio_url": settings.label_studio_base_url,
"debug": settings.debug
}
)

View File

@@ -0,0 +1,8 @@
# app/clients/__init__.py
from .dm_client import DMServiceClient
from .label_studio_client import LabelStudioClient
from .client_manager import get_clients, set_clients, get_dm_client, get_ls_client
__all__ = ["DMServiceClient", "LabelStudioClient", "get_clients", "set_clients", "get_dm_client", "get_ls_client"]

View File

@@ -0,0 +1,34 @@
from typing import Optional
from fastapi import HTTPException
from .dm_client import DMServiceClient
from .label_studio_client import LabelStudioClient
# 全局客户端实例(将在main.py中初始化)
dm_client: Optional[DMServiceClient] = None
ls_client: Optional[LabelStudioClient] = None
def get_clients() -> tuple[DMServiceClient, LabelStudioClient]:
"""获取客户端实例"""
global dm_client, ls_client
if not dm_client or not ls_client:
raise HTTPException(status_code=500, detail="客户端未初始化")
return dm_client, ls_client
def set_clients(dm_client_instance: DMServiceClient, ls_client_instance: LabelStudioClient) -> None:
"""设置全局客户端实例"""
global dm_client, ls_client
dm_client = dm_client_instance
ls_client = ls_client_instance
def get_dm_client() -> DMServiceClient:
"""获取DM服务客户端"""
if not dm_client:
raise HTTPException(status_code=500, detail="DM客户端未初始化")
return dm_client
def get_ls_client() -> LabelStudioClient:
"""获取Label Studio客户端"""
if not ls_client:
raise HTTPException(status_code=500, detail="Label Studio客户端未初始化")
return ls_client

View File

@@ -0,0 +1,138 @@
import httpx
from typing import Optional
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.dm_service import DatasetResponse, PagedDatasetFileResponse
logger = get_logger(__name__)
class DMServiceClient:
"""数据管理服务客户端"""
def __init__(self, base_url: str|None = None, timeout: float = 30.0):
self.base_url = base_url or settings.dm_service_base_url
self.timeout = timeout
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout
)
logger.info(f"Initialize DM service client, base url: {self.base_url}")
@staticmethod
def _unwrap_payload(data):
"""Unwrap common envelope shapes like {'code': ..., 'message': ..., 'data': {...}}."""
if isinstance(data, dict) and 'data' in data and isinstance(data['data'], (dict, list)):
return data['data']
return data
@staticmethod
def _is_error_payload(data) -> bool:
"""Detect error-shaped payloads returned with HTTP 200."""
if not isinstance(data, dict):
return False
# Common patterns: {error, message, ...} or {code, message, ...} without data
if 'error' in data and 'message' in data:
return True
if 'code' in data and 'message' in data and 'data' not in data:
return True
return False
@staticmethod
def _keys(d):
return list(d.keys()) if isinstance(d, dict) else []
async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]:
"""获取数据集详情"""
try:
logger.info(f"Getting dataset detail: {dataset_id} ...")
response = await self.client.get(f"/data-management/datasets/{dataset_id}")
response.raise_for_status()
raw = response.json()
data = self._unwrap_payload(raw)
if self._is_error_payload(data):
logger.error(f"DM service returned error for dataset {dataset_id}: {data}")
return None
if not isinstance(data, dict):
logger.error(f"Unexpected dataset payload type for {dataset_id}: {type(data).__name__}")
return None
required = ["id", "name", "description", "datasetType", "status", "fileCount", "totalSize"]
if not all(k in data for k in required):
logger.error(f"Dataset payload missing required fields for {dataset_id}. Keys: {self._keys(data)}")
return None
return DatasetResponse(**data)
except httpx.HTTPError as e:
logger.error(f"Failed to get dataset {dataset_id}: {e}")
return None
except Exception as e:
logger.error(f"[Unexpected] [GET] dataset {dataset_id}: \n{e}\nRaw JSON received: \n{raw}")
return None
async def get_dataset_files(
self,
dataset_id: str,
page: int = 0,
size: int = 100,
file_type: Optional[str] = None,
status: Optional[str] = None
) -> Optional[PagedDatasetFileResponse]:
"""获取数据集文件列表"""
try:
logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}")
params: dict = {"page": page, "size": size}
if file_type:
params["fileType"] = file_type
if status:
params["status"] = status
response = await self.client.get(
f"/data-management/datasets/{dataset_id}/files",
params=params
)
response.raise_for_status()
raw = response.json()
data = self._unwrap_payload(raw)
if self._is_error_payload(data):
logger.error(f"DM service returned error for dataset files {dataset_id}: {data}")
return None
if not isinstance(data, dict):
logger.error(f"Unexpected dataset files payload type for {dataset_id}: {type(data).__name__}")
return None
required = ["content", "totalElements", "totalPages", "page", "size"]
if not all(k in data for k in required):
logger.error(f"Files payload missing required fields for {dataset_id}. Keys: {self._keys(data)}")
return None
return PagedDatasetFileResponse(**data)
except httpx.HTTPError as e:
logger.error(f"Failed to get dataset files for {dataset_id}: {e}")
return None
except Exception as e:
logger.error(f"[Unexpected] [GET] dataset files {dataset_id}: \n{e}\nRaw JSON received: \n{raw}")
return None
async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]:
"""下载文件内容"""
try:
logger.info(f"Download file: dataset={dataset_id}, file={file_id}")
response = await self.client.get(
f"/data-management/datasets/{dataset_id}/files/{file_id}/download"
)
response.raise_for_status()
return response.content
except httpx.HTTPError as e:
logger.error(f"Failed to download file {file_id}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected error while downloading file {file_id}: {e}")
return None
async def get_file_download_url(self, dataset_id: str, file_id: str) -> str:
"""获取文件下载URL"""
return f"{self.base_url}/data-management/datasets/{dataset_id}/files/{file_id}/download"
async def close(self):
"""关闭客户端连接"""
await self.client.aclose()
logger.info("DM service client connection closed")

View File

@@ -0,0 +1,469 @@
import httpx
from typing import Optional, Dict, Any, List
import json
from app.core.config import settings
from app.core.logging import get_logger
from app.schemas.label_studio import (
LabelStudioProject,
LabelStudioCreateProjectRequest,
LabelStudioCreateTaskRequest
)
logger = get_logger(__name__)
class LabelStudioClient:
"""Label Studio服务客户端
使用 HTTP REST API 直接与 Label Studio 交互
认证方式:使用 Authorization: Token {token} 头部进行认证
"""
# 默认标注配置模板
DEFAULT_LABEL_CONFIGS = {
"image": """
<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
<Label value="Object" background="red"/>
</RectangleLabels>
</View>
""",
"text": """
<View>
<Text name="text" value="$text"/>
<Choices name="sentiment" toName="text">
<Choice value="positive"/>
<Choice value="negative"/>
<Choice value="neutral"/>
</Choices>
</View>
""",
"audio": """
<View>
<Audio name="audio" value="$audio"/>
<AudioRegionLabels name="label" toName="audio">
<Label value="Speech" background="red"/>
<Label value="Noise" background="blue"/>
</AudioRegionLabels>
</View>
""",
"video": """
<View>
<Video name="video" value="$video"/>
<VideoRegionLabels name="label" toName="video">
<Label value="Action" background="red"/>
</VideoRegionLabels>
</View>
"""
}
def __init__(
self,
base_url: Optional[str] = None,
token: Optional[str] = None,
timeout: float = 30.0
):
"""初始化 Label Studio 客户端
Args:
base_url: Label Studio 服务地址
token: API Token(使用 Authorization: Token {token} 头部)
timeout: 请求超时时间(秒)
"""
self.base_url = (base_url or settings.label_studio_base_url).rstrip("/")
self.token = token or settings.label_studio_user_token
self.timeout = timeout
if not self.token:
raise ValueError("Label Studio API token is required")
# 初始化 HTTP 客户端
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout,
headers={
"Authorization": f"Token {self.token}",
"Content-Type": "application/json"
}
)
logger.info(f"Label Studio client initialized: {self.base_url}")
def get_label_config_by_type(self, data_type: str) -> str:
"""根据数据类型获取标注配置"""
return self.DEFAULT_LABEL_CONFIGS.get(data_type.lower(), self.DEFAULT_LABEL_CONFIGS["image"])
async def create_project(
self,
title: str,
description: str = "",
label_config: Optional[str] = None,
data_type: str = "image"
) -> Optional[Dict[str, Any]]:
"""创建Label Studio项目"""
try:
logger.info(f"Creating Label Studio project: {title}")
if not label_config:
label_config = self.get_label_config_by_type(data_type)
project_data = {
"title": title,
"description": description,
"label_config": label_config.strip()
}
response = await self.client.post("/api/projects", json=project_data)
response.raise_for_status()
project = response.json()
project_id = project.get("id")
if not project_id:
raise Exception("Label Studio response does not contain project ID")
logger.info(f"Project created successfully, ID: {project_id}")
return project
except httpx.HTTPStatusError as e:
logger.error(f"Create project failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while creating Label Studio project: {e}")
return None
async def import_tasks(
self,
project_id: int,
tasks: List[Dict[str, Any]],
commit_to_project: bool = True,
return_task_ids: bool = True
) -> Optional[Dict[str, Any]]:
"""批量导入任务到Label Studio项目"""
try:
logger.info(f"Importing {len(tasks)} tasks into project {project_id}")
response = await self.client.post(
f"/api/projects/{project_id}/import",
json=tasks,
params={
"commit_to_project": str(commit_to_project).lower(),
"return_task_ids": str(return_task_ids).lower()
}
)
response.raise_for_status()
result = response.json()
task_count = result.get("task_count", len(tasks))
logger.info(f"Tasks imported successfully: {task_count}")
return result
except httpx.HTTPStatusError as e:
logger.error(f"Import tasks failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while importing tasks: {e}")
return None
async def create_tasks_batch(
self,
project_id: str,
tasks: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""批量创建任务的便利方法"""
try:
pid = int(project_id)
return await self.import_tasks(pid, tasks)
except ValueError as e:
logger.error(f"Invalid project ID format: {project_id}, error: {e}")
return None
except Exception as e:
logger.error(f"Error while creating tasks in batch: {e}")
return None
async def create_task(
self,
project_id: str,
data: Dict[str, Any],
meta: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""创建单个任务"""
try:
task = {"data": data}
if meta:
task["meta"] = meta
return await self.create_tasks_batch(project_id, [task])
except Exception as e:
logger.error(f"Error while creating single task: {e}")
return None
async def get_project_tasks(
self,
project_id: str,
page: Optional[int] = None,
page_size: int = 1000
) -> Optional[Dict[str, Any]]:
"""获取项目任务信息
Args:
project_id: 项目ID
page: 页码(从1开始)。如果为None,则获取所有任务
page_size: 每页大小
Returns:
如果指定了page参数,返回包含分页信息的字典:
{
"count": 总任务数,
"page": 当前页码,
"page_size": 每页大小,
"project_id": 项目ID,
"tasks": 当前页的任务列表
}
如果page为None,返回包含所有任务的字典:
"count": 总任务数,
"project_id": 项目ID,
"tasks": 所有任务列表
}
"""
try:
pid = int(project_id)
# 如果指定了page,直接获取单页任务
if page is not None:
logger.info(f"Fetching tasks for project {pid}, page {page} (page_size={page_size})")
response = await self.client.get(
f"/api/projects/{pid}/tasks",
params={
"page": page,
"page_size": page_size
}
)
response.raise_for_status()
result = response.json()
# 返回单页结果,包含分页信息
return {
"count": result.get("total", len(result.get("tasks", []))),
"page": page,
"page_size": page_size,
"project_id": pid,
"tasks": result.get("tasks", [])
}
# 如果未指定page,获取所有任务
logger.info(f"Start fetching all tasks for project {pid} (page_size={page_size})")
all_tasks = []
current_page = 1
while True:
try:
response = await self.client.get(
f"/api/projects/{pid}/tasks",
params={
"page": current_page,
"page_size": page_size
}
)
response.raise_for_status()
result = response.json()
tasks = result.get("tasks", [])
if not tasks:
logger.debug(f"No more tasks on page {current_page}")
break
all_tasks.extend(tasks)
logger.debug(f"Fetched page {current_page}, {len(tasks)} tasks")
# 检查是否还有更多页
total = result.get("total", 0)
if len(all_tasks) >= total:
break
current_page += 1
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
# 超出页数范围,结束分页
logger.debug(f"Reached last page (page {current_page})")
break
else:
raise
logger.info(f"Fetched all tasks for project {pid}, total {len(all_tasks)}")
# 返回所有任务,不包含分页信息
return {
"count": len(all_tasks),
"project_id": pid,
"tasks": all_tasks
}
except httpx.HTTPStatusError as e:
logger.error(f"获取项目任务失败 HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"获取项目任务时发生错误: {e}")
return None
async def delete_task(
self,
task_id: int
) -> bool:
"""删除单个任务"""
try:
logger.info(f"Deleting task: {task_id}")
response = await self.client.delete(f"/api/tasks/{task_id}")
response.raise_for_status()
logger.info(f"Task deleted: {task_id}")
return True
except httpx.HTTPStatusError as e:
logger.error(f"Delete task {task_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False
except Exception as e:
logger.error(f"Error while deleting task {task_id}: {e}")
return False
async def delete_tasks_batch(
self,
task_ids: List[int]
) -> Dict[str, int]:
"""批量删除任务"""
try:
logger.info(f"Deleting {len(task_ids)} tasks in batch")
successful_deletions = 0
failed_deletions = 0
for task_id in task_ids:
if await self.delete_task(task_id):
successful_deletions += 1
else:
failed_deletions += 1
logger.info(f"Batch deletion finished: success {successful_deletions}, failed {failed_deletions}")
return {
"successful": successful_deletions,
"failed": failed_deletions,
"total": len(task_ids)
}
except Exception as e:
logger.error(f"Error while deleting tasks in batch: {e}")
return {
"successful": 0,
"failed": len(task_ids),
"total": len(task_ids)
}
async def get_project(self, project_id: int) -> Optional[Dict[str, Any]]:
"""获取项目信息"""
try:
logger.info(f"Fetching project info: {project_id}")
response = await self.client.get(f"/api/projects/{project_id}")
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"Get project info failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while getting project info: {e}")
return None
async def delete_project(self, project_id: int) -> bool:
"""删除项目"""
try:
logger.info(f"Deleting project: {project_id}")
response = await self.client.delete(f"/api/projects/{project_id}")
response.raise_for_status()
logger.info(f"Project deleted: {project_id}")
return True
except httpx.HTTPStatusError as e:
logger.error(f"Delete project {project_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False
except Exception as e:
logger.error(f"Error while deleting project {project_id}: {e}")
return False
async def create_local_storage(
self,
project_id: int,
path: str,
title: str,
use_blob_urls: bool = True,
regex_filter: Optional[str] = None,
description: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""创建本地文件存储配置
Args:
project_id: Label Studio 项目 ID
path: 本地文件路径(在 Label Studio 容器中的路径)
title: 存储配置标题
use_blob_urls: 是否使用 blob URLs(建议 True)
regex_filter: 文件过滤正则表达式(可选)
description: 存储描述(可选)
Returns:
创建的存储配置信息,失败返回 None
"""
try:
logger.info(f"Creating local storage for project {project_id}: {path}")
storage_data = {
"project": project_id,
"path": path,
"title": title,
"use_blob_urls": use_blob_urls
}
if regex_filter:
storage_data["regex_filter"] = regex_filter
if description:
storage_data["description"] = description
response = await self.client.post(
"/api/storages/localfiles/",
json=storage_data
)
response.raise_for_status()
storage = response.json()
storage_id = storage.get("id")
logger.info(f"Local storage created successfully, ID: {storage_id}")
return storage
except httpx.HTTPStatusError as e:
logger.error(f"Create local storage failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while creating local storage: {e}")
return None
async def close(self):
"""关闭客户端连接"""
try:
await self.client.aclose()
logger.info("Label Studio client closed")
except Exception as e:
logger.error(f"Error while closing Label Studio client: {e}")

View File

@@ -0,0 +1 @@
# app/core/__init__.py

View File

@@ -0,0 +1,146 @@
from pydantic_settings import BaseSettings
from typing import Optional
import os
from pathlib import Path
class Settings(BaseSettings):
"""应用程序配置"""
class Config:
env_file = ".env"
case_sensitive = False
extra = 'ignore' # 允许额外字段(如 Shell 脚本专用的环境变量)
# =========================
# Adapter 服务配置
# =========================
app_name: str = "Label Studio Adapter"
app_version: str = "1.0.0"
app_description: str = "Adapter for integrating Data Management System with Label Studio"
debug: bool = True
# 服务器配置
host: str = "0.0.0.0"
port: int = 8000
# CORS配置
allowed_origins: list = ["*"]
allowed_methods: list = ["*"]
allowed_headers: list = ["*"]
# MySQL数据库配置 (优先级1)
mysql_host: Optional[str] = None
mysql_port: int = 3306
mysql_user: Optional[str] = None
mysql_password: Optional[str] = None
mysql_database: Optional[str] = None
# PostgreSQL数据库配置 (优先级2)
postgres_host: Optional[str] = None
postgres_port: int = 5432
postgres_user: Optional[str] = None
postgres_password: Optional[str] = None
postgres_database: Optional[str] = None
# SQLite数据库配置 (优先级3 - 兜底)
sqlite_path: str = "data/labelstudio_adapter.db"
# 直接数据库URL配置(如果提供,将覆盖上述配置)
database_url: Optional[str] = None
# 日志配置
log_level: str = "INFO"
# 安全配置
secret_key: str = "your-secret-key-change-this-in-production"
access_token_expire_minutes: int = 30
# =========================
# Label Studio 服务配置
# =========================
label_studio_base_url: str = "http://label-studio:8080"
label_studio_username: Optional[str] = None # Label Studio 用户名(用于登录)
label_studio_password: Optional[str] = None # Label Studio 密码(用于登录)
label_studio_user_token: Optional[str] = None # Legacy Token
label_studio_local_storage_dataset_base_path: str = "/label-studio/local_files/dataset" # Label Studio容器中的本地存储基础路径
label_studio_local_storage_upload_base_path: str = "/label-studio/local_files/upload" # Label Studio容器中的本地存储基础路径
label_studio_file_path_prefix: str = "/data/local-files/?d=" # Label Studio本地文件服务路径前缀
ls_task_page_size: int = 1000
# =========================
# Data Management 服务配置
# =========================
dm_service_base_url: str = "http://data-engine"
dm_file_path_prefix: str = "/" # DM存储文件夹前缀
@property
def computed_database_url(self) -> str:
"""
根据优先级自动选择数据库连接URL
优先级:MySQL > PostgreSQL > SQLite3
"""
# 如果直接提供了database_url,优先使用
if self.database_url:
return self.database_url
# 优先级1: MySQL
if all([self.mysql_host, self.mysql_user, self.mysql_password, self.mysql_database]):
return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}"
# 优先级2: PostgreSQL
if all([self.postgres_host, self.postgres_user, self.postgres_password, self.postgres_database]):
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_database}"
# 优先级3: SQLite (兜底)
sqlite_full_path = Path(self.sqlite_path).absolute()
# 确保目录存在
sqlite_full_path.parent.mkdir(parents=True, exist_ok=True)
return f"sqlite+aiosqlite:///{sqlite_full_path}"
@property
def sync_database_url(self) -> str:
"""
用于数据库迁移的同步连接URL
将异步驱动替换为同步驱动
"""
async_url = self.computed_database_url
# 替换异步驱动为同步驱动
sync_replacements = {
"mysql+aiomysql://": "mysql+pymysql://",
"postgresql+asyncpg://": "postgresql+psycopg2://",
"sqlite+aiosqlite:///": "sqlite:///"
}
for async_driver, sync_driver in sync_replacements.items():
if async_url.startswith(async_driver):
return async_url.replace(async_driver, sync_driver)
return async_url
def get_database_info(self) -> dict:
"""获取数据库配置信息"""
url = self.computed_database_url
if url.startswith("mysql"):
db_type = "MySQL"
elif url.startswith("postgresql"):
db_type = "PostgreSQL"
elif url.startswith("sqlite"):
db_type = "SQLite"
else:
db_type = "Unknown"
return {
"type": db_type,
"url": url,
"sync_url": self.sync_database_url
}
# 全局设置实例
settings = Settings()

View File

@@ -0,0 +1,53 @@
import logging
import sys
from pathlib import Path
from app.core.config import settings
def setup_logging():
"""配置应用程序日志"""
# 创建logs目录
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# 配置日志格式
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
# 创建处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(getattr(logging, settings.log_level.upper()))
file_handler = logging.FileHandler(
log_dir / "app.log",
encoding="utf-8"
)
file_handler.setLevel(getattr(logging, settings.log_level.upper()))
error_handler = logging.FileHandler(
log_dir / "error.log",
encoding="utf-8"
)
error_handler.setLevel(logging.ERROR)
# 设置格式
formatter = logging.Formatter(log_format, date_format)
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
error_handler.setFormatter(formatter)
# 配置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.log_level.upper()))
root_logger.addHandler(console_handler)
root_logger.addHandler(file_handler)
root_logger.addHandler(error_handler)
# 配置第三方库日志级别(减少详细日志)
logging.getLogger("uvicorn").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR) # 隐藏SQL查询日志
logging.getLogger("httpx").setLevel(logging.WARNING)
def get_logger(name: str) -> logging.Logger:
"""获取指定名称的日志器"""
return logging.getLogger(name)

View File

@@ -0,0 +1 @@
# app/db/__init__.py

View File

@@ -0,0 +1,39 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base
from app.core.config import settings
from app.core.logging import get_logger
from typing import AsyncGenerator
logger = get_logger(__name__)
# 获取数据库配置信息
db_info = settings.get_database_info()
logger.info(f"使用数据库: {db_info['type']}")
logger.info(f"连接URL: {db_info['url']}")
# 创建数据库引擎
engine = create_async_engine(
settings.computed_database_url,
echo=False, # 关闭SQL调试日志以减少输出
future=True,
# SQLite特殊配置
connect_args={"check_same_thread": False} if "sqlite" in settings.computed_database_url else {}
)
# 创建会话工厂
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
# 创建基础模型类
Base = declarative_base()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()

View File

@@ -0,0 +1,31 @@
"""
自定义异常类定义
"""
class LabelStudioAdapterException(Exception):
"""Label Studio Adapter 基础异常类"""
pass
class DatasetMappingNotFoundError(LabelStudioAdapterException):
"""数据集映射未找到异常"""
def __init__(self, mapping_id: str):
self.mapping_id = mapping_id
super().__init__(f"Dataset mapping not found: {mapping_id}")
class NoDatasetInfoFoundError(LabelStudioAdapterException):
"""无法获取数据集信息异常"""
def __init__(self, dataset_uuid: str):
self.dataset_uuid = dataset_uuid
super().__init__(f"Failed to get dataset info: {dataset_uuid}")
class LabelStudioClientError(LabelStudioAdapterException):
"""Label Studio 客户端错误"""
pass
class DMServiceClientError(LabelStudioAdapterException):
"""DM 服务客户端错误"""
pass
class SyncServiceError(LabelStudioAdapterException):
"""同步服务错误"""
pass

View File

@@ -0,0 +1,172 @@
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from contextlib import asynccontextmanager
from typing import Dict, Any
from .core.config import settings
from .core.logging import setup_logging, get_logger
from .clients import DMServiceClient, LabelStudioClient, set_clients
from .api import api_router
from .schemas import StandardResponse
# 设置日志
setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用程序生命周期管理"""
# 启动时初始化
logger.info("Starting Label Studio Adapter...")
# 初始化客户端
dm_client = DMServiceClient()
# 初始化 Label Studio 客户端,使用 HTTP REST API + Token 认证
ls_client = LabelStudioClient(
base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token
)
# 设置全局客户端
set_clients(dm_client, ls_client)
# 数据库初始化由 Alembic 管理
# 在 Docker 环境中,entrypoint.sh 会在启动前运行: alembic upgrade head
# 在开发环境中,手动运行: alembic upgrade head
logger.info("Database schema managed by Alembic")
logger.info("Label Studio Adapter started")
yield
# 关闭时清理
logger.info("Shutting down Label Studio Adapter...")
# 客户端清理会在客户端管理器中处理
logger.info("Label Studio Adapter stopped")
# 创建FastAPI应用
app = FastAPI(
title=settings.app_name,
description=settings.app_description,
version=settings.app_version,
debug=settings.debug,
lifespan=lifespan
)
# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins,
allow_credentials=True,
allow_methods=settings.allowed_methods,
allow_headers=settings.allowed_headers,
)
# 自定义异常处理器:StarletteHTTPException (包括404等)
@app.exception_handler(StarletteHTTPException)
async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException):
"""将Starlette的HTTPException转换为标准响应格式"""
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": "error",
"data": {
"detail": exc.detail
}
}
)
# 自定义异常处理器:FastAPI HTTPException
@app.exception_handler(HTTPException)
async def fastapi_http_exception_handler(request: Request, exc: HTTPException):
"""将FastAPI的HTTPException转换为标准响应格式"""
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": "error",
"data": {
"detail": exc.detail
}
}
)
# 自定义异常处理器:RequestValidationError
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""将请求验证错误转换为标准响应格式"""
return JSONResponse(
status_code=422,
content={
"code": 422,
"message": "error",
"data": {
"detail": "Validation error",
"errors": exc.errors()
}
}
)
# 自定义异常处理器:未捕获的异常
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""将未捕获的异常转换为标准响应格式"""
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"code": 500,
"message": "error",
"data": {
"detail": "Internal server error"
}
}
)
# 注册路由
app.include_router(api_router, prefix="/api")
# 测试端点:验证异常处理
@app.get("/test-404", include_in_schema=False)
async def test_404():
"""测试404异常处理"""
raise HTTPException(status_code=404, detail="Test 404 error")
@app.get("/test-500", include_in_schema=False)
async def test_500():
"""测试500异常处理"""
raise Exception("Test uncaught exception")
# 根路径重定向到文档
@app.get("/", response_model=StandardResponse[Dict[str, Any]], include_in_schema=False)
async def root():
"""根路径,返回服务信息"""
return StandardResponse(
code=200,
message="success",
data={
"message": f"{settings.app_name} is running",
"version": settings.app_version,
"docs_url": "/docs",
"dm_service_url": settings.dm_service_base_url,
"label_studio_url": settings.label_studio_base_url
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
log_level=settings.log_level.lower()
)

View File

@@ -0,0 +1,5 @@
# app/models/__init__.py
from .dataset_mapping import DatasetMapping
__all__ = ["DatasetMapping"]

View File

@@ -0,0 +1,25 @@
from sqlalchemy import Column, String, DateTime, Boolean, Text
from sqlalchemy.sql import func
from app.db.database import Base
import uuid
class DatasetMapping(Base):
"""数据集映射模型"""
__tablename__ = "mapping"
mapping_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
source_dataset_id = Column(String(36), nullable=False, comment="源数据集ID")
labelling_project_id = Column(String(36), nullable=False, comment="标注项目ID")
labelling_project_name = Column(String(255), nullable=True, comment="标注项目名称")
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
last_updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后更新时间")
deleted_at = Column(DateTime(timezone=True), nullable=True, comment="删除时间")
def __repr__(self):
return f"<DatasetMapping(uuid={self.mapping_id}, source={self.source_dataset_id}, labelling={self.labelling_project_id})>"
@property
def is_deleted(self) -> bool:
"""检查是否已被软删除"""
return self.deleted_at is not None

View File

@@ -0,0 +1,29 @@
# app/schemas/__init__.py
from .common import *
from .dataset_mapping import *
from .dm_service import *
from .label_studio import *
__all__ = [
# Common schemas
"StandardResponse",
# Dataset Mapping schemas
"DatasetMappingBase",
"DatasetMappingCreateRequest",
"DatasetMappingUpdateRequest",
"DatasetMappingResponse",
"DatasetMappingCreateResponse",
"SyncDatasetResponse",
"DeleteDatasetResponse",
# DM Service schemas
"DatasetFileResponse",
"PagedDatasetFileResponse",
"DatasetResponse",
# Label Studio schemas
"LabelStudioProject",
"LabelStudioTask"
]

View File

@@ -0,0 +1,27 @@
"""
通用响应模型
"""
from typing import Generic, TypeVar, Optional
from pydantic import BaseModel, Field
# 定义泛型类型变量
T = TypeVar('T')
class StandardResponse(BaseModel, Generic[T]):
"""
标准API响应格式
所有API端点应返回此格式,确保响应的一致性
"""
code: int = Field(..., description="HTTP状态码")
message: str = Field(..., description="响应消息")
data: Optional[T] = Field(None, description="响应数据")
class Config:
json_schema_extra = {
"example": {
"code": 200,
"message": "success",
"data": {}
}
}

View File

@@ -0,0 +1,53 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class DatasetMappingBase(BaseModel):
"""数据集映射 基础模型"""
source_dataset_id: str = Field(..., description="源数据集ID")
class DatasetMappingCreateRequest(DatasetMappingBase):
"""数据集映射 创建 请求模型"""
pass
class DatasetMappingCreateResponse(BaseModel):
"""数据集映射 创建 响应模型"""
mapping_id: str = Field(..., description="映射UUID")
labelling_project_id: str = Field(..., description="Label Studio项目ID")
labelling_project_name: str = Field(..., description="Label Studio项目名称")
message: str = Field(..., description="响应消息")
class DatasetMappingUpdateRequest(BaseModel):
"""数据集映射 更新 请求模型"""
source_dataset_id: Optional[str] = Field(None, description="源数据集ID")
class DatasetMappingResponse(DatasetMappingBase):
"""数据集映射 查询 响应模型"""
mapping_id: str = Field(..., description="映射UUID")
labelling_project_id: str = Field(..., description="标注项目ID")
labelling_project_name: Optional[str] = Field(None, description="标注项目名称")
created_at: datetime = Field(..., description="创建时间")
last_updated_at: datetime = Field(..., description="最后更新时间")
deleted_at: Optional[datetime] = Field(None, description="删除时间")
class Config:
from_attributes = True
class SyncDatasetRequest(BaseModel):
"""同步数据集请求模型"""
mapping_id: str = Field(..., description="映射ID(mapping UUID)")
batch_size: int = Field(50, ge=1, le=100, description="批处理大小")
class SyncDatasetResponse(BaseModel):
"""同步数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="同步状态")
synced_files: int = Field(..., description="已同步文件数量")
total_files: int = Field(0, description="总文件数量")
message: str = Field(..., description="响应消息")
class DeleteDatasetResponse(BaseModel):
"""删除数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="删除状态")
message: str = Field(..., description="响应消息")

View File

@@ -0,0 +1,58 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
class DatasetFileResponse(BaseModel):
"""DM服务数据集文件响应模型"""
id: str = Field(..., description="文件ID")
fileName: str = Field(..., description="文件名")
fileType: str = Field(..., description="文件类型")
filePath: str = Field(..., description="文件路径")
originalName: Optional[str] = Field(None, description="原始文件名")
size: Optional[int] = Field(None, description="文件大小(字节)")
status: Optional[str] = Field(None, description="文件状态")
uploadedAt: Optional[datetime] = Field(None, description="上传时间")
description: Optional[str] = Field(None, description="文件描述")
uploadedBy: Optional[str] = Field(None, description="上传者")
lastAccessTime: Optional[datetime] = Field(None, description="最后访问时间")
class PagedDatasetFileResponse(BaseModel):
"""DM服务分页文件响应模型"""
content: List[DatasetFileResponse] = Field(..., description="文件列表")
totalElements: int = Field(..., description="总元素数")
totalPages: int = Field(..., description="总页数")
page: int = Field(..., description="当前页码")
size: int = Field(..., description="每页大小")
class DatasetTypeResponse(BaseModel):
"""数据集类型响应模型"""
code: str = Field(..., description="类型编码")
name: str = Field(..., description="类型名称")
description: Optional[str] = Field(None, description="类型描述")
supportedFormats: List[str] = Field(default_factory=list, description="支持的文件格式")
icon: Optional[str] = Field(None, description="图标")
class DatasetResponse(BaseModel):
"""DM服务数据集响应模型"""
id: str = Field(..., description="数据集ID")
name: str = Field(..., description="数据集名称")
description: Optional[str] = Field(None, description="数据集描述")
datasetType: str = Field(..., description="数据集类型", alias="datasetType")
status: str = Field(..., description="数据集状态")
fileCount: int = Field(..., description="文件数量")
totalSize: int = Field(..., description="总大小(字节)")
createdAt: Optional[datetime] = Field(None, description="创建时间")
updatedAt: Optional[datetime] = Field(None, description="更新时间")
createdBy: Optional[str] = Field(None, description="创建者")
# 为了向后兼容,添加一个属性方法返回类型对象
@property
def type(self) -> DatasetTypeResponse:
"""兼容属性:返回类型对象"""
return DatasetTypeResponse(
code=self.datasetType,
name=self.datasetType,
description=None,
supportedFormats=[],
icon=None
)

View File

@@ -0,0 +1,37 @@
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, List
from datetime import datetime
class LabelStudioProject(BaseModel):
"""Label Studio项目模型"""
id: int = Field(..., description="项目ID")
title: str = Field(..., description="项目标题")
description: Optional[str] = Field(None, description="项目描述")
label_config: str = Field(..., description="标注配置")
created_at: Optional[datetime] = Field(None, description="创建时间")
updated_at: Optional[datetime] = Field(None, description="更新时间")
class LabelStudioTaskData(BaseModel):
"""Label Studio任务数据模型"""
image: Optional[str] = Field(None, description="图像URL")
text: Optional[str] = Field(None, description="文本内容")
audio: Optional[str] = Field(None, description="音频URL")
video: Optional[str] = Field(None, description="视频URL")
filename: Optional[str] = Field(None, description="文件名")
class LabelStudioTask(BaseModel):
"""Label Studio任务模型"""
data: LabelStudioTaskData = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")
meta: Optional[Dict[str, Any]] = Field(None, description="元数据")
class LabelStudioCreateProjectRequest(BaseModel):
"""创建Label Studio项目请求模型"""
title: str = Field(..., description="项目标题")
description: str = Field("", description="项目描述")
label_config: str = Field(..., description="标注配置")
class LabelStudioCreateTaskRequest(BaseModel):
"""创建Label Studio任务请求模型"""
data: Dict[str, Any] = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")

View File

@@ -0,0 +1,6 @@
# app/services/__init__.py
from .dataset_mapping_service import DatasetMappingService
from .sync_service import SyncService
__all__ = ["DatasetMappingService", "SyncService"]

View File

@@ -0,0 +1,223 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update
from typing import Optional, List
from datetime import datetime
import uuid
from app.models.dataset_mapping import DatasetMapping
from app.schemas.dataset_mapping import (
DatasetMappingCreateRequest,
DatasetMappingUpdateRequest,
DatasetMappingResponse
)
from app.core.logging import get_logger
logger = get_logger(__name__)
class DatasetMappingService:
"""数据集映射服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_mapping(
self,
mapping_data: DatasetMappingCreateRequest,
labelling_project_id: str,
labelling_project_name: str
) -> DatasetMappingResponse:
"""创建数据集映射"""
logger.info(f"Create dataset mapping: {mapping_data.source_dataset_id} -> {labelling_project_id}")
db_mapping = DatasetMapping(
mapping_id=str(uuid.uuid4()),
source_dataset_id=mapping_data.source_dataset_id,
labelling_project_id=labelling_project_id,
labelling_project_name=labelling_project_name
)
self.db.add(db_mapping)
await self.db.commit()
await self.db.refresh(db_mapping)
logger.info(f"Mapping created: {db_mapping.mapping_id}")
return DatasetMappingResponse.model_validate(db_mapping)
async def get_mapping_by_source_uuid(
self,
source_dataset_id: str
) -> Optional[DatasetMappingResponse]:
"""根据源数据集ID获取映射(返回第一个未删除的)"""
logger.debug(f"Get mapping by source dataset id: {source_dataset_id}")
result = await self.db.execute(
select(DatasetMapping).where(
DatasetMapping.source_dataset_id == source_dataset_id,
DatasetMapping.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for source dataset id: {source_dataset_id}")
return None
async def get_mappings_by_source_dataset_id(
self,
source_dataset_id: str,
include_deleted: bool = False
) -> List[DatasetMappingResponse]:
"""根据源数据集ID获取所有映射关系"""
logger.debug(f"Get all mappings by source dataset id: {source_dataset_id}")
query = select(DatasetMapping).where(
DatasetMapping.source_dataset_id == source_dataset_id
)
if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None))
result = await self.db.execute(
query.order_by(DatasetMapping.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def get_mapping_by_labelling_project_id(
self,
labelling_project_id: str
) -> Optional[DatasetMappingResponse]:
"""根据Label Studio项目ID获取映射"""
logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}")
result = await self.db.execute(
select(DatasetMapping).where(
DatasetMapping.labelling_project_id == labelling_project_id,
DatasetMapping.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"No mapping found for Label Studio project id: {labelling_project_id}")
return None
async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]:
"""根据映射UUID获取映射"""
logger.debug(f"Get mapping: {mapping_id}")
result = await self.db.execute(
select(DatasetMapping).where(
DatasetMapping.mapping_id == mapping_id,
DatasetMapping.deleted_at.is_(None)
)
)
mapping = result.scalar_one_or_none()
if mapping:
logger.debug(f"Found mapping: {mapping.mapping_id}")
return DatasetMappingResponse.model_validate(mapping)
logger.debug(f"Mapping not found: {mapping_id}")
return None
async def update_mapping(
self,
mapping_id: str,
update_data: DatasetMappingUpdateRequest
) -> Optional[DatasetMappingResponse]:
"""更新映射信息"""
logger.info(f"Update mapping: {mapping_id}")
mapping = await self.get_mapping_by_uuid(mapping_id)
if not mapping:
return None
update_values = update_data.model_dump(exclude_unset=True)
update_values["last_updated_at"] = datetime.utcnow()
result = await self.db.execute(
update(DatasetMapping)
.where(DatasetMapping.mapping_id == mapping_id)
.values(**update_values)
)
await self.db.commit()
if result.rowcount > 0:
return await self.get_mapping_by_uuid(mapping_id)
return None
async def update_last_updated_at(self, mapping_id: str) -> bool:
"""更新最后更新时间"""
logger.debug(f"Update mapping last updated at: {mapping_id}")
result = await self.db.execute(
update(DatasetMapping)
.where(
DatasetMapping.mapping_id == mapping_id,
DatasetMapping.deleted_at.is_(None)
)
.values(last_updated_at=datetime.utcnow())
)
await self.db.commit()
return result.rowcount > 0
async def soft_delete_mapping(self, mapping_id: str) -> bool:
"""软删除映射"""
logger.info(f"Soft delete mapping: {mapping_id}")
result = await self.db.execute(
update(DatasetMapping)
.where(
DatasetMapping.mapping_id == mapping_id,
DatasetMapping.deleted_at.is_(None)
)
.values(deleted_at=datetime.utcnow())
)
await self.db.commit()
success = result.rowcount > 0
if success:
logger.info(f"Mapping soft-deleted: {mapping_id}")
else:
logger.warning(f"Mapping not exists or already deleted: {mapping_id}")
return success
async def get_all_mappings(
self,
skip: int = 0,
limit: int = 100
) -> List[DatasetMappingResponse]:
"""获取所有有效映射"""
logger.debug(f"List all mappings, skip: {skip}, limit: {limit}")
result = await self.db.execute(
select(DatasetMapping)
.where(DatasetMapping.deleted_at.is_(None))
.offset(skip)
.limit(limit)
.order_by(DatasetMapping.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def count_mappings(self) -> int:
"""统计映射总数"""
result = await self.db.execute(
select(DatasetMapping)
.where(DatasetMapping.deleted_at.is_(None))
)
mappings = result.scalars().all()
return len(mappings)

View File

@@ -0,0 +1,276 @@
from typing import Optional, List, Dict, Any, Tuple
from app.clients.dm_client import DMServiceClient
from app.clients.label_studio_client import LabelStudioClient
from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import SyncDatasetResponse
from app.core.logging import get_logger
from app.core.config import settings
from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError
logger = get_logger(__name__)
class SyncService:
"""数据同步服务"""
def __init__(
self,
dm_client: DMServiceClient,
ls_client: LabelStudioClient,
mapping_service: DatasetMappingService
):
self.dm_client = dm_client
self.ls_client = ls_client
self.mapping_service = mapping_service
def determine_data_type(self, file_type: str) -> str:
"""根据文件类型确定数据类型"""
file_type_lower = file_type.lower()
if any(ext in file_type_lower for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg', 'webp']):
return 'image'
elif any(ext in file_type_lower for ext in ['mp3', 'wav', 'flac', 'aac', 'ogg']):
return 'audio'
elif any(ext in file_type_lower for ext in ['mp4', 'avi', 'mov', 'wmv', 'flv', 'webm']):
return 'video'
elif any(ext in file_type_lower for ext in ['txt', 'doc', 'docx', 'pdf']):
return 'text'
else:
return 'image' # 默认为图像类型
async def get_existing_dm_file_mapping(self, project_id: str) -> Dict[str, int]:
"""
获取Label Studio项目中已存在的DM文件ID到任务ID的映射
Args:
project_id: Label Studio项目ID
Returns:
dm_file_id到task_id的映射字典
"""
try:
logger.info(f"Fetching existing task mappings for project {project_id} (page_size={settings.ls_task_page_size})")
dm_file_to_task_mapping = {}
# 使用Label Studio客户端封装的方法获取所有任务
page_size = getattr(settings, 'ls_task_page_size', 1000)
# 调用封装好的方法获取所有任务,page=None表示获取全部
result = await self.ls_client.get_project_tasks(
project_id=project_id,
page=None, # 不指定page,获取所有任务
page_size=page_size
)
if not result:
logger.warning(f"Failed to fetch tasks for project {project_id}")
return {}
all_tasks = result.get("tasks", [])
# 遍历所有任务,构建映射
for task in all_tasks:
# 检查任务的meta字段中是否有dm_file_id
meta = task.get('meta')
if meta:
dm_file_id = meta.get('dm_file_id')
if dm_file_id:
task_id = task.get('id')
if task_id:
dm_file_to_task_mapping[str(dm_file_id)] = task_id
logger.info(f"Found {len(dm_file_to_task_mapping)} existing task mappings")
return dm_file_to_task_mapping
except Exception as e:
logger.error(f"Error while fetching existing tasks: {e}")
return {} # 发生错误时返回空字典,会同步所有文件
async def sync_dataset_files(
self,
mapping_id: str,
batch_size: int = 50
) -> SyncDatasetResponse:
"""同步数据集文件到Label Studio"""
logger.info(f"Start syncing dataset by mapping: {mapping_id}")
# 获取映射关系
mapping = await self.mapping_service.get_mapping_by_uuid(mapping_id)
if not mapping:
logger.error(f"Dataset mapping not found: {mapping_id}")
return SyncDatasetResponse(
mapping_id="",
status="error",
synced_files=0,
total_files=0,
message=f"Dataset mapping not found: {mapping_id}"
)
try:
# 获取数据集信息
dataset_info = await self.dm_client.get_dataset(mapping.source_dataset_id)
if not dataset_info:
raise NoDatasetInfoFoundError(mapping.source_dataset_id)
synced_files = 0
deleted_tasks = 0
total_files = dataset_info.fileCount
page = 0
logger.info(f"Total files in dataset: {total_files}")
# 获取Label Studio中已存在的DM文件ID到任务ID的映射
existing_dm_file_mapping = await self.get_existing_dm_file_mapping(mapping.labelling_project_id)
existing_dm_file_ids = set(existing_dm_file_mapping.keys())
logger.info(f"{len(existing_dm_file_ids)} tasks already exist in Label Studio")
# 收集DM中当前存在的所有文件ID
current_dm_file_ids = set()
# 分页获取并同步文件
while True:
files_response = await self.dm_client.get_dataset_files(
mapping.source_dataset_id,
page=page,
size=batch_size,
status="COMPLETED" # 只同步已完成的文件
)
if not files_response or not files_response.content:
logger.info(f"No more files on page {page + 1}")
break
logger.info(f"Processing page {page + 1}, total {len(files_response.content)} files")
# 筛选出新文件并批量创建任务
tasks = []
new_files_count = 0
existing_files_count = 0
for file_info in files_response.content:
# 记录当前DM中存在的文件ID
current_dm_file_ids.add(str(file_info.id))
# 检查文件是否已存在
if str(file_info.id) in existing_dm_file_ids:
existing_files_count += 1
logger.debug(f"Skip existing file: {file_info.originalName} (ID: {file_info.id})")
continue
new_files_count += 1
# 确定数据类型
data_type = self.determine_data_type(file_info.fileType)
# 替换文件路径前缀:只替换开头的前缀,不影响路径中间可能出现的相同字符串
file_path = file_info.filePath.removeprefix(settings.dm_file_path_prefix)
file_path = settings.label_studio_file_path_prefix + file_path
# 构造任务数据
task_data = {
"data": {
data_type: file_path
},
"meta": {
"file_size": file_info.size,
"file_type": file_info.fileType,
"dm_dataset_id": mapping.source_dataset_id,
"dm_file_id": file_info.id,
"original_name": file_info.originalName,
}
}
tasks.append(task_data)
logger.info(f"Page {page + 1}: new files {new_files_count}, existing files {existing_files_count}")
# 批量创建Label Studio任务
if tasks:
batch_result = await self.ls_client.create_tasks_batch(
mapping.labelling_project_id,
tasks
)
if batch_result:
synced_files += len(tasks)
logger.info(f"Successfully synced {len(tasks)} files")
else:
logger.warning(f"Batch task creation failed, fallback to single creation")
# 如果批量创建失败,尝试单个创建
for task_data in tasks:
task_result = await self.ls_client.create_task(
mapping.labelling_project_id,
task_data["data"],
task_data.get("meta")
)
if task_result:
synced_files += 1
# 检查是否还有更多页面
if page >= files_response.totalPages - 1:
break
page += 1
# 清理在DM中不存在但在Label Studio中存在的任务
tasks_to_delete = []
for dm_file_id, task_id in existing_dm_file_mapping.items():
if dm_file_id not in current_dm_file_ids:
tasks_to_delete.append(task_id)
logger.debug(f"Mark task for deletion: {task_id} (DM file ID: {dm_file_id})")
if tasks_to_delete:
logger.info(f"Deleting {len(tasks_to_delete)} tasks not present in DM")
delete_result = await self.ls_client.delete_tasks_batch(tasks_to_delete)
deleted_tasks = delete_result.get("successful", 0)
logger.info(f"Successfully deleted {deleted_tasks} tasks")
else:
logger.info("No tasks to delete")
# 更新映射的最后更新时间
await self.mapping_service.update_last_updated_at(mapping.mapping_id)
logger.info(f"Sync completed: total_files={total_files}, created={synced_files}, deleted={deleted_tasks}")
return SyncDatasetResponse(
mapping_id=mapping.mapping_id,
status="success",
synced_files=synced_files,
total_files=total_files,
message=f"Sync completed: created {synced_files} files, deleted {deleted_tasks} tasks"
)
except Exception as e:
logger.error(f"Error while syncing dataset: {e}")
return SyncDatasetResponse(
mapping_id=mapping.mapping_id,
status="error",
synced_files=0,
total_files=0,
message=f"Sync failed: {str(e)}"
)
async def get_sync_status(
self,
source_dataset_id: str
) -> Optional[Dict[str, Any]]:
"""获取同步状态"""
mapping = await self.mapping_service.get_mapping_by_source_uuid(source_dataset_id)
if not mapping:
return None
# 获取DM数据集信息
dataset_info = await self.dm_client.get_dataset(source_dataset_id)
# 获取Label Studio项目任务数量
tasks_info = await self.ls_client.get_project_tasks(mapping.labelling_project_id)
return {
"mapping_id": mapping.mapping_id,
"source_dataset_id": source_dataset_id,
"labelling_project_id": mapping.labelling_project_id,
"last_updated_at": mapping.last_updated_at,
"dm_total_files": dataset_info.fileCount if dataset_info else 0,
"ls_total_tasks": tasks_info.get("count", 0) if tasks_info else 0,
"sync_ratio": (
tasks_info.get("count", 0) / dataset_info.fileCount
if dataset_info and dataset_info.fileCount > 0 and tasks_info else 0
)
}

View File

@@ -0,0 +1,64 @@
#!/bin/bash
set -e
echo "=========================================="
echo "Label Studio Adapter Starting..."
echo "=========================================="
# Label Studio 本地存储基础路径(从环境变量获取,默认值)
LABEL_STUDIO_LOCAL_BASE="${LABEL_STUDIO_LOCAL_BASE:-/label-studio/local_files}"
echo "=========================================="
echo "Ensuring Label Studio local storage directories exist..."
echo "Base path: ${LABEL_STUDIO_LOCAL_BASE}"
echo "=========================================="
# 创建必要的目录
mkdir -p "${LABEL_STUDIO_LOCAL_BASE}/dataset"
mkdir -p "${LABEL_STUDIO_LOCAL_BASE}/upload"
echo "✓ Directory 'dataset' ready: ${LABEL_STUDIO_LOCAL_BASE}/dataset"
echo "✓ Directory 'upload' ready: ${LABEL_STUDIO_LOCAL_BASE}/upload"
echo "=========================================="
echo "Directory initialization completed"
echo "=========================================="
# 等待数据库就绪(如果配置了数据库)
if [ -n "$MYSQL_HOST" ] || [ -n "$POSTGRES_HOST" ]; then
echo "Waiting for database to be ready..."
sleep 5
fi
# 运行数据库迁移
echo "=========================================="
echo "Running database migrations..."
echo "=========================================="
alembic upgrade head
if [ $? -eq 0 ]; then
echo "✓ Database migrations completed successfully"
else
echo "⚠️ WARNING: Database migrations failed"
echo " The application may not work correctly"
fi
echo "=========================================="
# 启动应用
echo "Starting Label Studio Adapter..."
echo "Host: ${HOST:-0.0.0.0}"
echo "Port: ${PORT:-18000}"
echo "Debug: ${DEBUG:-false}"
echo "Label Studio URL: ${LABEL_STUDIO_BASE_URL}"
echo "=========================================="
# 转换 LOG_LEVEL 为小写(uvicorn 要求小写)
LOG_LEVEL_LOWER=$(echo "${LOG_LEVEL:-info}" | tr '[:upper:]' '[:lower:]')
# 使用 uvicorn 启动应用
exec uvicorn app.main:app \
--host "${HOST:-0.0.0.0}" \
--port "${PORT:-18000}" \
--log-level "${LOG_LEVEL_LOWER}" \
${DEBUG:+--reload}

Binary file not shown.