You've already forked DataMate
feat: Dataset pagination; camelCase support in schemas (#22)
implement pagination for dataset mappings. update response models to support camelCase parameters.
This commit is contained in:
@@ -1,40 +1,60 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
import math
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.services.dataset_mapping_service import DatasetMappingService
|
||||
from app.schemas.dataset_mapping import DatasetMappingResponse
|
||||
from app.schemas import StandardResponse
|
||||
from app.schemas.common import StandardResponse, PaginatedData
|
||||
from app.core.logging import get_logger
|
||||
from . import project_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@project_router.get("/mappings/list", response_model=StandardResponse[List[DatasetMappingResponse]])
|
||||
@project_router.get("/mappings/list", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||
async def list_mappings(
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
查询所有映射关系
|
||||
查询所有映射关系(分页)
|
||||
|
||||
返回所有有效的数据集映射关系(未被软删除的)
|
||||
返回所有有效的数据集映射关系(未被软删除的),支持分页查询
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Listing mappings, skip={skip}, limit={limit}")
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
mappings = await service.get_all_mappings(skip=skip, limit=limit)
|
||||
logger.info(f"Listing mappings, page={page}, page_size={page_size}")
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings")
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_all_mappings_with_count(
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
total_pages = math.ceil(total / page_size) if total > 0 else 0
|
||||
|
||||
# 构造分页响应
|
||||
paginated_data = PaginatedData(
|
||||
page=page,
|
||||
size=page_size,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mappings
|
||||
data=paginated_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -78,29 +98,51 @@ async def get_mapping(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[List[DatasetMappingResponse]])
|
||||
@project_router.get("/mappings/by-source/{source_dataset_id}", response_model=StandardResponse[PaginatedData[DatasetMappingResponse]])
|
||||
async def get_mappings_by_source(
|
||||
source_dataset_id: str,
|
||||
page: int = Query(1, ge=1, description="页码(从1开始)"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页记录数"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据源数据集 ID 查询所有映射关系
|
||||
根据源数据集 ID 查询所有映射关系(分页)
|
||||
|
||||
返回该数据集创建的所有标注项目(包括已删除的)
|
||||
返回该数据集创建的所有标注项目(不包括已删除的),支持分页查询
|
||||
"""
|
||||
try:
|
||||
service = DatasetMappingService(db)
|
||||
|
||||
logger.info(f"Get mappings by source dataset id: {source_dataset_id}")
|
||||
# 计算 skip
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
mappings = await service.get_mappings_by_source_dataset_id(source_dataset_id)
|
||||
logger.info(f"Get mappings by source dataset id: {source_dataset_id}, page={page}, page_size={page_size}")
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings")
|
||||
# 获取数据和总数
|
||||
mappings, total = await service.get_mappings_by_source_with_count(
|
||||
source_dataset_id=source_dataset_id,
|
||||
skip=skip,
|
||||
limit=page_size
|
||||
)
|
||||
|
||||
# 计算总页数
|
||||
total_pages = math.ceil(total / page_size) if total > 0 else 0
|
||||
|
||||
# 构造分页响应
|
||||
paginated_data = PaginatedData(
|
||||
page=page,
|
||||
size=page_size,
|
||||
total_elements=total,
|
||||
total_pages=total_pages,
|
||||
content=mappings
|
||||
)
|
||||
|
||||
logger.info(f"Found {len(mappings)} mappings on page {page}, total: {total}")
|
||||
|
||||
return StandardResponse(
|
||||
code=200,
|
||||
message="success",
|
||||
data=mappings
|
||||
data=paginated_data
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
"""
|
||||
通用响应模型
|
||||
"""
|
||||
from typing import Generic, TypeVar, Optional
|
||||
from typing import Generic, TypeVar, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 定义泛型类型变量
|
||||
T = TypeVar('T')
|
||||
|
||||
class StandardResponse(BaseModel, Generic[T]):
|
||||
# 定义一个将 snake_case 转换为 camelCase 的函数
|
||||
def to_camel(string: str) -> str:
|
||||
"""将 snake_case 字符串转换为 camelCase"""
|
||||
components = string.split('_')
|
||||
# 首字母小写,其余单词首字母大写
|
||||
return components[0] + ''.join(x.title() for x in components[1:])
|
||||
|
||||
class BaseResponseModel(BaseModel):
|
||||
"""基础响应模型,启用别名生成器"""
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
alias_generator = to_camel
|
||||
|
||||
class StandardResponse(BaseResponseModel, Generic[T]):
|
||||
"""
|
||||
标准API响应格式
|
||||
|
||||
@@ -18,6 +32,8 @@ class StandardResponse(BaseModel, Generic[T]):
|
||||
data: Optional[T] = Field(None, description="响应数据")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
alias_generator = to_camel
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
@@ -25,3 +41,23 @@ class StandardResponse(BaseModel, Generic[T]):
|
||||
"data": {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedData(BaseResponseModel, Generic[T]):
|
||||
"""分页数据容器"""
|
||||
page: int = Field(..., description="当前页码(从1开始)")
|
||||
size: int = Field(..., description="页大小")
|
||||
total_elements: int = Field(..., description="总条数")
|
||||
total_pages: int = Field(..., description="总页数")
|
||||
content: List[T] = Field(..., description="当前页数据")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"size": 20,
|
||||
"totalElements": 100,
|
||||
"totalPages": 5,
|
||||
"content": []
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class DatasetMappingBase(BaseModel):
|
||||
from .common import BaseResponseModel
|
||||
|
||||
class DatasetMappingBase(BaseResponseModel):
|
||||
"""数据集映射 基础模型"""
|
||||
source_dataset_id: str = Field(..., description="源数据集ID")
|
||||
|
||||
@@ -10,14 +12,14 @@ class DatasetMappingCreateRequest(DatasetMappingBase):
|
||||
"""数据集映射 创建 请求模型"""
|
||||
pass
|
||||
|
||||
class DatasetMappingCreateResponse(BaseModel):
|
||||
class DatasetMappingCreateResponse(BaseResponseModel):
|
||||
"""数据集映射 创建 响应模型"""
|
||||
mapping_id: str = Field(..., description="映射UUID")
|
||||
labelling_project_id: str = Field(..., description="Label Studio项目ID")
|
||||
labelling_project_name: str = Field(..., description="Label Studio项目名称")
|
||||
message: str = Field(..., description="响应消息")
|
||||
|
||||
class DatasetMappingUpdateRequest(BaseModel):
|
||||
class DatasetMappingUpdateRequest(BaseResponseModel):
|
||||
"""数据集映射 更新 请求模型"""
|
||||
source_dataset_id: Optional[str] = Field(None, description="源数据集ID")
|
||||
|
||||
@@ -32,13 +34,14 @@ class DatasetMappingResponse(DatasetMappingBase):
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
class SyncDatasetRequest(BaseModel):
|
||||
class SyncDatasetRequest(BaseResponseModel):
|
||||
"""同步数据集请求模型"""
|
||||
mapping_id: str = Field(..., description="映射ID(mapping UUID)")
|
||||
batch_size: int = Field(50, ge=1, le=100, description="批处理大小")
|
||||
|
||||
class SyncDatasetResponse(BaseModel):
|
||||
class SyncDatasetResponse(BaseResponseModel):
|
||||
"""同步数据集响应模型"""
|
||||
mapping_id: str = Field(..., description="映射UUID")
|
||||
status: str = Field(..., description="同步状态")
|
||||
@@ -46,7 +49,7 @@ class SyncDatasetResponse(BaseModel):
|
||||
total_files: int = Field(0, description="总文件数量")
|
||||
message: str = Field(..., description="响应消息")
|
||||
|
||||
class DeleteDatasetResponse(BaseModel):
|
||||
class DeleteDatasetResponse(BaseResponseModel):
|
||||
"""删除数据集响应模型"""
|
||||
mapping_id: str = Field(..., description="映射UUID")
|
||||
status: str = Field(..., description="删除状态")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from .common import BaseResponseModel
|
||||
|
||||
class LabelStudioProject(BaseModel):
|
||||
class LabelStudioProject(BaseResponseModel):
|
||||
"""Label Studio项目模型"""
|
||||
id: int = Field(..., description="项目ID")
|
||||
title: str = Field(..., description="项目标题")
|
||||
@@ -11,7 +12,7 @@ class LabelStudioProject(BaseModel):
|
||||
created_at: Optional[datetime] = Field(None, description="创建时间")
|
||||
updated_at: Optional[datetime] = Field(None, description="更新时间")
|
||||
|
||||
class LabelStudioTaskData(BaseModel):
|
||||
class LabelStudioTaskData(BaseResponseModel):
|
||||
"""Label Studio任务数据模型"""
|
||||
image: Optional[str] = Field(None, description="图像URL")
|
||||
text: Optional[str] = Field(None, description="文本内容")
|
||||
@@ -19,19 +20,19 @@ class LabelStudioTaskData(BaseModel):
|
||||
video: Optional[str] = Field(None, description="视频URL")
|
||||
filename: Optional[str] = Field(None, description="文件名")
|
||||
|
||||
class LabelStudioTask(BaseModel):
|
||||
class LabelStudioTask(BaseResponseModel):
|
||||
"""Label Studio任务模型"""
|
||||
data: LabelStudioTaskData = Field(..., description="任务数据")
|
||||
project: Optional[int] = Field(None, description="项目ID")
|
||||
meta: Optional[Dict[str, Any]] = Field(None, description="元数据")
|
||||
|
||||
class LabelStudioCreateProjectRequest(BaseModel):
|
||||
class LabelStudioCreateProjectRequest(BaseResponseModel):
|
||||
"""创建Label Studio项目请求模型"""
|
||||
title: str = Field(..., description="项目标题")
|
||||
description: str = Field("", description="项目描述")
|
||||
label_config: str = Field(..., description="标注配置")
|
||||
|
||||
class LabelStudioCreateTaskRequest(BaseModel):
|
||||
class LabelStudioCreateTaskRequest(BaseResponseModel):
|
||||
"""创建Label Studio任务请求模型"""
|
||||
data: Dict[str, Any] = Field(..., description="任务数据")
|
||||
project: Optional[int] = Field(None, description="项目ID")
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import update, func
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
@@ -213,11 +213,86 @@ class DatasetMappingService:
|
||||
logger.debug(f"Found {len(mappings)} mappings")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings]
|
||||
|
||||
async def count_mappings(self) -> int:
|
||||
async def count_mappings(self, include_deleted: bool = False) -> int:
|
||||
"""统计映射总数"""
|
||||
query = select(func.count()).select_from(DatasetMapping)
|
||||
|
||||
if not include_deleted:
|
||||
query = query.where(DatasetMapping.deleted_at.is_(None))
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_all_mappings_with_count(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""获取所有映射及总数(用于分页)"""
|
||||
logger.debug(f"List all mappings with count, skip: {skip}, limit: {limit}")
|
||||
|
||||
# 构建查询
|
||||
query = select(DatasetMapping)
|
||||
if not include_deleted:
|
||||
query = query.where(DatasetMapping.deleted_at.is_(None))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(DatasetMapping)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(DatasetMapping.deleted_at.is_(None))
|
||||
|
||||
count_result = await self.db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取数据
|
||||
result = await self.db.execute(
|
||||
select(DatasetMapping)
|
||||
.where(DatasetMapping.deleted_at.is_(None))
|
||||
query
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.order_by(DatasetMapping.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
return len(mappings)
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
|
||||
async def get_mappings_by_source_with_count(
|
||||
self,
|
||||
source_dataset_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||
"""根据源数据集ID获取映射关系及总数(用于分页)"""
|
||||
logger.debug(f"Get mappings by source dataset id with count: {source_dataset_id}")
|
||||
|
||||
# 构建查询
|
||||
query = select(DatasetMapping).where(
|
||||
DatasetMapping.source_dataset_id == source_dataset_id
|
||||
)
|
||||
|
||||
if not include_deleted:
|
||||
query = query.where(DatasetMapping.deleted_at.is_(None))
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count()).select_from(DatasetMapping).where(
|
||||
DatasetMapping.source_dataset_id == source_dataset_id
|
||||
)
|
||||
if not include_deleted:
|
||||
count_query = count_query.where(DatasetMapping.deleted_at.is_(None))
|
||||
|
||||
count_result = await self.db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# 获取数据
|
||||
result = await self.db.execute(
|
||||
query
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.order_by(DatasetMapping.created_at.desc())
|
||||
)
|
||||
mappings = result.scalars().all()
|
||||
|
||||
logger.debug(f"Found {len(mappings)} mappings, total: {total}")
|
||||
return [DatasetMappingResponse.model_validate(mapping) for mapping in mappings], total
|
||||
Reference in New Issue
Block a user