You've already forked DataMate
feature: LabelStudio jumps without login (#201)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.session import get_db
|
||||
@@ -29,6 +29,39 @@ router = APIRouter(
|
||||
)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@router.get("/{mapping_id}/login")
|
||||
async def list_mappings(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
try:
|
||||
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
|
||||
token=settings.label_studio_user_token)
|
||||
target_response = await ls_client.login_label_studio()
|
||||
headers = dict(target_response.headers)
|
||||
set_cookies = target_response.headers.get_list("set-cookie")
|
||||
|
||||
# 删除合并的 Set-Cookie
|
||||
if "set-cookie" in headers:
|
||||
del headers["set-cookie"]
|
||||
|
||||
# 创建新响应,添加多个 Set-Cookie
|
||||
response = Response(
|
||||
content=target_response.content,
|
||||
status_code=target_response.status_code,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 分别添加每个 Set-Cookie
|
||||
for cookie in set_cookies:
|
||||
response.headers.append("set-cookie", cookie)
|
||||
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error while logining in LabelStudio: {e}", e)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
|
||||
async def create_mapping(
|
||||
request: DatasetMappingCreateRequest,
|
||||
@@ -36,12 +69,12 @@ async def create_mapping(
|
||||
):
|
||||
"""
|
||||
创建数据集映射
|
||||
|
||||
|
||||
根据指定的DM程序中的数据集,创建Label Studio中的数据集,
|
||||
在数据库中记录这一关联关系,返回Label Studio数据集的ID
|
||||
|
||||
|
||||
注意:一个数据集可以创建多个标注项目
|
||||
|
||||
|
||||
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
|
||||
"""
|
||||
try:
|
||||
@@ -51,9 +84,9 @@ async def create_mapping(
|
||||
mapping_service = DatasetMappingService(db)
|
||||
sync_service = SyncService(dm_client, ls_client, mapping_service)
|
||||
template_service = AnnotationTemplateService()
|
||||
|
||||
|
||||
logger.info(f"Create dataset mapping request: {request.dataset_id}")
|
||||
|
||||
|
||||
# 从DM服务获取数据集信息
|
||||
dataset_info = await dm_client.get_dataset(request.dataset_id)
|
||||
if not dataset_info:
|
||||
@@ -61,11 +94,11 @@ async def create_mapping(
|
||||
status_code=404,
|
||||
detail=f"Dataset not found in DM service: {request.dataset_id}"
|
||||
)
|
||||
|
||||
|
||||
project_name = request.name or \
|
||||
dataset_info.name or \
|
||||
"A new project from DataMate"
|
||||
|
||||
|
||||
project_description = request.description or \
|
||||
dataset_info.description or \
|
||||
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
|
||||
@@ -89,15 +122,15 @@ async def create_mapping(
|
||||
description=project_description,
|
||||
label_config=label_config # 传递模板配置
|
||||
)
|
||||
|
||||
|
||||
if not project_data:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Fail to create Label Studio project."
|
||||
)
|
||||
|
||||
|
||||
project_id = project_data["id"]
|
||||
|
||||
|
||||
# 配置本地存储:dataset/<id>
|
||||
local_storage_path = f"{settings.label_studio_local_document_root}/{request.dataset_id}"
|
||||
storage_result = await ls_client.create_local_storage(
|
||||
@@ -107,7 +140,7 @@ async def create_mapping(
|
||||
use_blob_urls=True,
|
||||
description=f"Local storage for dataset {dataset_info.name}"
|
||||
)
|
||||
|
||||
|
||||
if not storage_result:
|
||||
# 本地存储配置失败,记录警告但不中断流程
|
||||
logger.warning(f"Failed to configure local storage for project {project_id}")
|
||||
@@ -124,28 +157,28 @@ async def create_mapping(
|
||||
|
||||
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
|
||||
mapping = await mapping_service.create_mapping(labeling_project)
|
||||
|
||||
|
||||
# 进行一次同步,使用创建后的 mapping.id
|
||||
await sync_service.sync_dataset_files(mapping.id, 100)
|
||||
|
||||
|
||||
response_data = DatasetMappingCreateResponse(
|
||||
id=mapping.id,
|
||||
labeling_project_id=str(mapping.labeling_project_id),
|
||||
labeling_project_name=mapping.name or project_name
|
||||
)
|
||||
|
||||
|
||||
return StandardResponse(
|
||||
code=201,
|
||||
message="success",
|
||||
data=response_data
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error while creating dataset mapping: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||
async def list_mappings(
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
@@ -155,10 +188,10 @@ async def list_mappings(
|
||||
):
|
||||
"""
|
||||
查询所有映射关系(分页)
|
||||
|
||||
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
|
||||
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
|
||||
|
||||
|
||||
参数:
|
||||
- page: 页码(从1开始)
|
||||
- pageSize: 每页记录数
|
||||
@@ -166,12 +199,12 @@ async def list_mappings(
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
|
||||
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_all_mappings_with_count(
|
||||
skip=skip,
|
||||
@@ -179,10 +212,10 @@ async def list_mappings(
|
||||
include_deleted=False,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
|
||||
# 计算总页数
|
||||
total_pages = math.ceil(total / page_size) if total > 0 else 0
|
||||
|
||||
|
||||
# 构造分页响应
|
||||
paginated_data = PaginatedData(
|
||||
page=page,
|
||||
@@ -191,15 +224,15 @@ async def list_mappings(
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
|
||||
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=paginated_data
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing mappings: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -211,7 +244,7 @@ async def get_mapping(
|
||||
):
|
||||
"""
|
||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||
|
||||
|
||||
返回数据集映射关系以及关联的完整标注模板信息,包括:
|
||||
- 映射基本信息
|
||||
- 数据集信息
|
||||
@@ -220,26 +253,26 @@ async def get_mapping(
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
|
||||
logger.info(f"Get mapping with template details: {mapping_id}")
|
||||
|
||||
|
||||
# 获取映射,并包含完整的模板信息
|
||||
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
|
||||
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping not found: {mapping_id}"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
|
||||
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mapping
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -256,10 +289,10 @@ async def get_mappings_by_source(
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||
|
||||
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
|
||||
默认包含关联的完整标注模板信息。
|
||||
|
||||
|
||||
参数:
|
||||
- dataset_id: 数据集ID
|
||||
- page: 页码(从1开始)
|
||||
@@ -268,12 +301,12 @@ async def get_mappings_by_source(
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
|
||||
|
||||
|
||||
# 获取数据和总数(包含模板信息)
|
||||
mappings, total = await service.get_mappings_by_source_with_count(
|
||||
dataset_id=dataset_id,
|
||||
@@ -281,10 +314,10 @@ async def get_mappings_by_source(
|
||||
limit=page_size,
|
||||
include_template=include_template
|
||||
)
|
||||
|
||||
|
||||
# 计算总页数
|
||||
total_pages = math.ceil(total / page_size) if total > 0 else 0
|
||||
|
||||
|
||||
# 构造分页响应
|
||||
paginated_data = PaginatedData(
|
||||
page=page,
|
||||
@@ -293,15 +326,15 @@ async def get_mappings_by_source(
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
|
||||
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=paginated_data
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -328,24 +361,24 @@ async def delete_mapping(
|
||||
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
|
||||
token=settings.label_studio_user_token)
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
|
||||
# 使用 mapping UUID 查询映射记录
|
||||
logger.debug(f"Deleting by mapping UUID: {project_id}")
|
||||
mapping = await service.get_mapping_by_uuid(project_id)
|
||||
|
||||
logger.debug(f"Mapping lookup result: {mapping}")
|
||||
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Mapping either not found or not specified."
|
||||
)
|
||||
|
||||
|
||||
id = mapping.id
|
||||
labeling_project_id = mapping.labeling_project_id
|
||||
|
||||
logger.debug(f"Found mapping: {id}, Label Studio project ID: {labeling_project_id}")
|
||||
|
||||
|
||||
# 1. 删除 Label Studio 项目
|
||||
try:
|
||||
logger.debug(f"Deleting Label Studio project: {labeling_project_id}")
|
||||
@@ -357,11 +390,11 @@ async def delete_mapping(
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting Label Studio project: {e}")
|
||||
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
|
||||
|
||||
|
||||
# 2. 软删除映射记录
|
||||
soft_delete_success = await service.soft_delete_mapping(id)
|
||||
logger.debug(f"Soft delete result for mapping {id}: {soft_delete_success}")
|
||||
|
||||
|
||||
if not soft_delete_success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -378,7 +411,7 @@ async def delete_mapping(
|
||||
status="success"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user