feature: LabelStudio jumps without login (#201)

This commit is contained in:
hefanli
2025-12-25 16:49:06 +08:00
committed by GitHub
parent 87e73d3bf7
commit 29e4a333a9
4 changed files with 230 additions and 157 deletions

View File

@@ -11,7 +11,7 @@ import CardView from "@/components/CardView";
import type { AnnotationTask } from "../annotation.model"; import type { AnnotationTask } from "../annotation.model";
import useFetchData from "@/hooks/useFetchData"; import useFetchData from "@/hooks/useFetchData";
import { import {
deleteAnnotationTaskByIdUsingDelete, deleteAnnotationTaskByIdUsingDelete, loginAnnotationUsingGet,
queryAnnotationTasksUsingGet, queryAnnotationTasksUsingGet,
syncAnnotationTaskUsingPost, syncAnnotationTaskUsingPost,
} from "../annotation.api"; } from "../annotation.api";
@@ -76,6 +76,7 @@ export default function DataAnnotation() {
if (labelingProjId) { if (labelingProjId) {
// only open external Label Studio when we have a configured base url // only open external Label Studio when we have a configured base url
await loginAnnotationUsingGet(labelingProjId)
if (base) { if (base) {
const target = `${base}/projects/${labelingProjId}/data`; const target = `${base}/projects/${labelingProjId}/data`;
window.open(target, "_blank"); window.open(target, "_blank");

View File

@@ -18,6 +18,10 @@ export function deleteAnnotationTaskByIdUsingDelete(mappingId: string) {
return del(`/api/annotation/project/${mappingId}`); return del(`/api/annotation/project/${mappingId}`);
} }
export function loginAnnotationUsingGet(mappingId: string) {
return get("/api/annotation/project/${mappingId}/login");
}
// 标签配置管理 // 标签配置管理
export function getTagConfigUsingGet() { export function getTagConfigUsingGet() {
return get("/api/annotation/tags/config"); return get("/api/annotation/tags/config");
@@ -43,4 +47,4 @@ export function deleteAnnotationTemplateByIdUsingDelete(
templateId: string | number templateId: string | number
) { ) {
return del(`/api/annotation/template/${templateId}`); return del(`/api/annotation/template/${templateId}`);
} }

View File

@@ -1,11 +1,12 @@
import httpx import httpx
import re
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from .schema import ( from .schema import (
LabelStudioProject, LabelStudioProject,
LabelStudioCreateProjectRequest, LabelStudioCreateProjectRequest,
LabelStudioCreateTaskRequest LabelStudioCreateTaskRequest
) )
@@ -14,11 +15,11 @@ logger = get_logger(__name__)
class Client: class Client:
"""Label Studio服务客户端 """Label Studio服务客户端
使用 HTTP REST API 直接与 Label Studio 交互 使用 HTTP REST API 直接与 Label Studio 交互
认证方式:使用 Authorization: Token {token} 头部进行认证 认证方式:使用 Authorization: Token {token} 头部进行认证
""" """
# 默认标注配置模板 # 默认标注配置模板
DEFAULT_LABEL_CONFIGS = { DEFAULT_LABEL_CONFIGS = {
"image": """ "image": """
@@ -57,15 +58,15 @@ class Client:
</View> </View>
""" """
} }
def __init__( def __init__(
self, self,
base_url: Optional[str] = None, base_url: Optional[str] = None,
token: Optional[str] = None, token: Optional[str] = None,
timeout: float = 30.0 timeout: float = 30.0
): ):
"""初始化 Label Studio 客户端 """初始化 Label Studio 客户端
Args: Args:
base_url: Label Studio 服务地址 base_url: Label Studio 服务地址
token: API Token(使用 Authorization: Token {token} 头部) token: API Token(使用 Authorization: Token {token} 头部)
@@ -74,10 +75,10 @@ class Client:
self.base_url = (base_url or settings.label_studio_base_url).rstrip("/") self.base_url = (base_url or settings.label_studio_base_url).rstrip("/")
self.token = token or settings.label_studio_user_token self.token = token or settings.label_studio_user_token
self.timeout = timeout self.timeout = timeout
if not self.token: if not self.token:
raise ValueError("Label Studio API token is required") raise ValueError("Label Studio API token is required")
# 初始化 HTTP 客户端 # 初始化 HTTP 客户端
self.client = httpx.AsyncClient( self.client = httpx.AsyncClient(
base_url=self.base_url, base_url=self.base_url,
@@ -87,46 +88,80 @@ class Client:
"Content-Type": "application/json" "Content-Type": "application/json"
} }
) )
logger.debug(f"Label Studio client initialized: {self.base_url}") logger.debug(f"Label Studio client initialized: {self.base_url}")
def get_label_config_by_type(self, data_type: str) -> str: def get_label_config_by_type(self, data_type: str) -> str:
"""根据数据类型获取标注配置""" """根据数据类型获取标注配置"""
return self.DEFAULT_LABEL_CONFIGS.get(data_type.lower(), self.DEFAULT_LABEL_CONFIGS["image"]) return self.DEFAULT_LABEL_CONFIGS.get(data_type.lower(), self.DEFAULT_LABEL_CONFIGS["image"])
@staticmethod
def get_csrf_token(html: str) -> str:
m = re.search(r'name="csrfmiddlewaretoken"\s+value="([^"]+)"', html)
if not m:
raise IOError("CSRF Token not found")
return m.group(1)
async def login_label_studio(self):
try:
response = await self.client.get("/user/login/")
response.raise_for_status()
body = response.text
set_cookie_headers = response.headers.get_list("set-cookie")
cookie_header = "; ".join(set_cookie_headers)
form = {
"email": settings.label_studio_username,
"password": settings.label_studio_password,
"csrfmiddlewaretoken": self.get_csrf_token(body),
}
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Cookie": cookie_header,
}
login_response = await self.client.post("/user/login/", data=form, headers=headers)
logger.info(f"response is: {login_response}, {login_response.text}")
return login_response
except httpx.HTTPStatusError as e:
logger.error(f"Login failed HTTP {e.response.status_code}: {e.response.text}")
return None
except Exception as e:
logger.error(f"Error while login: {e}", e)
return None
async def create_project( async def create_project(
self, self,
title: str, title: str,
description: str = "", description: str = "",
label_config: Optional[str] = None, label_config: Optional[str] = None,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""创建Label Studio项目""" """创建Label Studio项目"""
try: try:
logger.debug(f"Creating Label Studio project: {title}") logger.debug(f"Creating Label Studio project: {title}")
logger.debug(f"Label Studio URL: {self.base_url}/api/projects") logger.debug(f"Label Studio URL: {self.base_url}/api/projects")
project_data = { project_data = {
"title": title, "title": title,
"description": description, "description": description,
"label_config": label_config or "<View></View>" "label_config": label_config or "<View></View>"
} }
# Log the request body for debugging # Log the request body for debugging
logger.debug(f"Request body: {project_data}") logger.debug(f"Request body: {project_data}")
logger.debug(f"Label config being sent:\n{project_data['label_config']}") logger.debug(f"Label config being sent:\n{project_data['label_config']}")
response = await self.client.post("/api/projects", json=project_data) response = await self.client.post("/api/projects", json=project_data)
response.raise_for_status() response.raise_for_status()
project = response.json() project = response.json()
project_id = project.get("id") project_id = project.get("id")
if not project_id: if not project_id:
raise Exception("Label Studio response does not contain project ID") raise Exception("Label Studio response does not contain project ID")
logger.debug(f"Project created successfully, ID: {project_id}") logger.debug(f"Project created successfully, ID: {project_id}")
return project return project
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error( logger.error(
f"Create project failed - HTTP {e.response.status_code}\n" f"Create project failed - HTTP {e.response.status_code}\n"
@@ -151,7 +186,7 @@ class Client:
except Exception as e: except Exception as e:
logger.error(f"Error while creating Label Studio project: {str(e)}", exc_info=True) logger.error(f"Error while creating Label Studio project: {str(e)}", exc_info=True)
return None return None
async def import_tasks( async def import_tasks(
self, self,
project_id: int, project_id: int,
@@ -162,7 +197,7 @@ class Client:
"""批量导入任务到Label Studio项目""" """批量导入任务到Label Studio项目"""
try: try:
logger.debug(f"Importing {len(tasks)} tasks into project {project_id}") logger.debug(f"Importing {len(tasks)} tasks into project {project_id}")
response = await self.client.post( response = await self.client.post(
f"/api/projects/{project_id}/import", f"/api/projects/{project_id}/import",
json=tasks, json=tasks,
@@ -172,20 +207,20 @@ class Client:
} }
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
task_count = result.get("task_count", len(tasks)) task_count = result.get("task_count", len(tasks))
logger.debug(f"Tasks imported successfully: {task_count}") logger.debug(f"Tasks imported successfully: {task_count}")
return result return result
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Import tasks failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Import tasks failed HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error while importing tasks: {e}") logger.error(f"Error while importing tasks: {e}")
return None return None
async def create_tasks_batch( async def create_tasks_batch(
self, self,
project_id: str, project_id: str,
@@ -201,7 +236,7 @@ class Client:
except Exception as e: except Exception as e:
logger.error(f"Error while creating tasks in batch: {e}") logger.error(f"Error while creating tasks in batch: {e}")
return None return None
async def create_task( async def create_task(
self, self,
project_id: str, project_id: str,
@@ -213,13 +248,13 @@ class Client:
task = {"data": data} task = {"data": data}
if meta: if meta:
task["meta"] = meta task["meta"] = meta
return await self.create_tasks_batch(project_id, [task]) return await self.create_tasks_batch(project_id, [task])
except Exception as e: except Exception as e:
logger.error(f"Error while creating single task: {e}") logger.error(f"Error while creating single task: {e}")
return None return None
async def get_project_tasks( async def get_project_tasks(
self, self,
project_id: str, project_id: str,
@@ -227,12 +262,12 @@ class Client:
page_size: int = 1000 page_size: int = 1000
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""获取项目任务信息 """获取项目任务信息
Args: Args:
project_id: 项目ID project_id: 项目ID
page: 页码(从1开始)。如果为None,则获取所有任务 page: 页码(从1开始)。如果为None,则获取所有任务
page_size: 每页大小 page_size: 每页大小
Returns: Returns:
如果指定了page参数,返回包含分页信息的字典: 如果指定了page参数,返回包含分页信息的字典:
{ {
@@ -242,9 +277,9 @@ class Client:
"project_id": 项目ID, "project_id": 项目ID,
"tasks": 当前页的任务列表 "tasks": 当前页的任务列表
} }
如果page为None,返回包含所有任务的字典: 如果page为None,返回包含所有任务的字典:
"count": 总任务数, "count": 总任务数,
"project_id": 项目ID, "project_id": 项目ID,
"tasks": 所有任务列表 "tasks": 所有任务列表
@@ -252,11 +287,11 @@ class Client:
""" """
try: try:
pid = int(project_id) pid = int(project_id)
# 如果指定了page,直接获取单页任务 # 如果指定了page,直接获取单页任务
if page is not None: if page is not None:
logger.debug(f"Fetching tasks for project {pid}, page {page} (page_size={page_size})") logger.debug(f"Fetching tasks for project {pid}, page {page} (page_size={page_size})")
response = await self.client.get( response = await self.client.get(
f"/api/tasks", f"/api/tasks",
params={ params={
@@ -266,9 +301,9 @@ class Client:
} }
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
# 返回单页结果,包含分页信息 # 返回单页结果,包含分页信息
return { return {
"count": result.get("total", len(result.get("tasks", []))), "count": result.get("total", len(result.get("tasks", []))),
@@ -277,11 +312,11 @@ class Client:
"project_id": pid, "project_id": pid,
"tasks": result.get("tasks", []) "tasks": result.get("tasks", [])
} }
# 如果未指定page,获取所有任务 # 如果未指定page,获取所有任务
logger.debug(f"(page) not specified, fetching all tasks.") logger.debug(f"(page) not specified, fetching all tasks.")
all_tasks = [] all_tasks = []
response = await self.client.get( response = await self.client.get(
f"/api/tasks", f"/api/tasks",
params={ params={
@@ -289,31 +324,31 @@ class Client:
} }
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
tasks = result.get("tasks", []) tasks = result.get("tasks", [])
if not tasks: if not tasks:
logger.debug(f"No tasks found for this project.") logger.debug(f"No tasks found for this project.")
all_tasks.extend(tasks) all_tasks.extend(tasks)
logger.debug(f"Fetched {len(tasks)} tasks.") logger.debug(f"Fetched {len(tasks)} tasks.")
# 返回所有任务,不包含分页信息 # 返回所有任务,不包含分页信息
return { return {
"count": len(all_tasks), "count": len(all_tasks),
"project_id": pid, "project_id": pid,
"tasks": all_tasks "tasks": all_tasks
} }
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"获取项目任务失败 HTTP {e.response.status_code}: {e.response.text}") logger.error(f"获取项目任务失败 HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"获取项目任务时发生错误: {e}") logger.error(f"获取项目任务时发生错误: {e}")
return None return None
async def delete_task( async def delete_task(
self, self,
task_id: int task_id: int
@@ -321,20 +356,20 @@ class Client:
"""删除单个任务""" """删除单个任务"""
try: try:
logger.debug(f"Deleting task: {task_id}") logger.debug(f"Deleting task: {task_id}")
response = await self.client.delete(f"/api/tasks/{task_id}") response = await self.client.delete(f"/api/tasks/{task_id}")
response.raise_for_status() response.raise_for_status()
logger.debug(f"Task deleted: {task_id}") logger.debug(f"Task deleted: {task_id}")
return True return True
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Delete task {task_id} failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Delete task {task_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error while deleting task {task_id}: {e}") logger.error(f"Error while deleting task {task_id}: {e}")
return False return False
async def delete_tasks_batch( async def delete_tasks_batch(
self, self,
task_ids: List[int] task_ids: List[int]
@@ -342,24 +377,24 @@ class Client:
"""批量删除任务""" """批量删除任务"""
try: try:
logger.debug(f"Deleting {len(task_ids)} tasks in batch") logger.debug(f"Deleting {len(task_ids)} tasks in batch")
successful_deletions = 0 successful_deletions = 0
failed_deletions = 0 failed_deletions = 0
for task_id in task_ids: for task_id in task_ids:
if await self.delete_task(task_id): if await self.delete_task(task_id):
successful_deletions += 1 successful_deletions += 1
else: else:
failed_deletions += 1 failed_deletions += 1
logger.debug(f"Batch deletion finished: success {successful_deletions}, failed {failed_deletions}") logger.debug(f"Batch deletion finished: success {successful_deletions}, failed {failed_deletions}")
return { return {
"successful": successful_deletions, "successful": successful_deletions,
"failed": failed_deletions, "failed": failed_deletions,
"total": len(task_ids) "total": len(task_ids)
} }
except Exception as e: except Exception as e:
logger.error(f"Error while deleting tasks in batch: {e}") logger.error(f"Error while deleting tasks in batch: {e}")
return { return {
@@ -367,72 +402,72 @@ class Client:
"failed": len(task_ids), "failed": len(task_ids),
"total": len(task_ids) "total": len(task_ids)
} }
async def get_project(self, project_id: int) -> Optional[Dict[str, Any]]: async def get_project(self, project_id: int) -> Optional[Dict[str, Any]]:
"""获取项目信息""" """获取项目信息"""
try: try:
logger.debug(f"Fetching project info: {project_id}") logger.debug(f"Fetching project info: {project_id}")
response = await self.client.get(f"/api/projects/{project_id}") response = await self.client.get(f"/api/projects/{project_id}")
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Get project info failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Get project info failed HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error while getting project info: {e}") logger.error(f"Error while getting project info: {e}")
return None return None
async def delete_project(self, project_id: int) -> bool: async def delete_project(self, project_id: int) -> bool:
"""删除项目""" """删除项目"""
try: try:
logger.debug(f"Deleting project: {project_id}") logger.debug(f"Deleting project: {project_id}")
response = await self.client.delete(f"/api/projects/{project_id}") response = await self.client.delete(f"/api/projects/{project_id}")
response.raise_for_status() response.raise_for_status()
logger.debug(f"Project deleted: {project_id}") logger.debug(f"Project deleted: {project_id}")
return True return True
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Delete project {project_id} failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Delete project {project_id} failed HTTP {e.response.status_code}: {e.response.text}")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error while deleting project {project_id}: {e}") logger.error(f"Error while deleting project {project_id}: {e}")
return False return False
async def get_task_annotations( async def get_task_annotations(
self, self,
task_id: int task_id: int
) -> Optional[List[Dict[str, Any]]]: ) -> Optional[List[Dict[str, Any]]]:
"""获取任务的标注结果 """获取任务的标注结果
Args: Args:
task_id: 任务ID task_id: 任务ID
Returns: Returns:
标注结果列表,每个标注包含完整的annotation信息 标注结果列表,每个标注包含完整的annotation信息
""" """
try: try:
logger.debug(f"Fetching annotations for task: {task_id}") logger.debug(f"Fetching annotations for task: {task_id}")
response = await self.client.get(f"/api/tasks/{task_id}/annotations") response = await self.client.get(f"/api/tasks/{task_id}/annotations")
response.raise_for_status() response.raise_for_status()
annotations = response.json() annotations = response.json()
logger.debug(f"Fetched {len(annotations)} annotations for task {task_id}") logger.debug(f"Fetched {len(annotations)} annotations for task {task_id}")
return annotations return annotations
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Get task annotations failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Get task annotations failed HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error while getting task annotations: {e}") logger.error(f"Error while getting task annotations: {e}")
return None return None
async def create_annotation( async def create_annotation(
self, self,
task_id: int, task_id: int,
@@ -440,111 +475,111 @@ class Client:
completed_by: Optional[int] = None completed_by: Optional[int] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""为任务创建新的标注 """为任务创建新的标注
Args: Args:
task_id: 任务ID task_id: 任务ID
result: 标注结果列表 result: 标注结果列表
completed_by: 完成标注的用户ID(可选) completed_by: 完成标注的用户ID(可选)
Returns: Returns:
创建的标注信息,失败返回None 创建的标注信息,失败返回None
""" """
try: try:
logger.debug(f"Creating annotation for task: {task_id}") logger.debug(f"Creating annotation for task: {task_id}")
annotation_data = { annotation_data = {
"result": result, "result": result,
"task": task_id "task": task_id
} }
if completed_by: if completed_by:
annotation_data["completed_by"] = completed_by annotation_data["completed_by"] = completed_by
response = await self.client.post( response = await self.client.post(
f"/api/tasks/{task_id}/annotations", f"/api/tasks/{task_id}/annotations",
json=annotation_data json=annotation_data
) )
response.raise_for_status() response.raise_for_status()
annotation = response.json() annotation = response.json()
logger.debug(f"Created annotation {annotation.get('id')} for task {task_id}") logger.debug(f"Created annotation {annotation.get('id')} for task {task_id}")
return annotation return annotation
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Create annotation failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Create annotation failed HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error while creating annotation: {e}") logger.error(f"Error while creating annotation: {e}")
return None return None
async def update_annotation( async def update_annotation(
self, self,
annotation_id: int, annotation_id: int,
result: List[Dict[str, Any]] result: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""更新已存在的标注 """更新已存在的标注
Args: Args:
annotation_id: 标注ID annotation_id: 标注ID
result: 新的标注结果列表 result: 新的标注结果列表
Returns: Returns:
更新后的标注信息,失败返回None 更新后的标注信息,失败返回None
""" """
try: try:
logger.debug(f"Updating annotation: {annotation_id}") logger.debug(f"Updating annotation: {annotation_id}")
annotation_data = { annotation_data = {
"result": result "result": result
} }
response = await self.client.patch( response = await self.client.patch(
f"/api/annotations/{annotation_id}", f"/api/annotations/{annotation_id}",
json=annotation_data json=annotation_data
) )
response.raise_for_status() response.raise_for_status()
annotation = response.json() annotation = response.json()
logger.debug(f"Updated annotation {annotation_id}") logger.debug(f"Updated annotation {annotation_id}")
return annotation return annotation
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Update annotation failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Update annotation failed HTTP {e.response.status_code}: {e.response.text}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error while updating annotation: {e}") logger.error(f"Error while updating annotation: {e}")
return None return None
async def delete_annotation( async def delete_annotation(
self, self,
annotation_id: int annotation_id: int
) -> bool: ) -> bool:
"""删除标注 """删除标注
Args: Args:
annotation_id: 标注ID annotation_id: 标注ID
Returns: Returns:
成功返回True,失败返回False 成功返回True,失败返回False
""" """
try: try:
logger.debug(f"Deleting annotation: {annotation_id}") logger.debug(f"Deleting annotation: {annotation_id}")
response = await self.client.delete(f"/api/annotations/{annotation_id}") response = await self.client.delete(f"/api/annotations/{annotation_id}")
response.raise_for_status() response.raise_for_status()
logger.debug(f"Deleted annotation {annotation_id}") logger.debug(f"Deleted annotation {annotation_id}")
return True return True
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Delete annotation failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Delete annotation failed HTTP {e.response.status_code}: {e.response.text}")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error while deleting annotation: {e}") logger.error(f"Error while deleting annotation: {e}")
return False return False
async def create_local_storage( async def create_local_storage(
self, self,
project_id: int, project_id: int,
@@ -555,7 +590,7 @@ class Client:
description: Optional[str] = None description: Optional[str] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""创建本地文件存储配置 """创建本地文件存储配置
Args: Args:
project_id: Label Studio 项目 ID project_id: Label Studio 项目 ID
path: 本地文件路径(在 Label Studio 容器中的路径) path: 本地文件路径(在 Label Studio 容器中的路径)
@@ -563,37 +598,37 @@ class Client:
use_blob_urls: 是否使用 blob URLs(建议 True) use_blob_urls: 是否使用 blob URLs(建议 True)
regex_filter: 文件过滤正则表达式(可选) regex_filter: 文件过滤正则表达式(可选)
description: 存储描述(可选) description: 存储描述(可选)
Returns: Returns:
创建的存储配置信息,失败返回 None 创建的存储配置信息,失败返回 None
""" """
try: try:
logger.debug(f"Creating local storage for project {project_id}: {path}") logger.debug(f"Creating local storage for project {project_id}: {path}")
storage_data = { storage_data = {
"project": project_id, "project": project_id,
"path": path, "path": path,
"title": title, "title": title,
"use_blob_urls": use_blob_urls "use_blob_urls": use_blob_urls
} }
if regex_filter: if regex_filter:
storage_data["regex_filter"] = regex_filter storage_data["regex_filter"] = regex_filter
if description: if description:
storage_data["description"] = description storage_data["description"] = description
response = await self.client.post( response = await self.client.post(
"/api/storages/localfiles/", "/api/storages/localfiles/",
json=storage_data json=storage_data
) )
response.raise_for_status() response.raise_for_status()
storage = response.json() storage = response.json()
storage_id = storage.get("id") storage_id = storage.get("id")
logger.debug(f"Local storage created successfully, ID: {storage_id}") logger.debug(f"Local storage created successfully, ID: {storage_id}")
return storage return storage
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.error(f"Create local storage failed HTTP {e.response.status_code}: {e.response.text}") logger.error(f"Create local storage failed HTTP {e.response.status_code}: {e.response.text}")
return None return None

View File

@@ -2,7 +2,7 @@ from typing import Optional
import math import math
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Path from fastapi import APIRouter, Depends, HTTPException, Query, Path, Response
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db from app.db.session import get_db
@@ -29,6 +29,39 @@ router = APIRouter(
) )
logger = get_logger(__name__) logger = get_logger(__name__)
@router.get("/{mapping_id}/login")
async def list_mappings(
db: AsyncSession = Depends(get_db)
):
try:
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token)
target_response = await ls_client.login_label_studio()
headers = dict(target_response.headers)
set_cookies = target_response.headers.get_list("set-cookie")
# 删除合并的 Set-Cookie
if "set-cookie" in headers:
del headers["set-cookie"]
# 创建新响应,添加多个 Set-Cookie
response = Response(
content=target_response.content,
status_code=target_response.status_code,
headers=headers
)
# 分别添加每个 Set-Cookie
for cookie in set_cookies:
response.headers.append("set-cookie", cookie)
return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error while logining in LabelStudio: {e}", e)
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201) @router.post("", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201)
async def create_mapping( async def create_mapping(
request: DatasetMappingCreateRequest, request: DatasetMappingCreateRequest,
@@ -36,12 +69,12 @@ async def create_mapping(
): ):
""" """
创建数据集映射 创建数据集映射
根据指定的DM程序中的数据集,创建Label Studio中的数据集, 根据指定的DM程序中的数据集,创建Label Studio中的数据集,
在数据库中记录这一关联关系,返回Label Studio数据集的ID 在数据库中记录这一关联关系,返回Label Studio数据集的ID
注意:一个数据集可以创建多个标注项目 注意:一个数据集可以创建多个标注项目
支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置 支持通过 template_id 指定标注模板,如果提供了模板ID,则使用模板的配置
""" """
try: try:
@@ -51,9 +84,9 @@ async def create_mapping(
mapping_service = DatasetMappingService(db) mapping_service = DatasetMappingService(db)
sync_service = SyncService(dm_client, ls_client, mapping_service) sync_service = SyncService(dm_client, ls_client, mapping_service)
template_service = AnnotationTemplateService() template_service = AnnotationTemplateService()
logger.info(f"Create dataset mapping request: {request.dataset_id}") logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息 # 从DM服务获取数据集信息
dataset_info = await dm_client.get_dataset(request.dataset_id) dataset_info = await dm_client.get_dataset(request.dataset_id)
if not dataset_info: if not dataset_info:
@@ -61,11 +94,11 @@ async def create_mapping(
status_code=404, status_code=404,
detail=f"Dataset not found in DM service: {request.dataset_id}" detail=f"Dataset not found in DM service: {request.dataset_id}"
) )
project_name = request.name or \ project_name = request.name or \
dataset_info.name or \ dataset_info.name or \
"A new project from DataMate" "A new project from DataMate"
project_description = request.description or \ project_description = request.description or \
dataset_info.description or \ dataset_info.description or \
f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})" f"Imported from DM dataset {dataset_info.name} ({dataset_info.id})"
@@ -89,15 +122,15 @@ async def create_mapping(
description=project_description, description=project_description,
label_config=label_config # 传递模板配置 label_config=label_config # 传递模板配置
) )
if not project_data: if not project_data:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Fail to create Label Studio project." detail="Fail to create Label Studio project."
) )
project_id = project_data["id"] project_id = project_data["id"]
# 配置本地存储:dataset/<id> # 配置本地存储:dataset/<id>
local_storage_path = f"{settings.label_studio_local_document_root}/{request.dataset_id}" local_storage_path = f"{settings.label_studio_local_document_root}/{request.dataset_id}"
storage_result = await ls_client.create_local_storage( storage_result = await ls_client.create_local_storage(
@@ -107,7 +140,7 @@ async def create_mapping(
use_blob_urls=True, use_blob_urls=True,
description=f"Local storage for dataset {dataset_info.name}" description=f"Local storage for dataset {dataset_info.name}"
) )
if not storage_result: if not storage_result:
# 本地存储配置失败,记录警告但不中断流程 # 本地存储配置失败,记录警告但不中断流程
logger.warning(f"Failed to configure local storage for project {project_id}") logger.warning(f"Failed to configure local storage for project {project_id}")
@@ -124,28 +157,28 @@ async def create_mapping(
# 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id) # 创建映射关系,包含项目名称(先持久化映射以获得 mapping.id)
mapping = await mapping_service.create_mapping(labeling_project) mapping = await mapping_service.create_mapping(labeling_project)
# 进行一次同步,使用创建后的 mapping.id # 进行一次同步,使用创建后的 mapping.id
await sync_service.sync_dataset_files(mapping.id, 100) await sync_service.sync_dataset_files(mapping.id, 100)
response_data = DatasetMappingCreateResponse( response_data = DatasetMappingCreateResponse(
id=mapping.id, id=mapping.id,
labeling_project_id=str(mapping.labeling_project_id), labeling_project_id=str(mapping.labeling_project_id),
labeling_project_name=mapping.name or project_name labeling_project_name=mapping.name or project_name
) )
return StandardResponse( return StandardResponse(
code=201, code=201,
message="success", message="success",
data=response_data data=response_data
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error while creating dataset mapping: {e}") logger.error(f"Error while creating dataset mapping: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) @router.get("", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings( async def list_mappings(
page: int = Query(1, ge=1, description="页码(从1开始)"), page: int = Query(1, ge=1, description="页码(从1开始)"),
@@ -155,10 +188,10 @@ async def list_mappings(
): ):
""" """
查询所有映射关系(分页) 查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询。 返回所有有效的数据集映射关系(未被软删除的),支持分页查询。
可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。 可选择是否包含完整的标注模板信息(默认不包含,以提高列表查询性能)。
参数: 参数:
- page: 页码(从1开始) - page: 页码(从1开始)
- pageSize: 每页记录数 - pageSize: 每页记录数
@@ -166,12 +199,12 @@ async def list_mappings(
""" """
try: try:
service = DatasetMappingService(db) service = DatasetMappingService(db)
# 计算 skip # 计算 skip
skip = (page - 1) * page_size skip = (page - 1) * page_size
logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}") logger.info(f"List mappings: page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数 # 获取数据和总数
mappings, total = await service.get_all_mappings_with_count( mappings, total = await service.get_all_mappings_with_count(
skip=skip, skip=skip,
@@ -179,10 +212,10 @@ async def list_mappings(
include_deleted=False, include_deleted=False,
include_template=include_template include_template=include_template
) )
# 计算总页数 # 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0 total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应 # 构造分页响应
paginated_data = PaginatedData( paginated_data = PaginatedData(
page=page, page=page,
@@ -191,15 +224,15 @@ async def list_mappings(
total_pages=total_pages, total_pages=total_pages,
content=mappings content=mappings
) )
logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}") logger.info(f"List mappings: page={page}, returned {len(mappings)}/{total}, templates_included: {include_template}")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=paginated_data data=paginated_data
) )
except Exception as e: except Exception as e:
logger.error(f"Error listing mappings: {e}") logger.error(f"Error listing mappings: {e}")
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@@ -211,7 +244,7 @@ async def get_mapping(
): ):
""" """
根据 UUID 查询单个映射关系(包含关联的标注模板详情) 根据 UUID 查询单个映射关系(包含关联的标注模板详情)
返回数据集映射关系以及关联的完整标注模板信息,包括: 返回数据集映射关系以及关联的完整标注模板信息,包括:
- 映射基本信息 - 映射基本信息
- 数据集信息 - 数据集信息
@@ -220,26 +253,26 @@ async def get_mapping(
""" """
try: try:
service = DatasetMappingService(db) service = DatasetMappingService(db)
logger.info(f"Get mapping with template details: {mapping_id}") logger.info(f"Get mapping with template details: {mapping_id}")
# 获取映射,并包含完整的模板信息 # 获取映射,并包含完整的模板信息
mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True) mapping = await service.get_mapping_by_uuid(mapping_id, include_template=True)
if not mapping: if not mapping:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Mapping not found: {mapping_id}" detail=f"Mapping not found: {mapping_id}"
) )
logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}") logger.info(f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=mapping data=mapping
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -256,10 +289,10 @@ async def get_mappings_by_source(
): ):
""" """
根据源数据集 ID 查询所有映射关系(分页,包含模板详情) 根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。 返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询。
默认包含关联的完整标注模板信息。 默认包含关联的完整标注模板信息。
参数: 参数:
- dataset_id: 数据集ID - dataset_id: 数据集ID
- page: 页码(从1开始) - page: 页码(从1开始)
@@ -268,12 +301,12 @@ async def get_mappings_by_source(
""" """
try: try:
service = DatasetMappingService(db) service = DatasetMappingService(db)
# 计算 skip # 计算 skip
skip = (page - 1) * page_size skip = (page - 1) * page_size
logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}") logger.info(f"Get mappings by source dataset id: {dataset_id}, page={page}, page_size={page_size}, include_template={include_template}")
# 获取数据和总数(包含模板信息) # 获取数据和总数(包含模板信息)
mappings, total = await service.get_mappings_by_source_with_count( mappings, total = await service.get_mappings_by_source_with_count(
dataset_id=dataset_id, dataset_id=dataset_id,
@@ -281,10 +314,10 @@ async def get_mappings_by_source(
limit=page_size, limit=page_size,
include_template=include_template include_template=include_template
) )
# 计算总页数 # 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0 total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应 # 构造分页响应
paginated_data = PaginatedData( paginated_data = PaginatedData(
page=page, page=page,
@@ -293,15 +326,15 @@ async def get_mappings_by_source(
total_pages=total_pages, total_pages=total_pages,
content=mappings content=mappings
) )
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}") logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}, templates_included: {include_template}")
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
data=paginated_data data=paginated_data
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -328,24 +361,24 @@ async def delete_mapping(
ls_client = LabelStudioClient(base_url=settings.label_studio_base_url, ls_client = LabelStudioClient(base_url=settings.label_studio_base_url,
token=settings.label_studio_user_token) token=settings.label_studio_user_token)
service = DatasetMappingService(db) service = DatasetMappingService(db)
# 使用 mapping UUID 查询映射记录 # 使用 mapping UUID 查询映射记录
logger.debug(f"Deleting by mapping UUID: {project_id}") logger.debug(f"Deleting by mapping UUID: {project_id}")
mapping = await service.get_mapping_by_uuid(project_id) mapping = await service.get_mapping_by_uuid(project_id)
logger.debug(f"Mapping lookup result: {mapping}") logger.debug(f"Mapping lookup result: {mapping}")
if not mapping: if not mapping:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Mapping either not found or not specified." detail=f"Mapping either not found or not specified."
) )
id = mapping.id id = mapping.id
labeling_project_id = mapping.labeling_project_id labeling_project_id = mapping.labeling_project_id
logger.debug(f"Found mapping: {id}, Label Studio project ID: {labeling_project_id}") logger.debug(f"Found mapping: {id}, Label Studio project ID: {labeling_project_id}")
# 1. 删除 Label Studio 项目 # 1. 删除 Label Studio 项目
try: try:
logger.debug(f"Deleting Label Studio project: {labeling_project_id}") logger.debug(f"Deleting Label Studio project: {labeling_project_id}")
@@ -357,11 +390,11 @@ async def delete_mapping(
except Exception as e: except Exception as e:
logger.error(f"Error deleting Label Studio project: {e}") logger.error(f"Error deleting Label Studio project: {e}")
# 继续执行,即使 Label Studio 项目删除失败也要删除映射记录 # 继续执行,即使 Label Studio 项目删除失败也要删除映射记录
# 2. 软删除映射记录 # 2. 软删除映射记录
soft_delete_success = await service.soft_delete_mapping(id) soft_delete_success = await service.soft_delete_mapping(id)
logger.debug(f"Soft delete result for mapping {id}: {soft_delete_success}") logger.debug(f"Soft delete result for mapping {id}: {soft_delete_success}")
if not soft_delete_success: if not soft_delete_success:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
@@ -378,7 +411,7 @@ async def delete_mapping(
status="success" status="success"
) )
) )
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e: