From ad9f41ffd79d4f9cb5c3953152ae74cccc8535d1 Mon Sep 17 00:00:00 2001 From: Jinglong Wang <56037774+JasonW404@users.noreply.github.com> Date: Fri, 24 Oct 2025 17:14:42 +0800 Subject: [PATCH] feat: Dataset pagination; camelCase support in schemas (#22) implement pagination for dataset mappings. update response models to support camelCase parameters. --- .../app/api/project/list.py | 76 ++++++++++++---- .../app/schemas/common.py | 40 ++++++++- .../app/schemas/dataset_mapping.py | 17 ++-- .../app/schemas/label_studio.py | 13 +-- .../app/services/dataset_mapping_service.py | 87 +++++++++++++++++-- 5 files changed, 195 insertions(+), 38 deletions(-) diff --git a/runtime/label-studio-adapter/app/api/project/list.py b/runtime/label-studio-adapter/app/api/project/list.py index c3bcafa..af2efc2 100644 --- a/runtime/label-studio-adapter/app/api/project/list.py +++ b/runtime/label-studio-adapter/app/api/project/list.py @@ -1,40 +1,60 @@ from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from typing import List +import math from app.db.database import get_db from app.services.dataset_mapping_service import DatasetMappingService from app.schemas.dataset_mapping import DatasetMappingResponse -from app.schemas import StandardResponse +from app.schemas.common import StandardResponse, PaginatedData from app.core.logging import get_logger from . import project_router logger = get_logger(__name__) -@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]]) +@project_router.get("/mappings/list", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) async def list_mappings( - skip: int = Query(0, ge=0, description="Number of records to skip"), - limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"), + page: int = Query(1, ge=1, description="页码(从1开始)"), + page_size: int = Query(20, ge=1, le=100, description="每页记录数"), db: AsyncSession = Depends(get_db) ): """ - 查询所有映射关系 + 查询所有映射关系(分页) - 返回所有有效的数据集映射关系(未被软删除的) + 返回所有有效的数据集映射关系(未被软删除的),支持分页查询 """ try: service = DatasetMappingService(db) - logger.info(f"Listing mappings, skip={skip}, limit={limit}") + # 计算 skip + skip = (page - 1) * page_size - mappings = await service.get_all_mappings(skip=skip, limit=limit) + logger.info(f"Listing mappings, page={page}, page_size={page_size}") - logger.info(f"Found {len(mappings)} mappings") + # 获取数据和总数 + mappings, total = await service.get_all_mappings_with_count( + skip=skip, + limit=page_size + ) + + # 计算总页数 + total_pages = math.ceil(total / page_size) if total > 0 else 0 + + # 构造分页响应 + paginated_data = PaginatedData( + page=page, + size=page_size, + total_elements=total, + total_pages=total_pages, + content=mappings + ) + + logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}") return StandardResponse( code=200, message="success", - data=mappings + data=paginated_data ) except Exception as e: @@ -78,29 +98,51 @@ async def get_mapping( raise HTTPException(status_code=500, detail="Internal server error") -@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]]) +@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]]) async def get_mappings_by_source( source_dataset_id: str, + page: int = Query(1, ge=1, description="页码(从1开始)"), + page_size: int = Query(20, ge=1, le=100, description="每页记录数"), db: AsyncSession = Depends(get_db) ): """ - 根据源数据集 ID 查询所有映射关系 + 根据源数据集 ID 查询所有映射关系(分页) - 返回该数据集创建的所有标注项目(包括已删除的) + 返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询 """ try: service = DatasetMappingService(db) - logger.info(f"Get mappings by source dataset id: {source_dataset_id}") + # 计算 skip + skip = (page - 1) * page_size - mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id) + logger.info(f"Get mappings by source dataset id: {source_dataset_id}, page={page}, page_size={page_size}") - logger.info(f"Found {len(mappings)} mappings") + # 获取数据和总数 + mappings, total = await service.get_mappings_by_source_with_count( + source_dataset_id=source_dataset_id, + skip=skip, + limit=page_size + ) + + # 计算总页数 + total_pages = math.ceil(total / page_size) if total > 0 else 0 + + # 构造分页响应 + paginated_data = PaginatedData( + page=page, + size=page_size, + total_elements=total, + total_pages=total_pages, + content=mappings + ) + + logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}") return StandardResponse( code=200, message="success", - data=mappings + data=paginated_data ) except HTTPException: diff --git a/runtime/label-studio-adapter/app/schemas/common.py b/runtime/label-studio-adapter/app/schemas/common.py index f931844..00f73f5 100644 --- a/runtime/label-studio-adapter/app/schemas/common.py +++ b/runtime/label-studio-adapter/app/schemas/common.py @@ -1,13 +1,27 @@ """ 通用响应模型 """ -from typing import Generic, TypeVar, Optional +from typing import Generic, TypeVar, Optional, List from pydantic import BaseModel, Field # 定义泛型类型变量 T = TypeVar('T') -class StandardResponse(BaseModel, Generic[T]): +# 定义一个将 snake_case 转换为 camelCase 的函数 +def to_camel(string: str) -> str: + """将 snake_case 字符串转换为 camelCase""" + components = string.split('_') + # 首字母小写,其余单词首字母大写 + return components[0] + ''.join(x.title() for x in components[1:]) + +class BaseResponseModel(BaseModel): + """基础响应模型,启用别名生成器""" + + class Config: + populate_by_name = True + alias_generator = to_camel + +class StandardResponse(BaseResponseModel, Generic[T]): """ 标准API响应格式 @@ -18,6 +32,8 @@ class StandardResponse(BaseModel, Generic[T]): data: Optional[T] = Field(None, description="响应数据") class Config: + populate_by_name = True + alias_generator = to_camel json_schema_extra = { "example": { "code": 200, @@ -25,3 +41,23 @@ class StandardResponse(BaseModel, Generic[T]): "data": {} } } + + +class PaginatedData(BaseResponseModel, Generic[T]): + """分页数据容器""" + page: int = Field(..., description="当前页码(从1开始)") + size: int = Field(..., description="页大小") + total_elements: int = Field(..., description="总条数") + total_pages: int = Field(..., description="总页数") + content: List[T] = Field(..., description="当前页数据") + + class Config: + json_schema_extra = { + "example": { + "page": 1, + "size": 20, + "totalElements": 100, + "totalPages": 5, + "content": [] + } + } diff --git a/runtime/label-studio-adapter/app/schemas/dataset_mapping.py b/runtime/label-studio-adapter/app/schemas/dataset_mapping.py index 9806c77..1249335 100644 --- a/runtime/label-studio-adapter/app/schemas/dataset_mapping.py +++ b/runtime/label-studio-adapter/app/schemas/dataset_mapping.py @@ -1,8 +1,10 @@ -from pydantic import BaseModel, Field +from pydantic import Field from typing import Optional from datetime import datetime -class DatasetMappingBase(BaseModel): +from .common import BaseResponseModel + +class DatasetMappingBase(BaseResponseModel): """数据集映射 基础模型""" source_dataset_id: str = Field(..., description="源数据集ID") @@ -10,14 +12,14 @@ class DatasetMappingCreateRequest(DatasetMappingBase): """数据集映射 创建 请求模型""" pass -class DatasetMappingCreateResponse(BaseModel): +class DatasetMappingCreateResponse(BaseResponseModel): """数据集映射 创建 响应模型""" mapping_id: str = Field(..., description="映射UUID") labelling_project_id: str = Field(..., description="Label Studio项目ID") labelling_project_name: str = Field(..., description="Label Studio项目名称") message: str = Field(..., description="响应消息") -class DatasetMappingUpdateRequest(BaseModel): +class DatasetMappingUpdateRequest(BaseResponseModel): """数据集映射 更新 请求模型""" source_dataset_id: Optional[str] = Field(None, description="源数据集ID") @@ -32,13 +34,14 @@ class DatasetMappingResponse(DatasetMappingBase): class Config: from_attributes = True + populate_by_name = True -class SyncDatasetRequest(BaseModel): +class SyncDatasetRequest(BaseResponseModel): """同步数据集请求模型""" mapping_id: str = Field(..., description="映射ID(mapping UUID)") batch_size: int = Field(50, ge=1, le=100, description="批处理大小") -class SyncDatasetResponse(BaseModel): +class SyncDatasetResponse(BaseResponseModel): """同步数据集响应模型""" mapping_id: str = Field(..., description="映射UUID") status: str = Field(..., description="同步状态") @@ -46,7 +49,7 @@ class SyncDatasetResponse(BaseModel): total_files: int = Field(0, description="总文件数量") message: str = Field(..., description="响应消息") -class DeleteDatasetResponse(BaseModel): +class DeleteDatasetResponse(BaseResponseModel): """删除数据集响应模型""" mapping_id: str = Field(..., description="映射UUID") status: str = Field(..., description="删除状态") diff --git a/runtime/label-studio-adapter/app/schemas/label_studio.py b/runtime/label-studio-adapter/app/schemas/label_studio.py index 66f5e71..00e5417 100644 --- a/runtime/label-studio-adapter/app/schemas/label_studio.py +++ b/runtime/label-studio-adapter/app/schemas/label_studio.py @@ -1,8 +1,9 @@ -from pydantic import BaseModel, Field +from pydantic import Field from typing import Dict, Any, Optional, List from datetime import datetime +from .common import BaseResponseModel -class LabelStudioProject(BaseModel): +class LabelStudioProject(BaseResponseModel): """Label Studio项目模型""" id: int = Field(..., description="项目ID") title: str = Field(..., description="项目标题") @@ -11,7 +12,7 @@ class LabelStudioProject(BaseModel): created_at: Optional[datetime] = Field(None, description="创建时间") updated_at: Optional[datetime] = Field(None, description="更新时间") -class LabelStudioTaskData(BaseModel): +class LabelStudioTaskData(BaseResponseModel): """Label Studio任务数据模型""" image: Optional[str] = Field(None, description="图像URL") text: Optional[str] = Field(None, description="文本内容") @@ -19,19 +20,19 @@ class LabelStudioTaskData(BaseModel): video: Optional[str] = Field(None, description="视频URL") filename: Optional[str] = Field(None, description="文件名") -class LabelStudioTask(BaseModel): +class LabelStudioTask(BaseResponseModel): """Label Studio任务模型""" data: LabelStudioTaskData = Field(..., description="任务数据") project: Optional[int] = Field(None, description="项目ID") meta: Optional[Dict[str, Any]] = Field(None, description="元数据") -class LabelStudioCreateProjectRequest(BaseModel): +class LabelStudioCreateProjectRequest(BaseResponseModel): """创建Label Studio项目请求模型""" title: str = Field(..., description="项目标题") description: str = Field("", description="项目描述") label_config: str = Field(..., description="标注配置") -class LabelStudioCreateTaskRequest(BaseModel): +class LabelStudioCreateTaskRequest(BaseResponseModel): """创建Label Studio任务请求模型""" data: Dict[str, Any] = Field(..., description="任务数据") project: Optional[int] = Field(None, description="项目ID") \ No newline at end of file diff --git a/runtime/label-studio-adapter/app/services/dataset_mapping_service.py b/runtime/label-studio-adapter/app/services/dataset_mapping_service.py index a97e165..1493199 100644 --- a/runtime/label-studio-adapter/app/services/dataset_mapping_service.py +++ b/runtime/label-studio-adapter/app/services/dataset_mapping_service.py @@ -1,7 +1,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy import update -from typing import Optional, List +from sqlalchemy import update, func +from typing import Optional, List, Tuple from datetime import datetime import uuid @@ -213,11 +213,86 @@ class DatasetMappingService: logger.debug(f"Found {len(mappings)} mappings") return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings] - async def count_mappings(self) -> int: + async def count_mappings(self, include_deleted: bool = False) -> int: """统计映射总数""" + query = select(func.count()).select_from(DatasetMapping) + + if not include_deleted: + query = query.where(DatasetMapping.deleted_at.is_(None)) + + result = await self.db.execute(query) + return result.scalar_one() + + async def get_all_mappings_with_count( + self, + skip: int = 0, + limit: int = 100, + include_deleted: bool = False + ) -> Tuple[List[DatasetMappingResponse], int]: + """获取所有映射及总数(用于分页)""" + logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}") + + # 构建查询 + query = select(DatasetMapping) + if not include_deleted: + query = query.where(DatasetMapping.deleted_at.is_(None)) + + # 获取总数 + count_query = select(func.count()).select_from(DatasetMapping) + if not include_deleted: + count_query = count_query.where(DatasetMapping.deleted_at.is_(None)) + + count_result = await self.db.execute(count_query) + total = count_result.scalar_one() + + # 获取数据 result = await self.db.execute( - select(DatasetMapping) - .where(DatasetMapping.deleted_at.is_(None)) + query + .offset(skip) + .limit(limit) + .order_by(DatasetMapping.created_at.desc()) ) mappings = result.scalars().all() - return len(mappings) \ No newline at end of file + + logger.debug(f"Found {len(mappings)} mappings, total: {total}") + return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total + + async def get_mappings_by_source_with_count( + self, + source_dataset_id: str, + skip: int = 0, + limit: int = 100, + include_deleted: bool = False + ) -> Tuple[List[DatasetMappingResponse], int]: + """根据源数据集ID获取映射关系及总数(用于分页)""" + logger.debug(f"Get mappings by source dataset id with count: {source_dataset_id}") + + # 构建查询 + query = select(DatasetMapping).where( + DatasetMapping.source_dataset_id == source_dataset_id + ) + + if not include_deleted: + query = query.where(DatasetMapping.deleted_at.is_(None)) + + # 获取总数 + count_query = select(func.count()).select_from(DatasetMapping).where( + DatasetMapping.source_dataset_id == source_dataset_id + ) + if not include_deleted: + count_query = count_query.where(DatasetMapping.deleted_at.is_(None)) + + count_result = await self.db.execute(count_query) + total = count_result.scalar_one() + + # 获取数据 + result = await self.db.execute( + query + .offset(skip) + .limit(limit) + .order_by(DatasetMapping.created_at.desc()) + ) + mappings = result.scalars().all() + + logger.debug(f"Found {len(mappings)} mappings, total: {total}") + return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total \ No newline at end of file