feature: LabelStudio jumps without login (#201)

This commit is contained in:
hefanli
2025-12-25 16:49:06 +08:00
committed by GitHub
parent 87e73d3bf7
commit 29e4a333a9
4 changed files with 230 additions and 157 deletions

View File

@@ -2,7 +2,7 @@ from typing import Optional
import math
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Response
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db
@@ -29,6 +29,39 @@ router = APIRouter(
)
logger = get_logger(__name__)
@router.get("/{mapping_id}/login")
async def list_mappings(
db: AsyncSession = Depends(get_db)
):
try:
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
target_response = await ls_client.login_label_studio()
headers = dict(target_response.headers)
set_cookies = target_response.headers.get_list("set-cookie")
# 删除合并的 Set-Cookie
if "set-cookie" in headers:
del headers["set-cookie"]
# 创建新响应,添加多个 Set-Cookie
response = Response(
content=target_response.content,
status_code=target_response.status_code,
headers=headers
)
# 分别添加每个 Set-Cookie
for cookie in set_cookies:
response.headers.append("set-cookie", cookie)
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while logining in LabelStudio: {e}", e)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
async def create_mapping(
request: DatasetMappingCreateRequest,
@@ -36,12 +69,12 @@ async def create_mapping(
):
"""
创建数据集映射
根据指定的DM程序中的数据集,创建Label Studio中的数据集,
在数据库中记录这一关联关系,返回Label Studio数据集的ID
注意:一个数据集可以创建多个标注项目
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
"""
try:
@@ -51,9 +84,9 @@ async def create_mapping(
mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client, ls_client, mapping_service)
template_service = AnnotationTemplateService()
logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息
dataset_info = await dm_client.get_dataset(request.dataset_id)
if not dataset_info:
@@ -61,11 +94,11 @@ async def create_mapping(
status_code=404,
detail=f"Dataset not found in DM service: {request.dataset_id}"
)
project_name = request.name or \
dataset_info.name or \
"A new project from DataMate"
project_description = request.description or \
dataset_info.description or \
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
@@ -89,15 +122,15 @@ async def create_mapping(
description=project_description,
label_config=label_config # 传递模板配置
)
if not project_data:
raise HTTPException(
status_code=500,
detail="Fail to create Label Studio project."
)
project_id = project_data["id"]
# 配置本地存储:dataset/<id>
local_storage_path = f"{settings.label_studio_local_document_root}/{request.dataset_id}"
storage_result = await ls_client.create_local_storage(
@@ -107,7 +140,7 @@ async def create_mapping(
use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}"
)
if not storage_result:
# 本地存储配置失败,记录警告但不中断流程
logger.warning(f"Failed to configure local storage for project {project_id}")
@@ -124,28 +157,28 @@ async def create_mapping(
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
mapping = await mapping_service.create_mapping(labeling_project)
# 进行一次同步,使用创建后的 mapping.id
await sync_service.sync_dataset_files(mapping.id, 100)
response_data = DatasetMappingCreateResponse(
id=mapping.id,
labeling_project_id=str(mapping.labeling_project_id),
labeling_project_name=mapping.name or project_name
)
return StandardResponse(
code=201,
message="success",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"),
@@ -155,10 +188,10 @@ async def list_mappings(
):
"""
查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
参数:
- page: 页码(从1开始)
- pageSize: 每页记录数
@@ -166,12 +199,12 @@ async def list_mappings(
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
@@ -179,10 +212,10 @@ async def list_mappings(
include_deleted=False,
include_template=include_template
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
@@ -191,15 +224,15 @@ async def list_mappings(
total_pages=total_pages,
content=mappings
)
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except Exception as e:
logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@@ -211,7 +244,7 @@ async def get_mapping(
):
"""
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
返回数据集映射关系以及关联的完整标注模板信息,包括:
- 映射基本信息
- 数据集信息
@@ -220,26 +253,26 @@ async def get_mapping(
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mapping with template details: {mapping_id}")
# 获取映射,并包含完整的模板信息
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping not found: {mapping_id}"
)
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
return StandardResponse(
code=200,
message="success",
data=mapping
)
except HTTPException:
raise
except Exception as e:
@@ -256,10 +289,10 @@ async def get_mappings_by_source(
):
"""
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
默认包含关联的完整标注模板信息。
参数:
- dataset_id: 数据集ID
- page: 页码(从1开始)
@@ -268,12 +301,12 @@ async def get_mappings_by_source(
"""
try:
service = DatasetMappingService(db)
# 计算 skip
skip = (page - 1) * page_size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数(包含模板信息)
mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id,
@@ -281,10 +314,10 @@ async def get_mappings_by_source(
limit=page_size,
include_template=include_template
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
@@ -293,15 +326,15 @@ async def get_mappings_by_source(
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
return StandardResponse(
code=200,
message="success",
data=paginated_data
)
except HTTPException:
raise
except Exception as e:
@@ -328,24 +361,24 @@ async def delete_mapping(
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
service = DatasetMappingService(db)
# 使用 mapping UUID 查询映射记录
logger.debug(f"Deleting by mapping UUID: {project_id}")
mapping = await service.get_mapping_by_uuid(project_id)
logger.debug(f"Mapping lookup result: {mapping}")
if not mapping:
raise HTTPException(
status_code=404,
detail=f"Mapping either not found or not specified."
)
id = mapping.id
labeling_project_id = mapping.labeling_project_id
logger.debug(f"Found mapping: {id}, Label Studio project ID: {labeling_project_id}")
# 1. 删除 Label Studio 项目
try:
logger.debug(f"Deleting Label Studio project: {labeling_project_id}")
@@ -357,11 +390,11 @@ async def delete_mapping(
except Exception as e:
logger.error(f"Error deleting Label Studio project: {e}")
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
# 2. 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(id)
logger.debug(f"Soft delete result for mapping {id}: {soft_delete_success}")
if not soft_delete_success:
raise HTTPException(
status_code=500,
@@ -378,7 +411,7 @@ async def delete_mapping(
status="success"
)
)
except HTTPException:
raise
except Exception as e: