feat: Dataset pagination; camelCase support in schemas (#22)

implement pagination for dataset mappings.
update response models to support camelCase parameters.
This commit is contained in:
Jinglong Wang
2025-10-24 17:14:42 +08:00
committed by GitHub
parent f9dbefd737
commit ad9f41ffd7
5 changed files with 195 additions and 38 deletions

View File

@@ -1,40 +1,60 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
import math
from app.db.database import get_db
from app.services.dataset_mapping_service import DatasetMappingService
from app.schemas.dataset_mapping import DatasetMappingResponse
from app.schemas import StandardResponse
from app.schemas.common import StandardResponse, PaginatedData
from app.core.logging import get_logger
from . import project_router
logger = get_logger(__name__)
@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]])
@project_router.get("/mappings/list", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def list_mappings(
skip: int = Query(0, ge=0, description="Number of records to skip"),
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
db: AsyncSession = Depends(get_db)
):
"""
查询所有映射关系
查询所有映射关系(分页)
返回所有有效的数据集映射关系(未被软删除的)
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
"""
try:
service = DatasetMappingService(db)
logger.info(f"Listing mappings, skip={skip}, limit={limit}")
# 计算 skip
skip = (page - 1) * page_size
mappings = await service.get_all_mappings(skip=skip, limit=limit)
logger.info(f"Listing mappings, page={page}, page_size={page_size}")
logger.info(f"Found {len(mappings)} mappings")
# 获取数据和总数
mappings, total = await service.get_all_mappings_with_count(
skip=skip,
limit=page_size
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=page_size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
return StandardResponse(
code=200,
message="success",
data=mappings
data=paginated_data
)
except Exception as e:
@@ -78,29 +98,51 @@ async def get_mapping(
raise HTTPException(status_code=500, detail="Internal server error")
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]])
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
async def get_mappings_by_source(
source_dataset_id: str,
page: int = Query(1, ge=1, description="页码(从1开始)"),
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
db: AsyncSession = Depends(get_db)
):
"""
根据源数据集 ID 查询所有映射关系
根据源数据集 ID 查询所有映射关系(分页)
返回该数据集创建的所有标注项目(包括已删除的)
返回该数据集创建的所有标注项目(包括已删除的),支持分页查询
"""
try:
service = DatasetMappingService(db)
logger.info(f"Get mappings by source dataset id: {source_dataset_id}")
# 计算 skip
skip = (page - 1) * page_size
mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id)
logger.info(f"Get mappings by source dataset id: {source_dataset_id}, page={page}, page_size={page_size}")
logger.info(f"Found {len(mappings)} mappings")
# 获取数据和总数
mappings, total = await service.get_mappings_by_source_with_count(
source_dataset_id=source_dataset_id,
skip=skip,
limit=page_size
)
# 计算总页数
total_pages = math.ceil(total / page_size) if total > 0 else 0
# 构造分页响应
paginated_data = PaginatedData(
page=page,
size=page_size,
total_elements=total,
total_pages=total_pages,
content=mappings
)
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
return StandardResponse(
code=200,
message="success",
data=mappings
data=paginated_data
)
except HTTPException:

View File

@@ -1,13 +1,27 @@
"""
通用响应模型
"""
from typing import Generic, TypeVar, Optional
from typing import Generic, TypeVar, Optional, List
from pydantic import BaseModel, Field
# 定义泛型类型变量
T = TypeVar('T')
class StandardResponse(BaseModel, Generic[T]):
# 定义一个将 snake_case 转换为 camelCase 的函数
def to_camel(string: str) -> str:
"""将 snake_case 字符串转换为 camelCase"""
components = string.split('_')
# 首字母小写,其余单词首字母大写
return components[0] + ''.join(x.title() for x in components[1:])
class BaseResponseModel(BaseModel):
"""基础响应模型,启用别名生成器"""
class Config:
populate_by_name = True
alias_generator = to_camel
class StandardResponse(BaseResponseModel, Generic[T]):
"""
标准API响应格式
@@ -18,6 +32,8 @@ class StandardResponse(BaseModel, Generic[T]):
data: Optional[T] = Field(None, description="响应数据")
class Config:
populate_by_name = True
alias_generator = to_camel
json_schema_extra = {
"example": {
"code": 200,
@@ -25,3 +41,23 @@ class StandardResponse(BaseModel, Generic[T]):
"data": {}
}
}
class PaginatedData(BaseResponseModel, Generic[T]):
"""分页数据容器"""
page: int = Field(..., description="当前页码(从1开始)")
size: int = Field(..., description="页大小")
total_elements: int = Field(..., description="总条数")
total_pages: int = Field(..., description="总页数")
content: List[T] = Field(..., description="当前页数据")
class Config:
json_schema_extra = {
"example": {
"page": 1,
"size": 20,
"totalElements": 100,
"totalPages": 5,
"content": []
}
}

View File

@@ -1,8 +1,10 @@
from pydantic import BaseModel, Field
from pydantic import Field
from typing import Optional
from datetime import datetime
class DatasetMappingBase(BaseModel):
from .common import BaseResponseModel
class DatasetMappingBase(BaseResponseModel):
"""数据集映射 基础模型"""
source_dataset_id: str = Field(..., description="源数据集ID")
@@ -10,14 +12,14 @@ class DatasetMappingCreateRequest(DatasetMappingBase):
"""数据集映射 创建 请求模型"""
pass
class DatasetMappingCreateResponse(BaseModel):
class DatasetMappingCreateResponse(BaseResponseModel):
"""数据集映射 创建 响应模型"""
mapping_id: str = Field(..., description="映射UUID")
labelling_project_id: str = Field(..., description="Label Studio项目ID")
labelling_project_name: str = Field(..., description="Label Studio项目名称")
message: str = Field(..., description="响应消息")
class DatasetMappingUpdateRequest(BaseModel):
class DatasetMappingUpdateRequest(BaseResponseModel):
"""数据集映射 更新 请求模型"""
source_dataset_id: Optional[str] = Field(None, description="源数据集ID")
@@ -32,13 +34,14 @@ class DatasetMappingResponse(DatasetMappingBase):
class Config:
from_attributes = True
populate_by_name = True
class SyncDatasetRequest(BaseModel):
class SyncDatasetRequest(BaseResponseModel):
"""同步数据集请求模型"""
mapping_id: str = Field(..., description="映射ID(mapping UUID)")
batch_size: int = Field(50, ge=1, le=100, description="批处理大小")
class SyncDatasetResponse(BaseModel):
class SyncDatasetResponse(BaseResponseModel):
"""同步数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="同步状态")
@@ -46,7 +49,7 @@ class SyncDatasetResponse(BaseModel):
total_files: int = Field(0, description="总文件数量")
message: str = Field(..., description="响应消息")
class DeleteDatasetResponse(BaseModel):
class DeleteDatasetResponse(BaseResponseModel):
"""删除数据集响应模型"""
mapping_id: str = Field(..., description="映射UUID")
status: str = Field(..., description="删除状态")

