From c6401053334f314ea6061c1f40c802d01c52cea0 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Wed, 22 Oct 2025 15:14:01 +0800 Subject: [PATCH 1/4] Add Label Studio adapter module and its build scipts. --- Makefile | 4 + runtime/label-studio-adapter/.env.example | 128 +++++ runtime/label-studio-adapter/.gitignore | 6 + runtime/label-studio-adapter/alembic.ini | 148 ++++++ runtime/label-studio-adapter/alembic/README | 1 + runtime/label-studio-adapter/alembic/env.py | 145 ++++++ .../alembic/script.py.mako | 28 ++ .../versions/755dc1afb8ad_initiation.py | 41 ++ runtime/label-studio-adapter/app/__init__.py | 1 + .../label-studio-adapter/app/api/__init__.py | 19 + .../app/api/project/__init__.py | 11 + .../app/api/project/create.py | 130 +++++ .../app/api/project/delete.py | 106 ++++ .../app/api/project/list.py | 110 ++++ .../app/api/project/sync.py | 68 +++ .../label-studio-adapter/app/api/system.py | 34 ++ .../app/clients/__init__.py | 8 + .../app/clients/client_manager.py | 34 ++ .../app/clients/dm_client.py | 138 ++++++ .../app/clients/label_studio_client.py | 469 ++++++++++++++++++ .../label-studio-adapter/app/core/__init__.py | 1 + .../label-studio-adapter/app/core/config.py | 146 ++++++ .../label-studio-adapter/app/core/logging.py | 53 ++ .../label-studio-adapter/app/db/__init__.py | 1 + .../label-studio-adapter/app/db/database.py | 39 ++ .../label-studio-adapter/app/exceptions.py | 31 ++ runtime/label-studio-adapter/app/main.py | 172 +++++++ .../app/models/__init__.py | 5 + .../app/models/dataset_mapping.py | 25 + .../app/schemas/__init__.py | 29 ++ .../app/schemas/common.py | 27 + .../app/schemas/dataset_mapping.py | 53 ++ .../app/schemas/dm_service.py | 58 +++ .../app/schemas/label_studio.py | 37 ++ .../app/services/__init__.py | 6 + .../app/services/dataset_mapping_service.py | 223 +++++++++ .../app/services/sync_service.py | 276 +++++++++++ .../deploy/docker-entrypoint.sh | 64 +++ runtime/label-studio-adapter/requirements.txt | Bin 0 -> 942 bytes .../images/label-studio-adapter/Dockerfile | 27 + 40 files changed, 2902 insertions(+) create mode 100644 runtime/label-studio-adapter/.env.example create mode 100644 runtime/label-studio-adapter/.gitignore create mode 100644 runtime/label-studio-adapter/alembic.ini create mode 100644 runtime/label-studio-adapter/alembic/README create mode 100644 runtime/label-studio-adapter/alembic/env.py create mode 100644 runtime/label-studio-adapter/alembic/script.py.mako create mode 100644 runtime/label-studio-adapter/alembic/versions/755dc1afb8ad_initiation.py create mode 100644 runtime/label-studio-adapter/app/__init__.py create mode 100644 runtime/label-studio-adapter/app/api/__init__.py create mode 100644 runtime/label-studio-adapter/app/api/project/__init__.py create mode 100644 runtime/label-studio-adapter/app/api/project/create.py create mode 100644 runtime/label-studio-adapter/app/api/project/delete.py create mode 100644 runtime/label-studio-adapter/app/api/project/list.py create mode 100644 runtime/label-studio-adapter/app/api/project/sync.py create mode 100644 runtime/label-studio-adapter/app/api/system.py create mode 100644 runtime/label-studio-adapter/app/clients/__init__.py create mode 100644 runtime/label-studio-adapter/app/clients/client_manager.py create mode 100644 runtime/label-studio-adapter/app/clients/dm_client.py create mode 100644 runtime/label-studio-adapter/app/clients/label_studio_client.py create mode 100644 runtime/label-studio-adapter/app/core/__init__.py create mode 100644 runtime/label-studio-adapter/app/core/config.py create mode 100644 runtime/label-studio-adapter/app/core/logging.py create mode 100644 runtime/label-studio-adapter/app/db/__init__.py create mode 100644 runtime/label-studio-adapter/app/db/database.py create mode 100644 runtime/label-studio-adapter/app/exceptions.py create mode 100644 runtime/label-studio-adapter/app/main.py create mode 100644 runtime/label-studio-adapter/app/models/__init__.py create mode 100644 runtime/label-studio-adapter/app/models/dataset_mapping.py create mode 100644 runtime/label-studio-adapter/app/schemas/__init__.py create mode 100644 runtime/label-studio-adapter/app/schemas/common.py create mode 100644 runtime/label-studio-adapter/app/schemas/dataset_mapping.py create mode 100644 runtime/label-studio-adapter/app/schemas/dm_service.py create mode 100644 runtime/label-studio-adapter/app/schemas/label_studio.py create mode 100644 runtime/label-studio-adapter/app/services/__init__.py create mode 100644 runtime/label-studio-adapter/app/services/dataset_mapping_service.py create mode 100644 runtime/label-studio-adapter/app/services/sync_service.py create mode 100755 runtime/label-studio-adapter/deploy/docker-entrypoint.sh create mode 100644 runtime/label-studio-adapter/requirements.txt create mode 100644 scripts/images/label-studio-adapter/Dockerfile diff --git a/Makefile b/Makefile index 6c453f1..65590db 100644 --- a/Makefile +++ b/Makefile @@ -81,6 +81,10 @@ frontend-docker-build: runtime-docker-build: docker build -t runtime:$(VERSION) . -f scripts/images/runtime/Dockerfile +.PHONY: label-studio-adapter-docker-build +label-studio-adapter-docker-build: + docker build -t label-studio-adapter:$(VERSION) . -f scripts/images/label-studio-adapter/Dockerfile + .PHONY: backend-docker-install backend-docker-install: cd deployment/docker/data-mate && docker-compose up -d backend diff --git a/runtime/label-studio-adapter/.env.example b/runtime/label-studio-adapter/.env.example new file mode 100644 index 0000000..3c8222e --- /dev/null +++ b/runtime/label-studio-adapter/.env.example @@ -0,0 +1,128 @@ +# ==================================== +# Label Studio Adapter Configuration +# ==================================== + +# ========================= +# 应用程序配置 +# ========================= +APP_NAME="Label Studio Adapter" +APP_VERSION="1.0.0" +APP_DESCRIPTION="Adapter for integrating Data Management System with Label Studio" +DEBUG=true + +# ========================= +# 服务器配置 +# ========================= +HOST=0.0.0.0 +PORT=18000 + +# ========================= +# 日志配置 +# ========================= +LOG_LEVEL=INFO + +# ========================= +# Label Studio 服务配置 +# ========================= +# Label Studio 服务地址(根据部署方式调整) +# Docker 环境:http://label-studio:8080 +# 本地开发:http://127.0.0.1:8000 +LABEL_STUDIO_BASE_URL=http://label-studio:8080 + +# Label Studio 用户名和密码(用于自动创建用户) +LABEL_STUDIO_USERNAME=admin@example.com +LABEL_STUDIO_PASSWORD=password + +# Label Studio API 认证 Token(Legacy Token,推荐使用) +# 从 Label Studio UI 的 Account & Settings > Access Token 获取 +LABEL_STUDIO_USER_TOKEN=your-label-studio-token-here + +# Label Studio 本地文件存储基础路径(容器内路径,用于 Docker 部署时的权限检查) +LABEL_STUDIO_LOCAL_BASE=/label-studio/local_files + +# Label Studio 本地文件服务路径前缀(任务数据中的文件路径前缀) +LABEL_STUDIO_FILE_PATH_PREFIX=/data/local-files/?d= + +# Label Studio 容器中的本地存储路径(用于配置 Local Storage) +LABEL_STUDIO_LOCAL_STORAGE_DATASET_BASE_PATH=/label-studio/local_files/dataset +LABEL_STUDIO_LOCAL_STORAGE_UPLOAD_BASE_PATH=/label-studio/local_files/upload + +# Label Studio 任务列表分页大小 +LS_TASK_PAGE_SIZE=1000 + +# ========================= +# Data Management 服务配置 +# ========================= +# DM 服务地址 +DM_SERVICE_BASE_URL=http://data-engine:8080 + +# DM 存储文件夹前缀(通常与 Label Studio 的 local-files 文件夹映射一致) +DM_FILE_PATH_PREFIX=/ + +# ========================= +# Adapter 数据库配置 (MySQL) +# ========================= +# 优先级1:如果配置了 MySQL,将优先使用 MySQL 数据库 +MYSQL_HOST=adapter-db +MYSQL_PORT=3306 +MYSQL_USER=label_studio_user +MYSQL_PASSWORD=user_password +MYSQL_DATABASE=label_studio_adapter + +# ========================= +# Label Studio 数据库配置 (PostgreSQL) +# ========================= +# 仅在使用 docker-compose.label-studio.yml 启动 Label Studio 时需要配置 +POSTGRES_HOST=label-studio-db +POSTGRES_PORT=5432 +POSTGRES_USER=labelstudio +POSTGRES_PASSWORD=labelstudio@4321 +POSTGRES_DATABASE=labelstudio + +# ========================= +# SQLite 数据库配置(兜底选项) +# ========================= +# 优先级3:如果没有配置 MySQL/PostgreSQL,将使用 SQLite +SQLITE_PATH=./data/labelstudio_adapter.db + +# ========================= +# 可选:直接指定数据库 URL +# ========================= +# 如果设置了此项,将覆盖上面的 MySQL/PostgreSQL/SQLite 配置 +# DATABASE_URL=postgresql+asyncpg://user:password@host:port/database + +# ========================= +# 安全配置 +# ========================= +# 密钥(生产环境务必修改) +SECRET_KEY=your-secret-key-change-this-in-production + +# Token 过期时间(分钟) +ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# ========================= +# CORS 配置 +# ========================= +# 允许的来源(生产环境建议配置具体域名) +ALLOWED_ORIGINS=["*"] + +# 允许的 HTTP 方法 +ALLOWED_METHODS=["*"] + +# 允许的请求头 +ALLOWED_HEADERS=["*"] + +# ========================= +# Docker Compose 配置 +# ========================= +# Docker Compose 项目名称前缀 +COMPOSE_PROJECT_NAME=ls-adapter + +# ========================= +# 同步配置(未来扩展) +# ========================= +# 批量同步任务的批次大小 +SYNC_BATCH_SIZE=100 + +# 同步失败时的最大重试次数 +MAX_RETRIES=3 \ No newline at end of file diff --git a/runtime/label-studio-adapter/.gitignore b/runtime/label-studio-adapter/.gitignore new file mode 100644 index 0000000..4670a93 --- /dev/null +++ b/runtime/label-studio-adapter/.gitignore @@ -0,0 +1,6 @@ +# Local Development Environment Files +.env +.dev.env + +# logs +logs/ \ No newline at end of file diff --git a/runtime/label-studio-adapter/alembic.ini b/runtime/label-studio-adapter/alembic.ini new file mode 100644 index 0000000..6753110 --- /dev/null +++ b/runtime/label-studio-adapter/alembic.ini @@ -0,0 +1,148 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# sqlalchemy.url = driver://user:pass@localhost/dbname +# 注释掉默认 URL,我们将在 env.py 中从应用配置读取 + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/runtime/label-studio-adapter/alembic/README b/runtime/label-studio-adapter/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/runtime/label-studio-adapter/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/runtime/label-studio-adapter/alembic/env.py b/runtime/label-studio-adapter/alembic/env.py new file mode 100644 index 0000000..eccb1ab --- /dev/null +++ b/runtime/label-studio-adapter/alembic/env.py @@ -0,0 +1,145 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from sqlalchemy import create_engine, text + +from alembic import context +import os +from urllib.parse import quote_plus + +# 导入应用配置和模型 +from app.core.config import settings +from app.db.database import Base +# 导入所有模型,以便 autogenerate 能够检测到它们 +from app.models import dataset_mapping # noqa + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +def ensure_database_and_user(): + """ + 确保数据库和用户存在 + 使用 root 用户连接 MySQL,创建数据库和应用用户 + """ + # 只在 MySQL 配置时执行 + if not settings.mysql_host: + return + + mysql_root_password = os.getenv('MYSQL_ROOT_PASSWORD', 'Huawei@123') + + # URL 编码密码以处理特殊字符 + encoded_password = quote_plus(mysql_root_password) + + # 使用 root 用户连接(不指定数据库) + root_url = f"mysql+pymysql://root:{encoded_password}@{settings.mysql_host}:{settings.mysql_port}/" + + try: + root_engine = create_engine(root_url, poolclass=pool.NullPool) + with root_engine.connect() as conn: + # 创建数据库(如果不存在) + conn.execute(text( + f"CREATE DATABASE IF NOT EXISTS `{settings.mysql_database}` " + f"CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" + )) + conn.commit() + + # 创建用户(如果不存在)- 使用 MySQL 8 默认的 caching_sha2_password + conn.execute(text( + f"CREATE USER IF NOT EXISTS '{settings.mysql_user}'@'%' " + f"IDENTIFIED BY '{settings.mysql_password}'" + )) + conn.commit() + + # 授予权限 + conn.execute(text( + f"GRANT ALL PRIVILEGES ON `{settings.mysql_database}`.* TO '{settings.mysql_user}'@'%'" + )) + conn.commit() + + # 刷新权限 + conn.execute(text("FLUSH PRIVILEGES")) + conn.commit() + + root_engine.dispose() + print(f"✓ Database '{settings.mysql_database}' and user '{settings.mysql_user}' are ready") + except Exception as e: + print(f"⚠️ Warning: Could not ensure database and user: {e}") + print(f" This may be expected if database already exists or permissions are set") + + +# 从应用配置设置数据库 URL +config.set_main_option('sqlalchemy.url', settings.sync_database_url) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + # 先确保数据库和用户存在 + ensure_database_and_user() + + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/runtime/label-studio-adapter/alembic/script.py.mako b/runtime/label-studio-adapter/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/runtime/label-studio-adapter/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/runtime/label-studio-adapter/alembic/versions/755dc1afb8ad_initiation.py b/runtime/label-studio-adapter/alembic/versions/755dc1afb8ad_initiation.py new file mode 100644 index 0000000..1659bd5 --- /dev/null +++ b/runtime/label-studio-adapter/alembic/versions/755dc1afb8ad_initiation.py @@ -0,0 +1,41 @@ +"""Initiation + +Revision ID: 755dc1afb8ad +Revises: +Create Date: 2025-10-20 19:34:20.258554 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '755dc1afb8ad' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('mapping', + sa.Column('mapping_id', sa.String(length=36), nullable=False), + sa.Column('source_dataset_id', sa.String(length=36), nullable=False, comment='源数据集ID'), + sa.Column('labelling_project_id', sa.String(length=36), nullable=False, comment='标注项目ID'), + sa.Column('labelling_project_name', sa.String(length=255), nullable=True, comment='标注项目名称'), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='创建时间'), + sa.Column('last_updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True, comment='最后更新时间'), + sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True, comment='删除时间'), + sa.PrimaryKeyConstraint('mapping_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('mapping') + # ### end Alembic commands ### diff --git a/runtime/label-studio-adapter/app/__init__.py b/runtime/label-studio-adapter/app/__init__.py new file mode 100644 index 0000000..4ec4b56 --- /dev/null +++ b/runtime/label-studio-adapter/app/__init__.py @@ -0,0 +1 @@ +# app/__init__.py \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/api/__init__.py b/runtime/label-studio-adapter/app/api/__init__.py new file mode 100644 index 0000000..ca138c2 --- /dev/null +++ b/runtime/label-studio-adapter/app/api/__init__.py @@ -0,0 +1,19 @@ +""" +API 路由模块 + +集中管理所有API路由的组织结构 +""" +from fastapi import APIRouter + +from .system import router as system_router +from .project import project_router + +# 创建主API路由器 +api_router = APIRouter() + +# 注册到主路由器 +api_router.include_router(system_router, tags=["系统"]) +api_router.include_router(project_router, prefix="/project", tags=["项目"]) + +# 导出路由器供 main.py 使用 +__all__ = ["api_router"] \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/api/project/__init__.py b/runtime/label-studio-adapter/app/api/project/__init__.py new file mode 100644 index 0000000..f499596 --- /dev/null +++ b/runtime/label-studio-adapter/app/api/project/__init__.py @@ -0,0 +1,11 @@ +""" +标注工程相关API路由模块 +""" +from fastapi import APIRouter + +project_router = APIRouter() + +from . import create +from . import sync +from . import list +from . import delete \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/api/project/create.py b/runtime/label-studio-adapter/app/api/project/create.py new file mode 100644 index 0000000..0d3d51d --- /dev/null +++ b/runtime/label-studio-adapter/app/api/project/create.py @@ -0,0 +1,130 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Optional + +from app.db.database import get_db +from app.services.dataset_mapping_service import DatasetMappingService +from app.clients import get_clients +from app.schemas.dataset_mapping import ( + DatasetMappingCreateRequest, + DatasetMappingCreateResponse, +) +from app.schemas import StandardResponse +from app.core.logging import get_logger +from app.core.config import settings +from . import project_router + +logger = get_logger(__name__) + +@project_router.post("/create", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201) +async def create_dataset_mapping( + request: DatasetMappingCreateRequest, + db: AsyncSession = Depends(get_db) +): + """ + 创建数据集映射 + + 根据指定的DM程序中的数据集,创建Label Studio中的数据集, + 在数据库中记录这一关联关系,返回Label Studio数据集的ID + + 注意:一个数据集可以创建多个标注项目 + """ + try: + # 获取全局客户端实例 + dm_client_instance, ls_client_instance = get_clients() + service = DatasetMappingService(db) + + logger.info(f"Create dataset mapping request: {request.source_dataset_id}") + + # 从DM服务获取数据集信息 + dataset_info = await dm_client_instance.get_dataset(request.source_dataset_id) + if not dataset_info: + raise HTTPException( + status_code=404, + detail=f"Dataset not found in DM service: {request.source_dataset_id}" + ) + + # 确定数据类型(基于数据集类型) + data_type = "image" # 默认值 + if dataset_info.type and dataset_info.type.code: + type_code = dataset_info.type.code.lower() + if "audio" in type_code: + data_type = "audio" + elif "video" in type_code: + data_type = "video" + elif "text" in type_code: + data_type = "text" + + # 生成项目名称 + project_name = f"{dataset_info.name}" + + # 在Label Studio中创建项目 + project_data = await ls_client_instance.create_project( + title=project_name, + description=dataset_info.description or f"Imported from DM dataset {dataset_info.id}", + data_type=data_type + ) + + if not project_data: + raise HTTPException( + status_code=500, + detail="Fail to create Label Studio project." + ) + + project_id = project_data["id"] + + # 配置本地存储:dataset/ + local_storage_path = f"{settings.label_studio_local_storage_dataset_base_path}/{request.source_dataset_id}" + storage_result = await ls_client_instance.create_local_storage( + project_id=project_id, + path=local_storage_path, + title="Dataset_BLOB", + use_blob_urls=True, + description=f"Local storage for dataset {dataset_info.name}" + ) + + # 配置本地存储:upload + local_storage_path = f"{settings.label_studio_local_storage_upload_base_path}" + storage_result = await ls_client_instance.create_local_storage( + project_id=project_id, + path=local_storage_path, + title="Upload_BLOB", + use_blob_urls=True, + description=f"Local storage for dataset {dataset_info.name}" + ) + + if not storage_result: + # 本地存储配置失败,记录警告但不中断流程 + logger.warning(f"Failed to configure local storage for project {project_id}") + else: + logger.info(f"Local storage configured for project {project_id}: {local_storage_path}") + + # 创建映射关系,包含项目名称 + mapping = await service.create_mapping( + request, + str(project_id), + project_name + ) + + logger.debug( + f"Dataset mapping created: {mapping.mapping_id} -> S {mapping.source_dataset_id} <> L {mapping.labelling_project_id}" + ) + + response_data = DatasetMappingCreateResponse( + mapping_id=mapping.mapping_id, + labelling_project_id=mapping.labelling_project_id, + labelling_project_name=mapping.labelling_project_name or project_name, + message="Dataset mapping created successfully" + ) + + return StandardResponse( + code=201, + message="success", + data=response_data + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error while creating dataset mapping: {e}") + raise HTTPException(status_code=500, detail="Internal server error") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/api/project/delete.py b/runtime/label-studio-adapter/app/api/project/delete.py new file mode 100644 index 0000000..f2861ce --- /dev/null +++ b/runtime/label-studio-adapter/app/api/project/delete.py @@ -0,0 +1,106 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Optional + +from app.db.database import get_db +from app.services.dataset_mapping_service import DatasetMappingService +from app.clients import get_clients +from app.schemas.dataset_mapping import DeleteDatasetResponse +from app.schemas import StandardResponse +from app.core.logging import get_logger +from . import project_router + +logger = get_logger(__name__) + +@project_router.delete("/mappings", response_model=StandardResponse[DeleteDatasetResponse]) +async def delete_mapping( + m: Optional[str] = Query(None, description="映射UUID"), + proj: Optional[str] = Query(None, description="Label Studio项目ID"), + db: AsyncSession = Depends(get_db) +): + """ + 删除映射关系和对应的 Label Studio 项目 + + 可以通过以下任一方式指定要删除的映射: + - m: 映射UUID + - proj: Label Studio项目ID + - 两者都提供(优先使用 m) + + 此操作会: + 1. 删除 Label Studio 中的项目 + 2. 软删除数据库中的映射记录 + """ + try: + # 至少需要提供一个参数 + if not m and not proj: + raise HTTPException( + status_code=400, + detail="Either 'm' (mapping UUID) or 'proj' (project ID) must be provided" + ) + + # 获取全局客户端实例 + dm_client_instance, ls_client_instance = get_clients() + service = DatasetMappingService(db) + + mapping = None + + # 优先使用 mapping_id 查询 + if m: + logger.info(f"Deleting by mapping UUID: {m}") + mapping = await service.get_mapping_by_uuid(m) + # 如果没有提供 m,使用 proj 查询 + elif proj: + logger.info(f"Deleting by project ID: {proj}") + mapping = await service.get_mapping_by_labelling_project_id(proj) + + if not mapping: + raise HTTPException( + status_code=404, + detail=f"Mapping not found" + ) + + mapping_id = mapping.mapping_id + labelling_project_id = mapping.labelling_project_id + labelling_project_name = mapping.labelling_project_name + + logger.info(f"Found mapping: {mapping_id}, Label Studio project ID: {labelling_project_id}") + + # 1. 删除 Label Studio 项目 + try: + delete_success = await ls_client_instance.delete_project(int(labelling_project_id)) + if delete_success: + logger.info(f"Successfully deleted Label Studio project: {labelling_project_id}") + else: + logger.warning(f"Failed to delete Label Studio project or project not found: {labelling_project_id}") + except Exception as e: + logger.error(f"Error deleting Label Studio project: {e}") + # 继续执行,即使 Label Studio 项目删除失败也要删除映射记录 + + # 2. 软删除映射记录 + soft_delete_success = await service.soft_delete_mapping(mapping_id) + + if not soft_delete_success: + raise HTTPException( + status_code=500, + detail="Failed to delete mapping record" + ) + + logger.info(f"Successfully deleted mapping: {mapping_id}") + + response_data = DeleteDatasetResponse( + mapping_id=mapping_id, + status="success", + message=f"Successfully deleted mapping and Label Studio project '{labelling_project_name}'" + ) + + return StandardResponse( + code=200, + message="success", + data=response_data + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error deleting mapping: {e}") + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/runtime/label-studio-adapter/app/api/project/list.py b/runtime/label-studio-adapter/app/api/project/list.py new file mode 100644 index 0000000..c3bcafa --- /dev/null +++ b/runtime/label-studio-adapter/app/api/project/list.py @@ -0,0 +1,110 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.db.database import get_db +from app.services.dataset_mapping_service import DatasetMappingService +from app.schemas.dataset_mapping import DatasetMappingResponse +from app.schemas import StandardResponse +from app.core.logging import get_logger +from . import project_router + +logger = get_logger(__name__) + +@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]]) +async def list_mappings( + skip: int = Query(0, ge=0, description="Number of records to skip"), + limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"), + db: AsyncSession = Depends(get_db) +): + """ + 查询所有映射关系 + + 返回所有有效的数据集映射关系(未被软删除的) + """ + try: + service = DatasetMappingService(db) + + logger.info(f"Listing mappings, skip={skip}, limit={limit}") + + mappings = await service.get_all_mappings(skip=skip, limit=limit) + + logger.info(f"Found {len(mappings)} mappings") + + return StandardResponse( + code=200, + message="success", + data=mappings + ) + + except Exception as e: + logger.error(f"Error listing mappings: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@project_router.get("/mappings/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse]) +async def get_mapping( + mapping_id: str, + db: AsyncSession = Depends(get_db) +): + """ + 根据 UUID 查询单个映射关系 + """ + try: + service = DatasetMappingService(db) + + logger.info(f"Get mapping: {mapping_id}") + + mapping = await service.get_mapping_by_uuid(mapping_id) + + if not mapping: + raise HTTPException( + status_code=404, + detail=f"Mapping not found: {mapping_id}" + ) + + logger.info(f"Found mapping: {mapping.mapping_id}") + + return StandardResponse( + code=200, + message="success", + data=mapping + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting mapping: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + +@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]]) +async def get_mappings_by_source( + source_dataset_id: str, + db: AsyncSession = Depends(get_db) +): + """ + 根据源数据集 ID 查询所有映射关系 + + 返回该数据集创建的所有标注项目(包括已删除的) + """ + try: + service = DatasetMappingService(db) + + logger.info(f"Get mappings by source dataset id: {source_dataset_id}") + + mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id) + + logger.info(f"Found {len(mappings)} mappings") + + return StandardResponse( + code=200, + message="success", + data=mappings + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting mappings: {e}") + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/runtime/label-studio-adapter/app/api/project/sync.py b/runtime/label-studio-adapter/app/api/project/sync.py new file mode 100644 index 0000000..11b1e35 --- /dev/null +++ b/runtime/label-studio-adapter/app/api/project/sync.py @@ -0,0 +1,68 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List, Optional + +from app.db.database import get_db +from app.services.dataset_mapping_service import DatasetMappingService +from app.services.sync_service import SyncService +from app.clients import get_clients +from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError +from app.schemas.dataset_mapping import ( + DatasetMappingResponse, + SyncDatasetRequest, + SyncDatasetResponse, +) +from app.schemas import StandardResponse +from app.core.logging import get_logger +from . import project_router + +logger = get_logger(__name__) + +@project_router.post("/sync", response_model=StandardResponse[SyncDatasetResponse]) +async def sync_dataset_content( + request: SyncDatasetRequest, + db: AsyncSession = Depends(get_db) +): + """ + 同步数据集内容 + + 根据指定的mapping ID,同步DM程序数据集中的内容到Label Studio数据集中, + 在数据库中记录更新时间,返回更新状态 + """ + try: + dm_client_instance, ls_client_instance = get_clients() + mapping_service = DatasetMappingService(db) + sync_service = SyncService(dm_client_instance, ls_client_instance, mapping_service) + + logger.info(f"Sync dataset content request: mapping_id={request.mapping_id}") + + # 根据 mapping_id 获取映射关系 + mapping = await mapping_service.get_mapping_by_uuid(request.mapping_id) + if not mapping: + raise HTTPException( + status_code=404, + detail=f"Mapping not found: {request.mapping_id}" + ) + + # 执行同步(使用映射中的源数据集UUID) + result = await sync_service.sync_dataset_files(request.mapping_id, request.batch_size) + + logger.info(f"Sync completed: {result.synced_files}/{result.total_files} files") + + return StandardResponse( + code=200, + message="success", + data=result + ) + + except HTTPException: + raise + except NoDatasetInfoFoundError as e: + logger.error(f"Failed to get dataset info: {e}") + raise HTTPException(status_code=404, detail=str(e)) + except DatasetMappingNotFoundError as e: + logger.error(f"Mapping not found: {e}") + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.error(f"Error syncing dataset content: {e}") + raise HTTPException(status_code=500, detail="Internal server error") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/api/system.py b/runtime/label-studio-adapter/app/api/system.py new file mode 100644 index 0000000..576635f --- /dev/null +++ b/runtime/label-studio-adapter/app/api/system.py @@ -0,0 +1,34 @@ +from fastapi import APIRouter +from typing import Dict, Any +from app.core.config import settings +from app.schemas import StandardResponse + +router = APIRouter() + +@router.get("/health", response_model=StandardResponse[Dict[str, Any]]) +async def health_check(): + """健康检查端点""" + return StandardResponse( + code=200, + message="success", + data={ + "status": "healthy", + "service": "Label Studio Adapter", + "version": settings.app_version + } + ) + +@router.get("/config", response_model=StandardResponse[Dict[str, Any]]) +async def get_config(): + """获取配置信息""" + return StandardResponse( + code=200, + message="success", + data={ + "app_name": settings.app_name, + "version": settings.app_version, + "dm_service_url": settings.dm_service_base_url, + "label_studio_url": settings.label_studio_base_url, + "debug": settings.debug + } + ) \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/clients/__init__.py b/runtime/label-studio-adapter/app/clients/__init__.py new file mode 100644 index 0000000..689b445 --- /dev/null +++ b/runtime/label-studio-adapter/app/clients/__init__.py @@ -0,0 +1,8 @@ +# app/clients/__init__.py + +from .dm_client import DMServiceClient +from .label_studio_client import LabelStudioClient +from .client_manager import get_clients, set_clients, get_dm_client, get_ls_client + +__all__ = ["DMServiceClient", "LabelStudioClient", "get_clients", "set_clients", "get_dm_client", "get_ls_client"] + diff --git a/runtime/label-studio-adapter/app/clients/client_manager.py b/runtime/label-studio-adapter/app/clients/client_manager.py new file mode 100644 index 0000000..bc7eeea --- /dev/null +++ b/runtime/label-studio-adapter/app/clients/client_manager.py @@ -0,0 +1,34 @@ +from typing import Optional +from fastapi import HTTPException + +from .dm_client import DMServiceClient +from .label_studio_client import LabelStudioClient + +# 全局客户端实例(将在main.py中初始化) +dm_client: Optional[DMServiceClient] = None +ls_client: Optional[LabelStudioClient] = None + +def get_clients() -> tuple[DMServiceClient, LabelStudioClient]: + """获取客户端实例""" + global dm_client, ls_client + if not dm_client or not ls_client: + raise HTTPException(status_code=500, detail="客户端未初始化") + return dm_client, ls_client + +def set_clients(dm_client_instance: DMServiceClient, ls_client_instance: LabelStudioClient) -> None: + """设置全局客户端实例""" + global dm_client, ls_client + dm_client = dm_client_instance + ls_client = ls_client_instance + +def get_dm_client() -> DMServiceClient: + """获取DM服务客户端""" + if not dm_client: + raise HTTPException(status_code=500, detail="DM客户端未初始化") + return dm_client + +def get_ls_client() -> LabelStudioClient: + """获取Label Studio客户端""" + if not ls_client: + raise HTTPException(status_code=500, detail="Label Studio客户端未初始化") + return ls_client \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/clients/dm_client.py b/runtime/label-studio-adapter/app/clients/dm_client.py new file mode 100644 index 0000000..3661bdc --- /dev/null +++ b/runtime/label-studio-adapter/app/clients/dm_client.py @@ -0,0 +1,138 @@ +import httpx +from typing import Optional +from app.core.config import settings +from app.core.logging import get_logger +from app.schemas.dm_service import DatasetResponse, PagedDatasetFileResponse + +logger = get_logger(__name__) + +class DMServiceClient: + """数据管理服务客户端""" + + def __init__(self, base_url: str|None = None, timeout: float = 30.0): + self.base_url = base_url or settings.dm_service_base_url + self.timeout = timeout + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=self.timeout + ) + logger.info(f"Initialize DM service client, base url: {self.base_url}") + + @staticmethod + def _unwrap_payload(data): + """Unwrap common envelope shapes like {'code': ..., 'message': ..., 'data': {...}}.""" + if isinstance(data, dict) and 'data' in data and isinstance(data['data'], (dict, list)): + return data['data'] + return data + + @staticmethod + def _is_error_payload(data) -> bool: + """Detect error-shaped payloads returned with HTTP 200.""" + if not isinstance(data, dict): + return False + # Common patterns: {error, message, ...} or {code, message, ...} without data + if 'error' in data and 'message' in data: + return True + if 'code' in data and 'message' in data and 'data' not in data: + return True + return False + + @staticmethod + def _keys(d): + return list(d.keys()) if isinstance(d, dict) else [] + + async def get_dataset(self, dataset_id: str) -> Optional[DatasetResponse]: + """获取数据集详情""" + try: + logger.info(f"Getting dataset detail: {dataset_id} ...") + response = await self.client.get(f"/data-management/datasets/{dataset_id}") + response.raise_for_status() + raw = response.json() + + data = self._unwrap_payload(raw) + + if self._is_error_payload(data): + logger.error(f"DM service returned error for dataset {dataset_id}: {data}") + return None + if not isinstance(data, dict): + logger.error(f"Unexpected dataset payload type for {dataset_id}: {type(data).__name__}") + return None + required = ["id", "name", "description", "datasetType", "status", "fileCount", "totalSize"] + if not all(k in data for k in required): + logger.error(f"Dataset payload missing required fields for {dataset_id}. Keys: {self._keys(data)}") + return None + return DatasetResponse(**data) + except httpx.HTTPError as e: + logger.error(f"Failed to get dataset {dataset_id}: {e}") + return None + except Exception as e: + logger.error(f"[Unexpected] [GET] dataset {dataset_id}: \n{e}\nRaw JSON received: \n{raw}") + + return None + + async def get_dataset_files( + self, + dataset_id: str, + page: int = 0, + size: int = 100, + file_type: Optional[str] = None, + status: Optional[str] = None + ) -> Optional[PagedDatasetFileResponse]: + """获取数据集文件列表""" + try: + logger.info(f"Get dataset files: dataset={dataset_id}, page={page}, size={size}") + params: dict = {"page": page, "size": size} + if file_type: + params["fileType"] = file_type + if status: + params["status"] = status + + response = await self.client.get( + f"/data-management/datasets/{dataset_id}/files", + params=params + ) + response.raise_for_status() + raw = response.json() + data = self._unwrap_payload(raw) + if self._is_error_payload(data): + logger.error(f"DM service returned error for dataset files {dataset_id}: {data}") + return None + if not isinstance(data, dict): + logger.error(f"Unexpected dataset files payload type for {dataset_id}: {type(data).__name__}") + return None + required = ["content", "totalElements", "totalPages", "page", "size"] + if not all(k in data for k in required): + logger.error(f"Files payload missing required fields for {dataset_id}. Keys: {self._keys(data)}") + return None + return PagedDatasetFileResponse(**data) + except httpx.HTTPError as e: + logger.error(f"Failed to get dataset files for {dataset_id}: {e}") + return None + except Exception as e: + logger.error(f"[Unexpected] [GET] dataset files {dataset_id}: \n{e}\nRaw JSON received: \n{raw}") + return None + + async def download_file(self, dataset_id: str, file_id: str) -> Optional[bytes]: + """下载文件内容""" + try: + logger.info(f"Download file: dataset={dataset_id}, file={file_id}") + response = await self.client.get( + f"/data-management/datasets/{dataset_id}/files/{file_id}/download" + ) + response.raise_for_status() + return response.content + except httpx.HTTPError as e: + logger.error(f"Failed to download file {file_id}: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error while downloading file {file_id}: {e}") + return None + + async def get_file_download_url(self, dataset_id: str, file_id: str) -> str: + """获取文件下载URL""" + return f"{self.base_url}/data-management/datasets/{dataset_id}/files/{file_id}/download" + + async def close(self): + """关闭客户端连接""" + await self.client.aclose() + logger.info("DM service client connection closed") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/clients/label_studio_client.py b/runtime/label-studio-adapter/app/clients/label_studio_client.py new file mode 100644 index 0000000..837a3a9 --- /dev/null +++ b/runtime/label-studio-adapter/app/clients/label_studio_client.py @@ -0,0 +1,469 @@ +import httpx +from typing import Optional, Dict, Any, List +import json + +from app.core.config import settings +from app.core.logging import get_logger +from app.schemas.label_studio import ( + LabelStudioProject, + LabelStudioCreateProjectRequest, + LabelStudioCreateTaskRequest +) + +logger = get_logger(__name__) + +class LabelStudioClient: + """Label Studio服务客户端 + + 使用 HTTP REST API 直接与 Label Studio 交互 + 认证方式:使用 Authorization: Token {token} 头部进行认证 + """ + + # 默认标注配置模板 + DEFAULT_LABEL_CONFIGS = { + "image": """ + + + + + + """, + "text": """ + + + + + + + + + """, + "audio": """ + + + """, + "video": """ + + + """ + } + + def __init__( + self, + base_url: Optional[str] = None, + token: Optional[str] = None, + timeout: float = 30.0 + ): + """初始化 Label Studio 客户端 + + Args: + base_url: Label Studio 服务地址 + token: API Token(使用 Authorization: Token {token} 头部) + timeout: 请求超时时间(秒) + """ + self.base_url = (base_url or settings.label_studio_base_url).rstrip("/") + self.token = token or settings.label_studio_user_token + self.timeout = timeout + + if not self.token: + raise ValueError("Label Studio API token is required") + + # 初始化 HTTP 客户端 + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=self.timeout, + headers={ + "Authorization": f"Token {self.token}", + "Content-Type": "application/json" + } + ) + + logger.info(f"Label Studio client initialized: {self.base_url}") + + def get_label_config_by_type(self, data_type: str) -> str: + """根据数据类型获取标注配置""" + return self.DEFAULT_LABEL_CONFIGS.get(data_type.lower(), self.DEFAULT_LABEL_CONFIGS["image"]) + + async def create_project( + self, + title: str, + description: str = "", + label_config: Optional[str] = None, + data_type: str = "image" + ) -> Optional[Dict[str, Any]]: + """创建Label Studio项目""" + try: + logger.info(f"Creating Label Studio project: {title}") + + if not label_config: + label_config = self.get_label_config_by_type(data_type) + + project_data = { + "title": title, + "description": description, + "label_config": label_config.strip() + } + + response = await self.client.post("/api/projects", json=project_data) + response.raise_for_status() + + project = response.json() + project_id = project.get("id") + + if not project_id: + raise Exception("Label Studio response does not contain project ID") + + logger.info(f"Project created successfully, ID: {project_id}") + return project + + except httpx.HTTPStatusError as e: + logger.error(f"Create project failed HTTP {e.response.status_code}: {e.response.text}") + return None + except Exception as e: + logger.error(f"Error while creating Label Studio project: {e}") + return None + + async def import_tasks( + self, + project_id: int, + tasks: List[Dict[str, Any]], + commit_to_project: bool = True, + return_task_ids: bool = True + ) -> Optional[Dict[str, Any]]: + """批量导入任务到Label Studio项目""" + try: + logger.info(f"Importing {len(tasks)} tasks into project {project_id}") + + response = await self.client.post( + f"/api/projects/{project_id}/import", + json=tasks, + params={ + "commit_to_project": str(commit_to_project).lower(), + "return_task_ids": str(return_task_ids).lower() + } + ) + response.raise_for_status() + + result = response.json() + task_count = result.get("task_count", len(tasks)) + + logger.info(f"Tasks imported successfully: {task_count}") + return result + + except httpx.HTTPStatusError as e: + logger.error(f"Import tasks failed HTTP {e.response.status_code}: {e.response.text}") + return None + except Exception as e: + logger.error(f"Error while importing tasks: {e}") + return None + + async def create_tasks_batch( + self, + project_id: str, + tasks: List[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """批量创建任务的便利方法""" + try: + pid = int(project_id) + return await self.import_tasks(pid, tasks) + except ValueError as e: + logger.error(f"Invalid project ID format: {project_id}, error: {e}") + return None + except Exception as e: + logger.error(f"Error while creating tasks in batch: {e}") + return None + + async def create_task( + self, + project_id: str, + data: Dict[str, Any], + meta: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + """创建单个任务""" + try: + task = {"data": data} + if meta: + task["meta"] = meta + + return await self.create_tasks_batch(project_id, [task]) + + except Exception as e: + logger.error(f"Error while creating single task: {e}") + return None + + async def get_project_tasks( + self, + project_id: str, + page: Optional[int] = None, + page_size: int = 1000 + ) -> Optional[Dict[str, Any]]: + """获取项目任务信息 + + Args: + project_id: 项目ID + page: 页码(从1开始)。如果为None,则获取所有任务 + page_size: 每页大小 + + Returns: + 如果指定了page参数,返回包含分页信息的字典: + { + "count": 总任务数, + "page": 当前页码, + "page_size": 每页大小, + "project_id": 项目ID, + "tasks": 当前页的任务列表 + } + + 如果page为None,返回包含所有任务的字典: + + "count": 总任务数, + "project_id": 项目ID, + "tasks": 所有任务列表 + } + """ + try: + pid = int(project_id) + + # 如果指定了page,直接获取单页任务 + if page is not None: + logger.info(f"Fetching tasks for project {pid}, page {page} (page_size={page_size})") + + response = await self.client.get( + f"/api/projects/{pid}/tasks", + params={ + "page": page, + "page_size": page_size + } + ) + response.raise_for_status() + + result = response.json() + + # 返回单页结果,包含分页信息 + return { + "count": result.get("total", len(result.get("tasks", []))), + "page": page, + "page_size": page_size, + "project_id": pid, + "tasks": result.get("tasks", []) + } + + # 如果未指定page,获取所有任务 + logger.info(f"Start fetching all tasks for project {pid} (page_size={page_size})") + all_tasks = [] + current_page = 1 + + while True: + try: + response = await self.client.get( + f"/api/projects/{pid}/tasks", + params={ + "page": current_page, + "page_size": page_size + } + ) + response.raise_for_status() + + result = response.json() + tasks = result.get("tasks", []) + + if not tasks: + logger.debug(f"No more tasks on page {current_page}") + break + + all_tasks.extend(tasks) + logger.debug(f"Fetched page {current_page}, {len(tasks)} tasks") + + # 检查是否还有更多页 + total = result.get("total", 0) + if len(all_tasks) >= total: + break + + current_page += 1 + + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + # 超出页数范围,结束分页 + logger.debug(f"Reached last page (page {current_page})") + break + else: + raise + + logger.info(f"Fetched all tasks for project {pid}, total {len(all_tasks)}") + + # 返回所有任务,不包含分页信息 + return { + "count": len(all_tasks), + "project_id": pid, + "tasks": all_tasks + } + + except httpx.HTTPStatusError as e: + logger.error(f"获取项目任务失败 HTTP {e.response.status_code}: {e.response.text}") + return None + except Exception as e: + logger.error(f"获取项目任务时发生错误: {e}") + return None + + async def delete_task( + self, + task_id: int + ) -> bool: + """删除单个任务""" + try: + logger.info(f"Deleting task: {task_id}") + + response = await self.client.delete(f"/api/tasks/{task_id}") + response.raise_for_status() + + logger.info(f"Task deleted: {task_id}") + return True + + except httpx.HTTPStatusError as e: + logger.error(f"Delete task {task_id} failed HTTP {e.response.status_code}: {e.response.text}") + return False + except Exception as e: + logger.error(f"Error while deleting task {task_id}: {e}") + return False + + async def delete_tasks_batch( + self, + task_ids: List[int] + ) -> Dict[str, int]: + """批量删除任务""" + try: + logger.info(f"Deleting {len(task_ids)} tasks in batch") + + successful_deletions = 0 + failed_deletions = 0 + + for task_id in task_ids: + if await self.delete_task(task_id): + successful_deletions += 1 + else: + failed_deletions += 1 + + logger.info(f"Batch deletion finished: success {successful_deletions}, failed {failed_deletions}") + + return { + "successful": successful_deletions, + "failed": failed_deletions, + "total": len(task_ids) + } + + except Exception as e: + logger.error(f"Error while deleting tasks in batch: {e}") + return { + "successful": 0, + "failed": len(task_ids), + "total": len(task_ids) + } + + async def get_project(self, project_id: int) -> Optional[Dict[str, Any]]: + """获取项目信息""" + try: + logger.info(f"Fetching project info: {project_id}") + + response = await self.client.get(f"/api/projects/{project_id}") + response.raise_for_status() + + return response.json() + + except httpx.HTTPStatusError as e: + logger.error(f"Get project info failed HTTP {e.response.status_code}: {e.response.text}") + return None + except Exception as e: + logger.error(f"Error while getting project info: {e}") + return None + + async def delete_project(self, project_id: int) -> bool: + """删除项目""" + try: + logger.info(f"Deleting project: {project_id}") + + response = await self.client.delete(f"/api/projects/{project_id}") + response.raise_for_status() + + logger.info(f"Project deleted: {project_id}") + return True + + except httpx.HTTPStatusError as e: + logger.error(f"Delete project {project_id} failed HTTP {e.response.status_code}: {e.response.text}") + return False + except Exception as e: + logger.error(f"Error while deleting project {project_id}: {e}") + return False + + async def create_local_storage( + self, + project_id: int, + path: str, + title: str, + use_blob_urls: bool = True, + regex_filter: Optional[str] = None, + description: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """创建本地文件存储配置 + + Args: + project_id: Label Studio 项目 ID + path: 本地文件路径(在 Label Studio 容器中的路径) + title: 存储配置标题 + use_blob_urls: 是否使用 blob URLs(建议 True) + regex_filter: 文件过滤正则表达式(可选) + description: 存储描述(可选) + + Returns: + 创建的存储配置信息,失败返回 None + """ + try: + logger.info(f"Creating local storage for project {project_id}: {path}") + + storage_data = { + "project": project_id, + "path": path, + "title": title, + "use_blob_urls": use_blob_urls + } + + if regex_filter: + storage_data["regex_filter"] = regex_filter + if description: + storage_data["description"] = description + + response = await self.client.post( + "/api/storages/localfiles/", + json=storage_data + ) + response.raise_for_status() + + storage = response.json() + storage_id = storage.get("id") + + logger.info(f"Local storage created successfully, ID: {storage_id}") + return storage + + except httpx.HTTPStatusError as e: + logger.error(f"Create local storage failed HTTP {e.response.status_code}: {e.response.text}") + return None + except Exception as e: + logger.error(f"Error while creating local storage: {e}") + return None + + async def close(self): + """关闭客户端连接""" + try: + await self.client.aclose() + logger.info("Label Studio client closed") + except Exception as e: + logger.error(f"Error while closing Label Studio client: {e}") diff --git a/runtime/label-studio-adapter/app/core/__init__.py b/runtime/label-studio-adapter/app/core/__init__.py new file mode 100644 index 0000000..4402e6f --- /dev/null +++ b/runtime/label-studio-adapter/app/core/__init__.py @@ -0,0 +1 @@ +# app/core/__init__.py \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/core/config.py b/runtime/label-studio-adapter/app/core/config.py new file mode 100644 index 0000000..d6cfe5c --- /dev/null +++ b/runtime/label-studio-adapter/app/core/config.py @@ -0,0 +1,146 @@ +from pydantic_settings import BaseSettings +from typing import Optional +import os +from pathlib import Path + +class Settings(BaseSettings): + """应用程序配置""" + + class Config: + env_file = ".env" + case_sensitive = False + extra = 'ignore' # 允许额外字段(如 Shell 脚本专用的环境变量) + + # ========================= + # Adapter 服务配置 + # ========================= + app_name: str = "Label Studio Adapter" + app_version: str = "1.0.0" + app_description: str = "Adapter for integrating Data Management System with Label Studio" + debug: bool = True + + # 服务器配置 + host: str = "0.0.0.0" + port: int = 8000 + + # CORS配置 + allowed_origins: list = ["*"] + allowed_methods: list = ["*"] + allowed_headers: list = ["*"] + + # MySQL数据库配置 (优先级1) + mysql_host: Optional[str] = None + mysql_port: int = 3306 + mysql_user: Optional[str] = None + mysql_password: Optional[str] = None + mysql_database: Optional[str] = None + + # PostgreSQL数据库配置 (优先级2) + postgres_host: Optional[str] = None + postgres_port: int = 5432 + postgres_user: Optional[str] = None + postgres_password: Optional[str] = None + postgres_database: Optional[str] = None + + # SQLite数据库配置 (优先级3 - 兜底) + sqlite_path: str = "data/labelstudio_adapter.db" + + # 直接数据库URL配置(如果提供,将覆盖上述配置) + database_url: Optional[str] = None + + # 日志配置 + log_level: str = "INFO" + + # 安全配置 + secret_key: str = "your-secret-key-change-this-in-production" + access_token_expire_minutes: int = 30 + + # ========================= + # Label Studio 服务配置 + # ========================= + label_studio_base_url: str = "http://label-studio:8080" + label_studio_username: Optional[str] = None # Label Studio 用户名(用于登录) + label_studio_password: Optional[str] = None # Label Studio 密码(用于登录) + label_studio_user_token: Optional[str] = None # Legacy Token + + label_studio_local_storage_dataset_base_path: str = "/label-studio/local_files/dataset" # Label Studio容器中的本地存储基础路径 + label_studio_local_storage_upload_base_path: str = "/label-studio/local_files/upload" # Label Studio容器中的本地存储基础路径 + label_studio_file_path_prefix: str = "/data/local-files/?d=" # Label Studio本地文件服务路径前缀 + + ls_task_page_size: int = 1000 + + + # ========================= + # Data Management 服务配置 + # ========================= + dm_service_base_url: str = "http://data-engine" + dm_file_path_prefix: str = "/" # DM存储文件夹前缀 + + + @property + def computed_database_url(self) -> str: + """ + 根据优先级自动选择数据库连接URL + 优先级:MySQL > PostgreSQL > SQLite3 + """ + # 如果直接提供了database_url,优先使用 + if self.database_url: + return self.database_url + + # 优先级1: MySQL + if all([self.mysql_host, self.mysql_user, self.mysql_password, self.mysql_database]): + return f"mysql+aiomysql://{self.mysql_user}:{self.mysql_password}@{self.mysql_host}:{self.mysql_port}/{self.mysql_database}" + + # 优先级2: PostgreSQL + if all([self.postgres_host, self.postgres_user, self.postgres_password, self.postgres_database]): + return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_database}" + + # 优先级3: SQLite (兜底) + sqlite_full_path = Path(self.sqlite_path).absolute() + # 确保目录存在 + sqlite_full_path.parent.mkdir(parents=True, exist_ok=True) + return f"sqlite+aiosqlite:///{sqlite_full_path}" + + @property + def sync_database_url(self) -> str: + """ + 用于数据库迁移的同步连接URL + 将异步驱动替换为同步驱动 + """ + async_url = self.computed_database_url + + # 替换异步驱动为同步驱动 + sync_replacements = { + "mysql+aiomysql://": "mysql+pymysql://", + "postgresql+asyncpg://": "postgresql+psycopg2://", + "sqlite+aiosqlite:///": "sqlite:///" + } + + for async_driver, sync_driver in sync_replacements.items(): + if async_url.startswith(async_driver): + return async_url.replace(async_driver, sync_driver) + + return async_url + + def get_database_info(self) -> dict: + """获取数据库配置信息""" + url = self.computed_database_url + + if url.startswith("mysql"): + db_type = "MySQL" + elif url.startswith("postgresql"): + db_type = "PostgreSQL" + elif url.startswith("sqlite"): + db_type = "SQLite" + else: + db_type = "Unknown" + + return { + "type": db_type, + "url": url, + "sync_url": self.sync_database_url + } + + +# 全局设置实例 +settings = Settings() \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/core/logging.py b/runtime/label-studio-adapter/app/core/logging.py new file mode 100644 index 0000000..e474c22 --- /dev/null +++ b/runtime/label-studio-adapter/app/core/logging.py @@ -0,0 +1,53 @@ +import logging +import sys +from pathlib import Path +from app.core.config import settings + +def setup_logging(): + """配置应用程序日志""" + + # 创建logs目录 + log_dir = Path("logs") + log_dir.mkdir(exist_ok=True) + + # 配置日志格式 + log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + + # 创建处理器 + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(getattr(logging, settings.log_level.upper())) + + file_handler = logging.FileHandler( + log_dir / "app.log", + encoding="utf-8" + ) + file_handler.setLevel(getattr(logging, settings.log_level.upper())) + + error_handler = logging.FileHandler( + log_dir / "error.log", + encoding="utf-8" + ) + error_handler.setLevel(logging.ERROR) + + # 设置格式 + formatter = logging.Formatter(log_format, date_format) + console_handler.setFormatter(formatter) + file_handler.setFormatter(formatter) + error_handler.setFormatter(formatter) + + # 配置根日志器 + root_logger = logging.getLogger() + root_logger.setLevel(getattr(logging, settings.log_level.upper())) + root_logger.addHandler(console_handler) + root_logger.addHandler(file_handler) + root_logger.addHandler(error_handler) + + # 配置第三方库日志级别(减少详细日志) + logging.getLogger("uvicorn").setLevel(logging.WARNING) + logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR) # 隐藏SQL查询日志 + logging.getLogger("httpx").setLevel(logging.WARNING) + +def get_logger(name: str) -> logging.Logger: + """获取指定名称的日志器""" + return logging.getLogger(name) \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/db/__init__.py b/runtime/label-studio-adapter/app/db/__init__.py new file mode 100644 index 0000000..894b869 --- /dev/null +++ b/runtime/label-studio-adapter/app/db/__init__.py @@ -0,0 +1 @@ +# app/db/__init__.py \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/db/database.py b/runtime/label-studio-adapter/app/db/database.py new file mode 100644 index 0000000..d475a54 --- /dev/null +++ b/runtime/label-studio-adapter/app/db/database.py @@ -0,0 +1,39 @@ +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.orm import declarative_base +from app.core.config import settings +from app.core.logging import get_logger +from typing import AsyncGenerator + +logger = get_logger(__name__) + +# 获取数据库配置信息 +db_info = settings.get_database_info() +logger.info(f"使用数据库: {db_info['type']}") +logger.info(f"连接URL: {db_info['url']}") + +# 创建数据库引擎 +engine = create_async_engine( + settings.computed_database_url, + echo=False, # 关闭SQL调试日志以减少输出 + future=True, + # SQLite特殊配置 + connect_args={"check_same_thread": False} if "sqlite" in settings.computed_database_url else {} +) + +# 创建会话工厂 +AsyncSessionLocal = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False +) + +# 创建基础模型类 +Base = declarative_base() +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话""" + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() + \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/exceptions.py b/runtime/label-studio-adapter/app/exceptions.py new file mode 100644 index 0000000..383e303 --- /dev/null +++ b/runtime/label-studio-adapter/app/exceptions.py @@ -0,0 +1,31 @@ +""" +自定义异常类定义 +""" + +class LabelStudioAdapterException(Exception): + """Label Studio Adapter 基础异常类""" + pass + +class DatasetMappingNotFoundError(LabelStudioAdapterException): + """数据集映射未找到异常""" + def __init__(self, mapping_id: str): + self.mapping_id = mapping_id + super().__init__(f"Dataset mapping not found: {mapping_id}") + +class NoDatasetInfoFoundError(LabelStudioAdapterException): + """无法获取数据集信息异常""" + def __init__(self, dataset_uuid: str): + self.dataset_uuid = dataset_uuid + super().__init__(f"Failed to get dataset info: {dataset_uuid}") + +class LabelStudioClientError(LabelStudioAdapterException): + """Label Studio 客户端错误""" + pass + +class DMServiceClientError(LabelStudioAdapterException): + """DM 服务客户端错误""" + pass + +class SyncServiceError(LabelStudioAdapterException): + """同步服务错误""" + pass \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/main.py b/runtime/label-studio-adapter/app/main.py new file mode 100644 index 0000000..44e4bc7 --- /dev/null +++ b/runtime/label-studio-adapter/app/main.py @@ -0,0 +1,172 @@ +from fastapi import FastAPI, Request, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException +from contextlib import asynccontextmanager +from typing import Dict, Any + +from .core.config import settings +from .core.logging import setup_logging, get_logger +from .clients import DMServiceClient, LabelStudioClient, set_clients +from .api import api_router +from .schemas import StandardResponse + +# 设置日志 +setup_logging() +logger = get_logger(__name__) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用程序生命周期管理""" + + # 启动时初始化 + logger.info("Starting Label Studio Adapter...") + + # 初始化客户端 + dm_client = DMServiceClient() + + # 初始化 Label Studio 客户端,使用 HTTP REST API + Token 认证 + ls_client = LabelStudioClient( + base_url=settings.label_studio_base_url, + token=settings.label_studio_user_token + ) + + # 设置全局客户端 + set_clients(dm_client, ls_client) + + # 数据库初始化由 Alembic 管理 + # 在 Docker 环境中,entrypoint.sh 会在启动前运行: alembic upgrade head + # 在开发环境中,手动运行: alembic upgrade head + logger.info("Database schema managed by Alembic") + + logger.info("Label Studio Adapter started") + + yield + + # 关闭时清理 + logger.info("Shutting down Label Studio Adapter...") + + # 客户端清理会在客户端管理器中处理 + logger.info("Label Studio Adapter stopped") + +# 创建FastAPI应用 +app = FastAPI( + title=settings.app_name, + description=settings.app_description, + version=settings.app_version, + debug=settings.debug, + lifespan=lifespan +) + +# 配置CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=settings.allowed_origins, + allow_credentials=True, + allow_methods=settings.allowed_methods, + allow_headers=settings.allowed_headers, +) + +# 自定义异常处理器:StarletteHTTPException (包括404等) +@app.exception_handler(StarletteHTTPException) +async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException): + """将Starlette的HTTPException转换为标准响应格式""" + return JSONResponse( + status_code=exc.status_code, + content={ + "code": exc.status_code, + "message": "error", + "data": { + "detail": exc.detail + } + } + ) + +# 自定义异常处理器:FastAPI HTTPException +@app.exception_handler(HTTPException) +async def fastapi_http_exception_handler(request: Request, exc: HTTPException): + """将FastAPI的HTTPException转换为标准响应格式""" + return JSONResponse( + status_code=exc.status_code, + content={ + "code": exc.status_code, + "message": "error", + "data": { + "detail": exc.detail + } + } + ) + +# 自定义异常处理器:RequestValidationError +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """将请求验证错误转换为标准响应格式""" + return JSONResponse( + status_code=422, + content={ + "code": 422, + "message": "error", + "data": { + "detail": "Validation error", + "errors": exc.errors() + } + } + ) + +# 自定义异常处理器:未捕获的异常 +@app.exception_handler(Exception) +async def general_exception_handler(request: Request, exc: Exception): + """将未捕获的异常转换为标准响应格式""" + logger.error(f"Unhandled exception: {exc}", exc_info=True) + return JSONResponse( + status_code=500, + content={ + "code": 500, + "message": "error", + "data": { + "detail": "Internal server error" + } + } + ) + +# 注册路由 +app.include_router(api_router, prefix="/api") + +# 测试端点:验证异常处理 +@app.get("/test-404", include_in_schema=False) +async def test_404(): + """测试404异常处理""" + raise HTTPException(status_code=404, detail="Test 404 error") + +@app.get("/test-500", include_in_schema=False) +async def test_500(): + """测试500异常处理""" + raise Exception("Test uncaught exception") + +# 根路径重定向到文档 +@app.get("/", response_model=StandardResponse[Dict[str, Any]], include_in_schema=False) +async def root(): + """根路径,返回服务信息""" + return StandardResponse( + code=200, + message="success", + data={ + "message": f"{settings.app_name} is running", + "version": settings.app_version, + "docs_url": "/docs", + "dm_service_url": settings.dm_service_base_url, + "label_studio_url": settings.label_studio_base_url + } + ) + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + log_level=settings.log_level.lower() + ) \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/models/__init__.py b/runtime/label-studio-adapter/app/models/__init__.py new file mode 100644 index 0000000..db4bb17 --- /dev/null +++ b/runtime/label-studio-adapter/app/models/__init__.py @@ -0,0 +1,5 @@ +# app/models/__init__.py + +from .dataset_mapping import DatasetMapping + +__all__ = ["DatasetMapping"] \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/models/dataset_mapping.py b/runtime/label-studio-adapter/app/models/dataset_mapping.py new file mode 100644 index 0000000..1dcd206 --- /dev/null +++ b/runtime/label-studio-adapter/app/models/dataset_mapping.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, String, DateTime, Boolean, Text +from sqlalchemy.sql import func +from app.db.database import Base +import uuid + +class DatasetMapping(Base): + """数据集映射模型""" + + __tablename__ = "mapping" + + mapping_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + source_dataset_id = Column(String(36), nullable=False, comment="源数据集ID") + labelling_project_id = Column(String(36), nullable=False, comment="标注项目ID") + labelling_project_name = Column(String(255), nullable=True, comment="标注项目名称") + created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间") + last_updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), comment="最后更新时间") + deleted_at = Column(DateTime(timezone=True), nullable=True, comment="删除时间") + + def __repr__(self): + return f"" + + @property + def is_deleted(self) -> bool: + """检查是否已被软删除""" + return self.deleted_at is not None \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/schemas/__init__.py b/runtime/label-studio-adapter/app/schemas/__init__.py new file mode 100644 index 0000000..7941ab7 --- /dev/null +++ b/runtime/label-studio-adapter/app/schemas/__init__.py @@ -0,0 +1,29 @@ +# app/schemas/__init__.py + +from .common import * +from .dataset_mapping import * +from .dm_service import * +from .label_studio import * + +__all__ = [ + # Common schemas + "StandardResponse", + + # Dataset Mapping schemas + "DatasetMappingBase", + "DatasetMappingCreateRequest", + "DatasetMappingUpdateRequest", + "DatasetMappingResponse", + "DatasetMappingCreateResponse", + "SyncDatasetResponse", + "DeleteDatasetResponse", + + # DM Service schemas + "DatasetFileResponse", + "PagedDatasetFileResponse", + "DatasetResponse", + + # Label Studio schemas + "LabelStudioProject", + "LabelStudioTask" +] \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/schemas/common.py b/runtime/label-studio-adapter/app/schemas/common.py new file mode 100644 index 0000000..f931844 --- /dev/null +++ b/runtime/label-studio-adapter/app/schemas/common.py @@ -0,0 +1,27 @@ +""" +通用响应模型 +""" +from typing import Generic, TypeVar, Optional +from pydantic import BaseModel, Field + +# 定义泛型类型变量 +T = TypeVar('T') + +class StandardResponse(BaseModel, Generic[T]): + """ + 标准API响应格式 + + 所有API端点应返回此格式,确保响应的一致性 + """ + code: int = Field(..., description="HTTP状态码") + message: str = Field(..., description="响应消息") + data: Optional[T] = Field(None, description="响应数据") + + class Config: + json_schema_extra = { + "example": { + "code": 200, + "message": "success", + "data": {} + } + } diff --git a/runtime/label-studio-adapter/app/schemas/dataset_mapping.py b/runtime/label-studio-adapter/app/schemas/dataset_mapping.py new file mode 100644 index 0000000..9806c77 --- /dev/null +++ b/runtime/label-studio-adapter/app/schemas/dataset_mapping.py @@ -0,0 +1,53 @@ +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + +class DatasetMappingBase(BaseModel): + """数据集映射 基础模型""" + source_dataset_id: str = Field(..., description="源数据集ID") + +class DatasetMappingCreateRequest(DatasetMappingBase): + """数据集映射 创建 请求模型""" + pass + +class DatasetMappingCreateResponse(BaseModel): + """数据集映射 创建 响应模型""" + mapping_id: str = Field(..., description="映射UUID") + labelling_project_id: str = Field(..., description="Label Studio项目ID") + labelling_project_name: str = Field(..., description="Label Studio项目名称") + message: str = Field(..., description="响应消息") + +class DatasetMappingUpdateRequest(BaseModel): + """数据集映射 更新 请求模型""" + source_dataset_id: Optional[str] = Field(None, description="源数据集ID") + +class DatasetMappingResponse(DatasetMappingBase): + """数据集映射 查询 响应模型""" + mapping_id: str = Field(..., description="映射UUID") + labelling_project_id: str = Field(..., description="标注项目ID") + labelling_project_name: Optional[str] = Field(None, description="标注项目名称") + created_at: datetime = Field(..., description="创建时间") + last_updated_at: datetime = Field(..., description="最后更新时间") + deleted_at: Optional[datetime] = Field(None, description="删除时间") + + class Config: + from_attributes = True + +class SyncDatasetRequest(BaseModel): + """同步数据集请求模型""" + mapping_id: str = Field(..., description="映射ID(mapping UUID)") + batch_size: int = Field(50, ge=1, le=100, description="批处理大小") + +class SyncDatasetResponse(BaseModel): + """同步数据集响应模型""" + mapping_id: str = Field(..., description="映射UUID") + status: str = Field(..., description="同步状态") + synced_files: int = Field(..., description="已同步文件数量") + total_files: int = Field(0, description="总文件数量") + message: str = Field(..., description="响应消息") + +class DeleteDatasetResponse(BaseModel): + """删除数据集响应模型""" + mapping_id: str = Field(..., description="映射UUID") + status: str = Field(..., description="删除状态") + message: str = Field(..., description="响应消息") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/schemas/dm_service.py b/runtime/label-studio-adapter/app/schemas/dm_service.py new file mode 100644 index 0000000..aca3c40 --- /dev/null +++ b/runtime/label-studio-adapter/app/schemas/dm_service.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from datetime import datetime + +class DatasetFileResponse(BaseModel): + """DM服务数据集文件响应模型""" + id: str = Field(..., description="文件ID") + fileName: str = Field(..., description="文件名") + fileType: str = Field(..., description="文件类型") + filePath: str = Field(..., description="文件路径") + originalName: Optional[str] = Field(None, description="原始文件名") + size: Optional[int] = Field(None, description="文件大小(字节)") + status: Optional[str] = Field(None, description="文件状态") + uploadedAt: Optional[datetime] = Field(None, description="上传时间") + description: Optional[str] = Field(None, description="文件描述") + uploadedBy: Optional[str] = Field(None, description="上传者") + lastAccessTime: Optional[datetime] = Field(None, description="最后访问时间") + +class PagedDatasetFileResponse(BaseModel): + """DM服务分页文件响应模型""" + content: List[DatasetFileResponse] = Field(..., description="文件列表") + totalElements: int = Field(..., description="总元素数") + totalPages: int = Field(..., description="总页数") + page: int = Field(..., description="当前页码") + size: int = Field(..., description="每页大小") + +class DatasetTypeResponse(BaseModel): + """数据集类型响应模型""" + code: str = Field(..., description="类型编码") + name: str = Field(..., description="类型名称") + description: Optional[str] = Field(None, description="类型描述") + supportedFormats: List[str] = Field(default_factory=list, description="支持的文件格式") + icon: Optional[str] = Field(None, description="图标") + +class DatasetResponse(BaseModel): + """DM服务数据集响应模型""" + id: str = Field(..., description="数据集ID") + name: str = Field(..., description="数据集名称") + description: Optional[str] = Field(None, description="数据集描述") + datasetType: str = Field(..., description="数据集类型", alias="datasetType") + status: str = Field(..., description="数据集状态") + fileCount: int = Field(..., description="文件数量") + totalSize: int = Field(..., description="总大小(字节)") + createdAt: Optional[datetime] = Field(None, description="创建时间") + updatedAt: Optional[datetime] = Field(None, description="更新时间") + createdBy: Optional[str] = Field(None, description="创建者") + + # 为了向后兼容,添加一个属性方法返回类型对象 + @property + def type(self) -> DatasetTypeResponse: + """兼容属性:返回类型对象""" + return DatasetTypeResponse( + code=self.datasetType, + name=self.datasetType, + description=None, + supportedFormats=[], + icon=None + ) \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/schemas/label_studio.py b/runtime/label-studio-adapter/app/schemas/label_studio.py new file mode 100644 index 0000000..66f5e71 --- /dev/null +++ b/runtime/label-studio-adapter/app/schemas/label_studio.py @@ -0,0 +1,37 @@ +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional, List +from datetime import datetime + +class LabelStudioProject(BaseModel): + """Label Studio项目模型""" + id: int = Field(..., description="项目ID") + title: str = Field(..., description="项目标题") + description: Optional[str] = Field(None, description="项目描述") + label_config: str = Field(..., description="标注配置") + created_at: Optional[datetime] = Field(None, description="创建时间") + updated_at: Optional[datetime] = Field(None, description="更新时间") + +class LabelStudioTaskData(BaseModel): + """Label Studio任务数据模型""" + image: Optional[str] = Field(None, description="图像URL") + text: Optional[str] = Field(None, description="文本内容") + audio: Optional[str] = Field(None, description="音频URL") + video: Optional[str] = Field(None, description="视频URL") + filename: Optional[str] = Field(None, description="文件名") + +class LabelStudioTask(BaseModel): + """Label Studio任务模型""" + data: LabelStudioTaskData = Field(..., description="任务数据") + project: Optional[int] = Field(None, description="项目ID") + meta: Optional[Dict[str, Any]] = Field(None, description="元数据") + +class LabelStudioCreateProjectRequest(BaseModel): + """创建Label Studio项目请求模型""" + title: str = Field(..., description="项目标题") + description: str = Field("", description="项目描述") + label_config: str = Field(..., description="标注配置") + +class LabelStudioCreateTaskRequest(BaseModel): + """创建Label Studio任务请求模型""" + data: Dict[str, Any] = Field(..., description="任务数据") + project: Optional[int] = Field(None, description="项目ID") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/services/__init__.py b/runtime/label-studio-adapter/app/services/__init__.py new file mode 100644 index 0000000..7818db1 --- /dev/null +++ b/runtime/label-studio-adapter/app/services/__init__.py @@ -0,0 +1,6 @@ +# app/services/__init__.py + +from .dataset_mapping_service import DatasetMappingService +from .sync_service import SyncService + +__all__ = ["DatasetMappingService", "SyncService"] \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/services/dataset_mapping_service.py b/runtime/label-studio-adapter/app/services/dataset_mapping_service.py new file mode 100644 index 0000000..a97e165 --- /dev/null +++ b/runtime/label-studio-adapter/app/services/dataset_mapping_service.py @@ -0,0 +1,223 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update +from typing import Optional, List +from datetime import datetime +import uuid + +from app.models.dataset_mapping import DatasetMapping +from app.schemas.dataset_mapping import ( + DatasetMappingCreateRequest, + DatasetMappingUpdateRequest, + DatasetMappingResponse +) +from app.core.logging import get_logger + +logger = get_logger(__name__) + +class DatasetMappingService: + """数据集映射服务""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def create_mapping( + self, + mapping_data: DatasetMappingCreateRequest, + labelling_project_id: str, + labelling_project_name: str + ) -> DatasetMappingResponse: + """创建数据集映射""" + logger.info(f"Create dataset mapping: {mapping_data.source_dataset_id} -> {labelling_project_id}") + + db_mapping = DatasetMapping( + mapping_id=str(uuid.uuid4()), + source_dataset_id=mapping_data.source_dataset_id, + labelling_project_id=labelling_project_id, + labelling_project_name=labelling_project_name + ) + + self.db.add(db_mapping) + await self.db.commit() + await self.db.refresh(db_mapping) + + logger.info(f"Mapping created: {db_mapping.mapping_id}") + return DatasetMappingResponse.model_validate(db_mapping) + + async def get_mapping_by_source_uuid( + self, + source_dataset_id: str + ) -> Optional[DatasetMappingResponse]: + """根据源数据集ID获取映射(返回第一个未删除的)""" + logger.debug(f"Get mapping by source dataset id: {source_dataset_id}") + + result = await self.db.execute( + select(DatasetMapping).where( + DatasetMapping.source_dataset_id == source_dataset_id, + DatasetMapping.deleted_at.is_(None) + ) + ) + mapping = result.scalar_one_or_none() + + if mapping: + logger.debug(f"Found mapping: {mapping.mapping_id}") + return DatasetMappingResponse.model_validate(mapping) + + logger.debug(f"No mapping found for source dataset id: {source_dataset_id}") + return None + + async def get_mappings_by_source_dataset_id( + self, + source_dataset_id: str, + include_deleted: bool = False + ) -> List[DatasetMappingResponse]: + """根据源数据集ID获取所有映射关系""" + logger.debug(f"Get all mappings by source dataset id: {source_dataset_id}") + + query = select(DatasetMapping).where( + DatasetMapping.source_dataset_id == source_dataset_id + ) + + if not include_deleted: + query = query.where(DatasetMapping.deleted_at.is_(None)) + + result = await self.db.execute( + query.order_by(DatasetMapping.created_at.desc()) + ) + mappings = result.scalars().all() + + logger.debug(f"Found {len(mappings)} mappings") + return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings] + + async def get_mapping_by_labelling_project_id( + self, + labelling_project_id: str + ) -> Optional[DatasetMappingResponse]: + """根据Label Studio项目ID获取映射""" + logger.debug(f"Get mapping by Label Studio project id: {labelling_project_id}") + + result = await self.db.execute( + select(DatasetMapping).where( + DatasetMapping.labelling_project_id == labelling_project_id, + DatasetMapping.deleted_at.is_(None) + ) + ) + mapping = result.scalar_one_or_none() + + if mapping: + logger.debug(f"Found mapping: {mapping.mapping_id}") + return DatasetMappingResponse.model_validate(mapping) + + logger.debug(f"No mapping found for Label Studio project id: {labelling_project_id}") + return None + + async def get_mapping_by_uuid(self, mapping_id: str) -> Optional[DatasetMappingResponse]: + """根据映射UUID获取映射""" + logger.debug(f"Get mapping: {mapping_id}") + + result = await self.db.execute( + select(DatasetMapping).where( + DatasetMapping.mapping_id == mapping_id, + DatasetMapping.deleted_at.is_(None) + ) + ) + mapping = result.scalar_one_or_none() + + if mapping: + logger.debug(f"Found mapping: {mapping.mapping_id}") + return DatasetMappingResponse.model_validate(mapping) + + logger.debug(f"Mapping not found: {mapping_id}") + return None + + async def update_mapping( + self, + mapping_id: str, + update_data: DatasetMappingUpdateRequest + ) -> Optional[DatasetMappingResponse]: + """更新映射信息""" + logger.info(f"Update mapping: {mapping_id}") + + mapping = await self.get_mapping_by_uuid(mapping_id) + if not mapping: + return None + + update_values = update_data.model_dump(exclude_unset=True) + update_values["last_updated_at"] = datetime.utcnow() + + result = await self.db.execute( + update(DatasetMapping) + .where(DatasetMapping.mapping_id == mapping_id) + .values(**update_values) + ) + await self.db.commit() + + if result.rowcount > 0: + return await self.get_mapping_by_uuid(mapping_id) + return None + + async def update_last_updated_at(self, mapping_id: str) -> bool: + """更新最后更新时间""" + logger.debug(f"Update mapping last updated at: {mapping_id}") + + result = await self.db.execute( + update(DatasetMapping) + .where( + DatasetMapping.mapping_id == mapping_id, + DatasetMapping.deleted_at.is_(None) + ) + .values(last_updated_at=datetime.utcnow()) + ) + await self.db.commit() + return result.rowcount > 0 + + async def soft_delete_mapping(self, mapping_id: str) -> bool: + """软删除映射""" + logger.info(f"Soft delete mapping: {mapping_id}") + + result = await self.db.execute( + update(DatasetMapping) + .where( + DatasetMapping.mapping_id == mapping_id, + DatasetMapping.deleted_at.is_(None) + ) + .values(deleted_at=datetime.utcnow()) + ) + await self.db.commit() + + success = result.rowcount > 0 + if success: + logger.info(f"Mapping soft-deleted: {mapping_id}") + else: + logger.warning(f"Mapping not exists or already deleted: {mapping_id}") + + return success + + async def get_all_mappings( + self, + skip: int = 0, + limit: int = 100 + ) -> List[DatasetMappingResponse]: + """获取所有有效映射""" + logger.debug(f"List all mappings, skip: {skip}, limit: {limit}") + + result = await self.db.execute( + select(DatasetMapping) + .where(DatasetMapping.deleted_at.is_(None)) + .offset(skip) + .limit(limit) + .order_by(DatasetMapping.created_at.desc()) + ) + mappings = result.scalars().all() + + logger.debug(f"Found {len(mappings)} mappings") + return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings] + + async def count_mappings(self) -> int: + """统计映射总数""" + result = await self.db.execute( + select(DatasetMapping) + .where(DatasetMapping.deleted_at.is_(None)) + ) + mappings = result.scalars().all() + return len(mappings) \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/services/sync_service.py b/runtime/label-studio-adapter/app/services/sync_service.py new file mode 100644 index 0000000..60f2a36 --- /dev/null +++ b/runtime/label-studio-adapter/app/services/sync_service.py @@ -0,0 +1,276 @@ +from typing import Optional, List, Dict, Any, Tuple +from app.clients.dm_client import DMServiceClient +from app.clients.label_studio_client import LabelStudioClient +from app.services.dataset_mapping_service import DatasetMappingService +from app.schemas.dataset_mapping import SyncDatasetResponse +from app.core.logging import get_logger +from app.core.config import settings +from app.exceptions import NoDatasetInfoFoundError, DatasetMappingNotFoundError + +logger = get_logger(__name__) + +class SyncService: + """数据同步服务""" + + def __init__( + self, + dm_client: DMServiceClient, + ls_client: LabelStudioClient, + mapping_service: DatasetMappingService + ): + self.dm_client = dm_client + self.ls_client = ls_client + self.mapping_service = mapping_service + + def determine_data_type(self, file_type: str) -> str: + """根据文件类型确定数据类型""" + file_type_lower = file_type.lower() + + if any(ext in file_type_lower for ext in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'svg', 'webp']): + return 'image' + elif any(ext in file_type_lower for ext in ['mp3', 'wav', 'flac', 'aac', 'ogg']): + return 'audio' + elif any(ext in file_type_lower for ext in ['mp4', 'avi', 'mov', 'wmv', 'flv', 'webm']): + return 'video' + elif any(ext in file_type_lower for ext in ['txt', 'doc', 'docx', 'pdf']): + return 'text' + else: + return 'image' # 默认为图像类型 + + async def get_existing_dm_file_mapping(self, project_id: str) -> Dict[str, int]: + """ + 获取Label Studio项目中已存在的DM文件ID到任务ID的映射 + + Args: + project_id: Label Studio项目ID + + Returns: + dm_file_id到task_id的映射字典 + """ + try: + logger.info(f"Fetching existing task mappings for project {project_id} (page_size={settings.ls_task_page_size})") + dm_file_to_task_mapping = {} + + # 使用Label Studio客户端封装的方法获取所有任务 + page_size = getattr(settings, 'ls_task_page_size', 1000) + + # 调用封装好的方法获取所有任务,page=None表示获取全部 + result = await self.ls_client.get_project_tasks( + project_id=project_id, + page=None, # 不指定page,获取所有任务 + page_size=page_size + ) + + if not result: + logger.warning(f"Failed to fetch tasks for project {project_id}") + return {} + + all_tasks = result.get("tasks", []) + + # 遍历所有任务,构建映射 + for task in all_tasks: + # 检查任务的meta字段中是否有dm_file_id + meta = task.get('meta') + if meta: + dm_file_id = meta.get('dm_file_id') + if dm_file_id: + task_id = task.get('id') + if task_id: + dm_file_to_task_mapping[str(dm_file_id)] = task_id + + logger.info(f"Found {len(dm_file_to_task_mapping)} existing task mappings") + return dm_file_to_task_mapping + + except Exception as e: + logger.error(f"Error while fetching existing tasks: {e}") + return {} # 发生错误时返回空字典,会同步所有文件 + + async def sync_dataset_files( + self, + mapping_id: str, + batch_size: int = 50 + ) -> SyncDatasetResponse: + """同步数据集文件到Label Studio""" + logger.info(f"Start syncing dataset by mapping: {mapping_id}") + + # 获取映射关系 + mapping = await self.mapping_service.get_mapping_by_uuid(mapping_id) + if not mapping: + logger.error(f"Dataset mapping not found: {mapping_id}") + return SyncDatasetResponse( + mapping_id="", + status="error", + synced_files=0, + total_files=0, + message=f"Dataset mapping not found: {mapping_id}" + ) + + try: + # 获取数据集信息 + dataset_info = await self.dm_client.get_dataset(mapping.source_dataset_id) + if not dataset_info: + raise NoDatasetInfoFoundError(mapping.source_dataset_id) + + synced_files = 0 + deleted_tasks = 0 + total_files = dataset_info.fileCount + page = 0 + + logger.info(f"Total files in dataset: {total_files}") + + # 获取Label Studio中已存在的DM文件ID到任务ID的映射 + existing_dm_file_mapping = await self.get_existing_dm_file_mapping(mapping.labelling_project_id) + existing_dm_file_ids = set(existing_dm_file_mapping.keys()) + logger.info(f"{len(existing_dm_file_ids)} tasks already exist in Label Studio") + + # 收集DM中当前存在的所有文件ID + current_dm_file_ids = set() + + # 分页获取并同步文件 + while True: + files_response = await self.dm_client.get_dataset_files( + mapping.source_dataset_id, + page=page, + size=batch_size, + status="COMPLETED" # 只同步已完成的文件 + ) + + if not files_response or not files_response.content: + logger.info(f"No more files on page {page + 1}") + break + + logger.info(f"Processing page {page + 1}, total {len(files_response.content)} files") + + # 筛选出新文件并批量创建任务 + tasks = [] + new_files_count = 0 + existing_files_count = 0 + + for file_info in files_response.content: + # 记录当前DM中存在的文件ID + current_dm_file_ids.add(str(file_info.id)) + + # 检查文件是否已存在 + if str(file_info.id) in existing_dm_file_ids: + existing_files_count += 1 + logger.debug(f"Skip existing file: {file_info.originalName} (ID: {file_info.id})") + continue + + new_files_count += 1 + + # 确定数据类型 + data_type = self.determine_data_type(file_info.fileType) + + # 替换文件路径前缀:只替换开头的前缀,不影响路径中间可能出现的相同字符串 + file_path = file_info.filePath.removeprefix(settings.dm_file_path_prefix) + file_path = settings.label_studio_file_path_prefix + file_path + + # 构造任务数据 + task_data = { + "data": { + data_type: file_path + }, + "meta": { + "file_size": file_info.size, + "file_type": file_info.fileType, + "dm_dataset_id": mapping.source_dataset_id, + "dm_file_id": file_info.id, + "original_name": file_info.originalName, + } + } + tasks.append(task_data) + + logger.info(f"Page {page + 1}: new files {new_files_count}, existing files {existing_files_count}") + + # 批量创建Label Studio任务 + if tasks: + batch_result = await self.ls_client.create_tasks_batch( + mapping.labelling_project_id, + tasks + ) + + if batch_result: + synced_files += len(tasks) + logger.info(f"Successfully synced {len(tasks)} files") + else: + logger.warning(f"Batch task creation failed, fallback to single creation") + # 如果批量创建失败,尝试单个创建 + for task_data in tasks: + task_result = await self.ls_client.create_task( + mapping.labelling_project_id, + task_data["data"], + task_data.get("meta") + ) + if task_result: + synced_files += 1 + + # 检查是否还有更多页面 + if page >= files_response.totalPages - 1: + break + page += 1 + + # 清理在DM中不存在但在Label Studio中存在的任务 + tasks_to_delete = [] + for dm_file_id, task_id in existing_dm_file_mapping.items(): + if dm_file_id not in current_dm_file_ids: + tasks_to_delete.append(task_id) + logger.debug(f"Mark task for deletion: {task_id} (DM file ID: {dm_file_id})") + + if tasks_to_delete: + logger.info(f"Deleting {len(tasks_to_delete)} tasks not present in DM") + delete_result = await self.ls_client.delete_tasks_batch(tasks_to_delete) + deleted_tasks = delete_result.get("successful", 0) + logger.info(f"Successfully deleted {deleted_tasks} tasks") + else: + logger.info("No tasks to delete") + + # 更新映射的最后更新时间 + await self.mapping_service.update_last_updated_at(mapping.mapping_id) + + logger.info(f"Sync completed: total_files={total_files}, created={synced_files}, deleted={deleted_tasks}") + + return SyncDatasetResponse( + mapping_id=mapping.mapping_id, + status="success", + synced_files=synced_files, + total_files=total_files, + message=f"Sync completed: created {synced_files} files, deleted {deleted_tasks} tasks" + ) + + except Exception as e: + logger.error(f"Error while syncing dataset: {e}") + return SyncDatasetResponse( + mapping_id=mapping.mapping_id, + status="error", + synced_files=0, + total_files=0, + message=f"Sync failed: {str(e)}" + ) + + async def get_sync_status( + self, + source_dataset_id: str + ) -> Optional[Dict[str, Any]]: + """获取同步状态""" + mapping = await self.mapping_service.get_mapping_by_source_uuid(source_dataset_id) + if not mapping: + return None + + # 获取DM数据集信息 + dataset_info = await self.dm_client.get_dataset(source_dataset_id) + + # 获取Label Studio项目任务数量 + tasks_info = await self.ls_client.get_project_tasks(mapping.labelling_project_id) + + return { + "mapping_id": mapping.mapping_id, + "source_dataset_id": source_dataset_id, + "labelling_project_id": mapping.labelling_project_id, + "last_updated_at": mapping.last_updated_at, + "dm_total_files": dataset_info.fileCount if dataset_info else 0, + "ls_total_tasks": tasks_info.get("count", 0) if tasks_info else 0, + "sync_ratio": ( + tasks_info.get("count", 0) / dataset_info.fileCount + if dataset_info and dataset_info.fileCount > 0 and tasks_info else 0 + ) + } \ No newline at end of file diff --git a/runtime/label-studio-adapter/deploy/docker-entrypoint.sh b/runtime/label-studio-adapter/deploy/docker-entrypoint.sh new file mode 100755 index 0000000..7487951 --- /dev/null +++ b/runtime/label-studio-adapter/deploy/docker-entrypoint.sh @@ -0,0 +1,64 @@ +#!/bin/bash +set -e + +echo "==========================================" +echo "Label Studio Adapter Starting..." +echo "==========================================" + +# Label Studio 本地存储基础路径(从环境变量获取,默认值) +LABEL_STUDIO_LOCAL_BASE="${LABEL_STUDIO_LOCAL_BASE:-/label-studio/local_files}" + +echo "==========================================" +echo "Ensuring Label Studio local storage directories exist..." +echo "Base path: ${LABEL_STUDIO_LOCAL_BASE}" +echo "==========================================" + +# 创建必要的目录 +mkdir -p "${LABEL_STUDIO_LOCAL_BASE}/dataset" +mkdir -p "${LABEL_STUDIO_LOCAL_BASE}/upload" + +echo "✓ Directory 'dataset' ready: ${LABEL_STUDIO_LOCAL_BASE}/dataset" +echo "✓ Directory 'upload' ready: ${LABEL_STUDIO_LOCAL_BASE}/upload" + +echo "==========================================" +echo "Directory initialization completed" +echo "==========================================" + +# 等待数据库就绪(如果配置了数据库) +if [ -n "$MYSQL_HOST" ] || [ -n "$POSTGRES_HOST" ]; then + echo "Waiting for database to be ready..." + sleep 5 +fi + +# 运行数据库迁移 +echo "==========================================" +echo "Running database migrations..." +echo "==========================================" +alembic upgrade head + +if [ $? -eq 0 ]; then + echo "✓ Database migrations completed successfully" +else + echo "⚠️ WARNING: Database migrations failed" + echo " The application may not work correctly" +fi + +echo "==========================================" + +# 启动应用 +echo "Starting Label Studio Adapter..." +echo "Host: ${HOST:-0.0.0.0}" +echo "Port: ${PORT:-18000}" +echo "Debug: ${DEBUG:-false}" +echo "Label Studio URL: ${LABEL_STUDIO_BASE_URL}" +echo "==========================================" + +# 转换 LOG_LEVEL 为小写(uvicorn 要求小写) +LOG_LEVEL_LOWER=$(echo "${LOG_LEVEL:-info}" | tr '[:upper:]' '[:lower:]') + +# 使用 uvicorn 启动应用 +exec uvicorn app.main:app \ + --host "${HOST:-0.0.0.0}" \ + --port "${PORT:-18000}" \ + --log-level "${LOG_LEVEL_LOWER}" \ + ${DEBUG:+--reload} diff --git a/runtime/label-studio-adapter/requirements.txt b/runtime/label-studio-adapter/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..16aea99c9d573ff9351d05a2518be345d185f28c GIT binary patch literal 942 zcmaJ=%TB^j6r8vb4DMZ$xWNU47M~0EE_@+wj8P+{0s+gz*2qt&i95f*-|!Rs5tEd^ADlDUjp&xip=cDJBqasS!Bf5b) zcCmD@Q(rO2OS%{s9jl$VwTeY*NGD^?UHCoeslt$Ya*#0{)2Dj;$t)aOW-cy8_xrC8 zACr8X_&lTsCDvx|{?wzYHDB>smQ!>YBKmdPdEr zy-u~Ybm|?}3fa|NTF*6G=NN>?d?)-HvU6pv4ejuY^plBd0|O$GD%u2x86`WYsn24{ zoQQ9#D#142v_qC9hjw$t3R!DFOqpFo>_#<)Ci5_Lt2w&Aa$4jn$ABkC6_Q=Eoh6&3 a)6pcV&nnueh+ehfdTXI)n6?*tBEJCQN~u)< literal 0 HcmV?d00001 diff --git a/scripts/images/label-studio-adapter/Dockerfile b/scripts/images/label-studio-adapter/Dockerfile new file mode 100644 index 0000000..b240214 --- /dev/null +++ b/scripts/images/label-studio-adapter/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.11-slim + +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# 复制requirements文件 +COPY runtime/label-studio-adapter/requirements.txt . + +# 安装Python依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY runtime/label-studio-adapter . + +# 复制并设置 entrypoint 脚本权限 +COPY runtime/label-studio-adapter/deploy/docker-entrypoint.sh /docker-entrypoint.sh +RUN chmod +x /docker-entrypoint.sh + +# 暴露端口 +EXPOSE 8088 + +# 使用 entrypoint 脚本启动 +ENTRYPOINT ["/docker-entrypoint.sh"] \ No newline at end of file From e8e2c1a96bbbd4303193b3b79f9c05ecdc0666bd Mon Sep 17 00:00:00 2001 From: chenghh-9609 <55340429+chenghh-9609@users.noreply.github.com> Date: Wed, 22 Oct 2025 16:09:03 +0800 Subject: [PATCH 2/4] =?UTF-8?q?refactor:=20=E4=BF=AE=E5=A4=8D=E6=A0=87?= =?UTF-8?q?=E7=AD=BE=E7=AE=A1=E7=90=86=E5=8A=9F=E8=83=BD=E3=80=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=95=B0=E6=8D=AE=E9=80=89=E6=8B=A9=E9=A1=B9=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E3=80=81=E5=B1=8F=E8=94=BD=E5=BC=80=E5=8F=91=E4=B8=AD?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=20(#12)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: clean up tag management and dataset handling, update API endpoints * feat: add showTime prop to DevelopmentInProgress component across multiple pages * refactor: update component styles and improve layout with new utility classes --- frontend/src/components/CardView.tsx | 6 +- frontend/src/components/DetailHeader.tsx | 2 +- .../src/components/DevelopmentInProgress.tsx | 13 ++- frontend/src/components/RadioCard.tsx | 2 +- frontend/src/components/TagManagement.tsx | 105 +++++++----------- frontend/src/components/TaskPopover.tsx | 2 +- frontend/src/index.css | 26 ++++- .../Annotate/components/ImageAnnotation.tsx | 4 +- .../DataAnnotation/Create/CreateTask.tsx | 4 +- .../components/CreateAnnptationTaskDialog.tsx | 35 +++--- .../DataAnnotation/Home/DataAnnotation.tsx | 2 + .../pages/DataAnnotation/annotation.api.ts | 2 +- .../pages/DataCleansing/Create/CreateTask.tsx | 4 +- .../DataCleansing/Create/CreateTempate.tsx | 4 +- .../Create/components/CreateTaskStepOne.tsx | 31 +++--- .../components/OperatorOrchestration.tsx | 2 +- .../Home/components/ProcessFlowDiagram.tsx | 2 +- .../pages/DataCleansing/cleansing.model.ts | 6 +- .../DataCollection/Create/CreateTask.tsx | 2 +- .../DataCollection/Home/DataCollection.tsx | 2 +- .../DataEvaluation/Home/DataEvaluation.tsx | 4 +- .../DataManagement/Create/CreateDataset.tsx | 4 +- .../Create/components/BasicInformation.tsx | 18 +-- .../DataManagement/Detail/DatasetDetail.tsx | 4 +- .../DataManagement/Home/DataManagement.tsx | 8 +- .../src/pages/DataManagement/dataset.api.ts | 8 +- .../pages/DataManagement/dataset.const.tsx | 26 +++-- .../Create/KnowledgeBaseCreate.tsx | 2 +- .../FileDetail/KnowledgeBaseFileDetail.tsx | 24 ++-- .../Home/KnowledgeGeneration.tsx | 2 +- frontend/src/pages/Layout/MainLayout.tsx | 4 +- frontend/src/pages/Layout/Sidebar.tsx | 6 +- .../Create/OperatorPluginCreate.tsx | 6 +- .../OperatorMarket/Home/OperatorMarket.tsx | 9 +- .../Home/components/Filters.tsx | 4 +- .../OperatorMarket/Home/components/List.tsx | 16 +-- .../src/pages/RatioTask/CreateRatioTask.tsx | 4 +- frontend/src/pages/RatioTask/RatioTask.tsx | 2 +- .../src/pages/SynthesisTask/CreateTask.tsx | 2 +- .../pages/SynthesisTask/CreateTemplate.tsx | 49 ++++---- .../src/pages/SynthesisTask/DataSynthesis.tsx | 4 +- 41 files changed, 224 insertions(+), 238 deletions(-) diff --git a/frontend/src/components/CardView.tsx b/frontend/src/components/CardView.tsx index f230280..92363dd 100644 --- a/frontend/src/components/CardView.tsx +++ b/frontend/src/components/CardView.tsx @@ -168,12 +168,12 @@ function CardView(props: CardViewProps) { const ops = (item) => typeof operations === "function" ? operations(item) : operations; return ( -
-
+
+
{data.map((item) => (
{/* Header */} diff --git a/frontend/src/components/DetailHeader.tsx b/frontend/src/components/DetailHeader.tsx index dc14b63..b3c910e 100644 --- a/frontend/src/components/DetailHeader.tsx +++ b/frontend/src/components/DetailHeader.tsx @@ -48,7 +48,7 @@ function DetailHeader({
{ +const DevelopmentInProgress = ({ showHome = true, showTime = "" }) => { return (
🚧

功能开发中

-

- 为了给您带来更好的体验,我们计划2025.10.30 - 开放此功能 -

+ {showTime && ( +

+ 为了给您带来更好的体验,我们计划{showTime} + 开放此功能 +

+ )} {showHome && ( +
+ setNewTag(e.target.value)} + onKeyPress={(e) => { + if (e.key === "Enter") { + addTag(e.target.value); + } + }} + /> + +
+ +
+
+ {tags.map((tag) => ( + + ))}
- -

预置标签

-
- {preparedTags.length > 0 && - preparedTags.map((tag) => )} -
- -

自定义标签

-
- {tags.map((tag) => ( - - ))} -
diff --git a/frontend/src/components/TaskPopover.tsx b/frontend/src/components/TaskPopover.tsx index 768c796..2a0e746 100644 --- a/frontend/src/components/TaskPopover.tsx +++ b/frontend/src/components/TaskPopover.tsx @@ -89,7 +89,7 @@ export default function TaskPopover() { {tasks.map((task) => (
diff --git a/frontend/src/index.css b/frontend/src/index.css index 5b2646d..e5fca1b 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -42,4 +42,28 @@ opacity: 100%; visibility: visible; transform: translateX(0); -} \ No newline at end of file +} + +@layer components { + .flex-center { + @apply flex items-center justify-center; + } + .flex-overflow-auto { + @apply flex-1 flex flex-col overflow-auto h-full; + } + .flex-overflow-hidden { + @apply flex flex-col h-full overflow-hidden; + } + .border-card { + @apply border border-[#f0f0f0] rounded-lg bg-white; + } + .border { + @apply border border-gray-100; + } + .border-bottom { + @apply border-b border-gray-100; + } + .border-top { + @apply border-t border-gray-100; + } +} diff --git a/frontend/src/pages/DataAnnotation/Annotate/components/ImageAnnotation.tsx b/frontend/src/pages/DataAnnotation/Annotate/components/ImageAnnotation.tsx index 0e9090e..8d95e2b 100644 --- a/frontend/src/pages/DataAnnotation/Annotate/components/ImageAnnotation.tsx +++ b/frontend/src/pages/DataAnnotation/Annotate/components/ImageAnnotation.tsx @@ -429,7 +429,7 @@ export default function ImageAnnotationWorkspace({ }`} onClick={() => setSelectedImageIndex(index)} > -
+
{index + 1}
-
+
+
{/* Header */}
@@ -134,7 +134,7 @@ export default function AnnotationTaskCreate() {

创建标注任务

-
+
({ - label: ( -
-
- - {dataset.icon || } - - {dataset.name} + options={datasets.map((dataset) => { + return { + label: ( +
+
+ {dataset.icon} + {dataset.name} +
+
{dataset.size}
-
- {datasetTypeMap[dataset?.datasetType]?.label} -
-
- ), - value: dataset.id, - }))} + ), + value: dataset.id, + }; + })} /> diff --git a/frontend/src/pages/DataCleansing/Create/components/OperatorOrchestration.tsx b/frontend/src/pages/DataCleansing/Create/components/OperatorOrchestration.tsx index 6bae1b8..eaea432 100644 --- a/frontend/src/pages/DataCleansing/Create/components/OperatorOrchestration.tsx +++ b/frontend/src/pages/DataCleansing/Create/components/OperatorOrchestration.tsx @@ -105,7 +105,7 @@ const OperatorFlow: React.FC = ({
{/* 编排区域 */}
e.preventDefault()} onDragLeave={handleContainerDragLeave} onDrop={handleDropToContainer} diff --git a/frontend/src/pages/DataCleansing/Home/components/ProcessFlowDiagram.tsx b/frontend/src/pages/DataCleansing/Home/components/ProcessFlowDiagram.tsx index 426012c..c307057 100644 --- a/frontend/src/pages/DataCleansing/Home/components/ProcessFlowDiagram.tsx +++ b/frontend/src/pages/DataCleansing/Home/components/ProcessFlowDiagram.tsx @@ -56,7 +56,7 @@ export default function ProcessFlowDiagram() { ]; return ( -
+
{flowSteps.map((step, index) => { diff --git a/frontend/src/pages/DataCleansing/cleansing.model.ts b/frontend/src/pages/DataCleansing/cleansing.model.ts index c8fa829..58cdce0 100644 --- a/frontend/src/pages/DataCleansing/cleansing.model.ts +++ b/frontend/src/pages/DataCleansing/cleansing.model.ts @@ -16,7 +16,11 @@ export interface CleansingTask { color: string; }; startedAt: string; - progress: number; + progress: { + finishedFileNum: number; + process: 100, + totalFileNum: number; + }; operators: OperatorI[]; createdAt: string; updatedAt: string; diff --git a/frontend/src/pages/DataCollection/Create/CreateTask.tsx b/frontend/src/pages/DataCollection/Create/CreateTask.tsx index 0b6c357..2a1135b 100644 --- a/frontend/src/pages/DataCollection/Create/CreateTask.tsx +++ b/frontend/src/pages/DataCollection/Create/CreateTask.tsx @@ -59,7 +59,7 @@ const defaultTemplates = [ ]; export default function CollectionTaskCreate() { - return ; + return ; const navigate = useNavigate(); const [form] = Form.useForm(); diff --git a/frontend/src/pages/DataCollection/Home/DataCollection.tsx b/frontend/src/pages/DataCollection/Home/DataCollection.tsx index 88dbe3f..8cdd66a 100644 --- a/frontend/src/pages/DataCollection/Home/DataCollection.tsx +++ b/frontend/src/pages/DataCollection/Home/DataCollection.tsx @@ -10,7 +10,7 @@ export default function DataCollection() { const navigate = useNavigate(); const [activeTab, setActiveTab] = useState("task-management"); - return ; + return ; return (
diff --git a/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx b/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx index 87e09b1..d589b37 100644 --- a/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx +++ b/frontend/src/pages/DataEvaluation/Home/DataEvaluation.tsx @@ -180,9 +180,7 @@ export default function DataEvaluationPage() { setTasks(tasks.filter((task) => task.id !== taskId)); }; - return ( - - ); + return ; // 主列表界面 return (
diff --git a/frontend/src/pages/DataManagement/Create/CreateDataset.tsx b/frontend/src/pages/DataManagement/Create/CreateDataset.tsx index 610404a..4042630 100644 --- a/frontend/src/pages/DataManagement/Create/CreateDataset.tsx +++ b/frontend/src/pages/DataManagement/Create/CreateDataset.tsx @@ -56,7 +56,7 @@ export default function DatasetCreate() {
{/* form */} -
+
-
+
- 筛选器 +

筛选器

{hasActiveFilters && ( { - const iconMap = { - preprocessing: Code, - training: Brain, - inference: Cpu, - postprocessing: Package, - }; - const IconComponent = iconMap[type as keyof typeof iconMap] || Code; - return ; - }; return ( ( diff --git a/frontend/src/pages/RatioTask/CreateRatioTask.tsx b/frontend/src/pages/RatioTask/CreateRatioTask.tsx index fe9a13b..4353e95 100644 --- a/frontend/src/pages/RatioTask/CreateRatioTask.tsx +++ b/frontend/src/pages/RatioTask/CreateRatioTask.tsx @@ -31,8 +31,8 @@ const { TextArea } = Input; const { Option } = Select; export default function CreateRatioTask() { - return ; - + return ; + const navigate = useNavigate(); const [form] = Form.useForm(); // 配比任务相关状态 diff --git a/frontend/src/pages/RatioTask/RatioTask.tsx b/frontend/src/pages/RatioTask/RatioTask.tsx index 03a9e8f..c5e1231 100644 --- a/frontend/src/pages/RatioTask/RatioTask.tsx +++ b/frontend/src/pages/RatioTask/RatioTask.tsx @@ -28,7 +28,7 @@ import { SearchControls } from "@/components/SearchControls"; import DevelopmentInProgress from "@/components/DevelopmentInProgress"; export default function RatioTasksPage() { - return ; + return ; const navigate = useNavigate(); const [searchQuery, setSearchQuery] = useState(""); const [filterStatus, setFilterStatus] = useState("all"); diff --git a/frontend/src/pages/SynthesisTask/CreateTask.tsx b/frontend/src/pages/SynthesisTask/CreateTask.tsx index e243468..6eaf791 100644 --- a/frontend/src/pages/SynthesisTask/CreateTask.tsx +++ b/frontend/src/pages/SynthesisTask/CreateTask.tsx @@ -41,7 +41,7 @@ import DevelopmentInProgress from "@/components/DevelopmentInProgress"; const { TextArea } = Input; export default function SynthesisTaskCreate() { - return ; + return ; const navigate = useNavigate(); const [form] = Form.useForm(); const [searchQuery, setSearchQuery] = useState(""); diff --git a/frontend/src/pages/SynthesisTask/CreateTemplate.tsx b/frontend/src/pages/SynthesisTask/CreateTemplate.tsx index e80379f..f196a23 100644 --- a/frontend/src/pages/SynthesisTask/CreateTemplate.tsx +++ b/frontend/src/pages/SynthesisTask/CreateTemplate.tsx @@ -1,15 +1,15 @@ import { useState, useRef } from "react"; -import { Card, Select, Input, Button, Badge, Divider, Form, message } from "antd"; import { - Plus, - ArrowLeft, - Play, - Save, - RefreshCw, - FileText, - Code, - X, -} from "lucide-react"; + Card, + Select, + Input, + Button, + Badge, + Divider, + Form, + message, +} from "antd"; +import { Plus, ArrowLeft, Play, Save, RefreshCw, Code, X } from "lucide-react"; import { useNavigate } from "react-router"; import { mockTemplates } from "@/mock/annotation"; import DevelopmentInProgress from "@/components/DevelopmentInProgress"; @@ -17,9 +17,11 @@ import DevelopmentInProgress from "@/components/DevelopmentInProgress"; const { TextArea } = Input; export default function InstructionTemplateCreate() { - return ; + return ; const navigate = useNavigate(); - const [selectedTemplate, setSelectedTemplate] = useState