View File

@@ -1,8 +1,9 @@
from pydantic import BaseModel, Field
from pydantic import Field
from typing import Dict, Any, Optional, List
from datetime import datetime
from .common import BaseResponseModel
class LabelStudioProject(BaseModel):
class LabelStudioProject(BaseResponseModel):
"""Label Studio项目模型"""
id: int = Field(..., description="项目ID")
title: str = Field(..., description="项目标题")
@@ -11,7 +12,7 @@ class LabelStudioProject(BaseModel):
created_at: Optional[datetime] = Field(None, description="创建时间")
updated_at: Optional[datetime] = Field(None, description="更新时间")
class LabelStudioTaskData(BaseModel):
class LabelStudioTaskData(BaseResponseModel):
"""Label Studio任务数据模型"""
image: Optional[str] = Field(None, description="图像URL")
text: Optional[str] = Field(None, description="文本内容")
@@ -19,19 +20,19 @@ class LabelStudioTaskData(BaseModel):
video: Optional[str] = Field(None, description="视频URL")
filename: Optional[str] = Field(None, description="文件名")
class LabelStudioTask(BaseModel):
class LabelStudioTask(BaseResponseModel):
"""Label Studio任务模型"""
data: LabelStudioTaskData = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")
meta: Optional[Dict[str, Any]] = Field(None, description="元数据")
class LabelStudioCreateProjectRequest(BaseModel):
class LabelStudioCreateProjectRequest(BaseResponseModel):
"""创建Label Studio项目请求模型"""
title: str = Field(..., description="项目标题")
description: str = Field("", description="项目描述")
label_config: str = Field(..., description="标注配置")
class LabelStudioCreateTaskRequest(BaseModel):
class LabelStudioCreateTaskRequest(BaseResponseModel):
"""创建Label Studio任务请求模型"""
data: Dict[str, Any] = Field(..., description="任务数据")
project: Optional[int] = Field(None, description="项目ID")

View File

@@ -1,7 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update
from typing import Optional, List
from sqlalchemy import update, func
from typing import Optional, List, Tuple
from datetime import datetime
import uuid
@@ -213,11 +213,86 @@ class DatasetMappingService:
logger.debug(f"Found {len(mappings)} mappings")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
async def count_mappings(self) -> int:
async def count_mappings(self, include_deleted: bool = False) -> int:
"""统计映射总数"""
query = select(func.count()).select_from(DatasetMapping)
if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one()
async def get_all_mappings_with_count(
self,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""获取所有映射及总数(用于分页)"""
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
# 构建查询
query = select(DatasetMapping)
if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(DatasetMapping)
if not include_deleted:
count_query = count_query.where(DatasetMapping.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
select(DatasetMapping)
.where(DatasetMapping.deleted_at.is_(None))
query
.offset(skip)
.limit(limit)
.order_by(DatasetMapping.created_at.desc())
)
mappings = result.scalars().all()
return len(mappings)
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
async def get_mappings_by_source_with_count(
self,
source_dataset_id: str,
skip: int = 0,
limit: int = 100,
include_deleted: bool = False
) -> Tuple[List[DatasetMappingResponse], int]:
"""根据源数据集ID获取映射关系及总数(用于分页)"""
logger.debug(f"Get mappings by source dataset id with count: {source_dataset_id}")
# 构建查询
query = select(DatasetMapping).where(
DatasetMapping.source_dataset_id == source_dataset_id
)
if not include_deleted:
query = query.where(DatasetMapping.deleted_at.is_(None))
# 获取总数
count_query = select(func.count()).select_from(DatasetMapping).where(
DatasetMapping.source_dataset_id == source_dataset_id
)
if not include_deleted:
count_query = count_query.where(DatasetMapping.deleted_at.is_(None))
count_result = await self.db.execute(count_query)
total = count_result.scalar_one()
# 获取数据
result = await self.db.execute(
query
.offset(skip)
.limit(limit)
.order_by(DatasetMapping.created_at.desc())
)
mappings = result.scalars().all()
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total