feat(template): 添加模板搜索功能和优化数据获取

- 添加 keyword 参数支持模板名称和描述模糊搜索
- 在 useFetchData hook 中添加 filterParamMapper 参数用于过滤参数映射
- 为模板列表页面实现内置标志过滤器映射功能
- 优化模板配置更新逻辑,改进数据验证和转换流程
- 完善模板服务中的条件查询,支持多字段模糊匹配
- 更新数据获取 hook 的依赖数组以正确处理轮询逻辑
This commit is contained in:
2026-01-22 21:25:04 +08:00
parent d22d677efe
commit ccb581d501
4 changed files with 305 additions and 245 deletions

View File

@@ -17,13 +17,23 @@ import { useDebouncedEffect } from "./useDebouncedEffect";
import Loading from "@/utils/loading"; import Loading from "@/utils/loading";
import { App } from "antd"; import { App } from "antd";
type FetchParams = Record<string, unknown>;
type FetchResult<T> = {
data?: {
content?: Partial<T>[];
totalElements?: number;
total?: number;
};
};
export default function useFetchData<T>( export default function useFetchData<T>(
fetchFunc: (params?: any) => Promise<any>, fetchFunc: (params?: FetchParams) => Promise<FetchResult<T>>,
mapDataFunc: (data: Partial<T>) => T = (data) => data as T, mapDataFunc: (data: Partial<T>) => T = (data) => data as T,
pollingInterval: number = 30000, // 默认30秒轮询一次 pollingInterval: number = 30000, // 默认30秒轮询一次
autoRefresh: boolean = false, // 是否自动开始轮询,默认 false autoRefresh: boolean = false, // 是否自动开始轮询,默认 false
additionalPollingFuncs: (() => Promise<any>)[] = [], // 额外的轮询函数 additionalPollingFuncs: (() => Promise<unknown>)[] = [], // 额外的轮询函数
pageOffset: number = 1 pageOffset: number = 1,
filterParamMapper?: (filters: Record<string, unknown>) => Record<string, unknown>
) { ) {
const { message } = App.useApp(); const { message } = App.useApp();
@@ -111,6 +121,7 @@ export default function useFetchData<T>(
} }
try { try {
const mappedFilterParams = filterParamMapper ? filterParamMapper(filter) : {};
// 同时执行主要数据获取和额外的轮询函数 // 同时执行主要数据获取和额外的轮询函数
const promises = [ const promises = [
fetchFunc({ fetchFunc({
@@ -121,6 +132,7 @@ export default function useFetchData<T>(
type: getFirstOfArray(filter?.type) || undefined, type: getFirstOfArray(filter?.type) || undefined,
status: getFirstOfArray(filter?.status) || undefined, status: getFirstOfArray(filter?.status) || undefined,
tags: filter?.tags?.length ? filter.tags.join(",") : undefined, tags: filter?.tags?.length ? filter.tags.join(",") : undefined,
...mappedFilterParams,
page: current - pageOffset, page: current - pageOffset,
size: pageSize, // Use camelCase for HTTP query params size: pageSize, // Use camelCase for HTTP query params
}), }),
@@ -167,9 +179,11 @@ export default function useFetchData<T>(
mapDataFunc, mapDataFunc,
isPolling, isPolling,
clearPollingTimer, clearPollingTimer,
pageOffset,
pollingInterval, pollingInterval,
message, message,
additionalPollingFuncs, additionalPollingFuncs,
filterParamMapper,
] ]
); );
@@ -215,7 +229,7 @@ export default function useFetchData<T>(
return () => { return () => {
clearPollingTimer(); clearPollingTimer();
}; };
}, [clearPollingTimer]); }, [autoRefresh, startPolling, clearPollingTimer]);
return { return {
loading, loading,

View File

@@ -56,6 +56,31 @@ const TemplateList: React.FC = () => {
}, },
]; ];
const BUILT_IN_FLAG = {
TRUE: "true",
FALSE: "false",
} as const;
const mapTemplateFilters = (filters: Record<string, string[]>) => {
const getFirstValue = (values?: string[]) =>
values && values.length > 0 ? values[0] : undefined;
const builtInRaw = getFirstValue(filters.builtIn);
const builtIn =
builtInRaw === BUILT_IN_FLAG.TRUE
? true
: builtInRaw === BUILT_IN_FLAG.FALSE
? false
: undefined;
return {
category: getFirstValue(filters.category),
dataType: getFirstValue(filters.dataType),
labelingType: getFirstValue(filters.labelingType),
builtIn,
};
};
// Modals // Modals
const [isFormVisible, setIsFormVisible] = useState(false); const [isFormVisible, setIsFormVisible] = useState(false);
const [isDetailVisible, setIsDetailVisible] = useState(false); const [isDetailVisible, setIsDetailVisible] = useState(false);
@@ -71,7 +96,15 @@ const TemplateList: React.FC = () => {
fetchData, fetchData,
handleFiltersChange, handleFiltersChange,
handleKeywordChange, handleKeywordChange,
} = useFetchData(queryAnnotationTemplatesUsingGet, undefined, undefined, undefined, undefined, 0); } = useFetchData(
queryAnnotationTemplatesUsingGet,
undefined,
undefined,
undefined,
undefined,
0,
mapTemplateFilters
);
const handleCreate = () => { const handleCreate = () => {
setFormMode("create"); setFormMode("create");

View File

@@ -67,6 +67,7 @@ async def get_template(
async def list_template( async def list_template(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
size: int = Query(10, ge=1, le=100, description="每页大小"), size: int = Query(10, ge=1, le=100, description="每页大小"),
keyword: Optional[str] = Query(None, description="关键词"),
category: Optional[str] = Query(None, description="分类筛选"), category: Optional[str] = Query(None, description="分类筛选"),
dataType: Optional[str] = Query(None, alias="dataType", description="数据类型筛选"), dataType: Optional[str] = Query(None, alias="dataType", description="数据类型筛选"),
labelingType: Optional[str] = Query(None, alias="labelingType", description="标注类型筛选"), labelingType: Optional[str] = Query(None, alias="labelingType", description="标注类型筛选"),
@@ -78,6 +79,7 @@ async def list_template(
- **page**: 页码(从1开始) - **page**: 页码(从1开始)
- **size**: 每页大小(1-100) - **size**: 每页大小(1-100)
- **keyword**: 关键词(匹配名称/描述)
- **category**: 模板分类筛选 - **category**: 模板分类筛选
- **dataType**: 数据类型筛选 - **dataType**: 数据类型筛选
- **labelingType**: 标注类型筛选 - **labelingType**: 标注类型筛选
@@ -90,7 +92,8 @@ async def list_template(
category=category, category=category,
data_type=dataType, data_type=dataType,
labeling_type=labelingType, labeling_type=labelingType,
built_in=builtIn built_in=builtIn,
keyword=keyword
) )
return StandardResponse(code=200, message="success", data=templates) return StandardResponse(code=200, message="success", data=templates)

View File

@@ -3,7 +3,7 @@ Annotation Template Service
""" """
from typing import Optional, List from typing import Optional, List
from datetime import datetime from datetime import datetime
from sqlalchemy import select, func from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from uuid import uuid4 from uuid import uuid4
from fastapi import HTTPException from fastapi import HTTPException
@@ -185,7 +185,8 @@ class AnnotationTemplateService:
category: Optional[str] = None, category: Optional[str] = None,
data_type: Optional[str] = None, data_type: Optional[str] = None,
labeling_type: Optional[str] = None, labeling_type: Optional[str] = None,
built_in: Optional[bool] = None built_in: Optional[bool] = None,
keyword: Optional[str] = None
) -> AnnotationTemplateListResponse: ) -> AnnotationTemplateListResponse:
""" """
获取模板列表 获取模板列表
@@ -213,6 +214,14 @@ class AnnotationTemplateService:
conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore conditions.append(AnnotationTemplate.labeling_type == labeling_type) # type: ignore
if built_in is not None: if built_in is not None:
conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore conditions.append(AnnotationTemplate.built_in == built_in) # type: ignore
if keyword:
like_keyword = f"%{keyword}%"
conditions.append(
or_(
AnnotationTemplate.name.ilike(like_keyword), # type: ignore
AnnotationTemplate.description.ilike(like_keyword) # type: ignore
)
)
# 查询总数 # 查询总数
count_result = await db.execute( count_result = await db.execute(
@@ -273,13 +282,14 @@ class AnnotationTemplateService:
for field, value in update_data.items(): for field, value in update_data.items():
if field == 'configuration' and value is not None: if field == 'configuration' and value is not None:
# 验证配置JSON # 验证配置JSON
config_dict = value.model_dump(mode='json', by_alias=False) config = value if isinstance(value, TemplateConfiguration) else TemplateConfiguration.model_validate(value)
config_dict = config.model_dump(mode='json', by_alias=False)
valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict) valid, error = LabelStudioConfigValidator.validate_configuration_json(config_dict)
if not valid: if not valid:
raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}") raise HTTPException(status_code=400, detail=f"Invalid configuration: {error}")
# 重新生成Label Studio XML配置(用于验证) # 重新生成Label Studio XML配置(用于验证)
label_config = self.generate_label_studio_config(value) label_config = self.generate_label_studio_config(config)
# 验证生成的XML # 验证生成的XML
valid, error = LabelStudioConfigValidator.validate_xml(label_config) valid, error = LabelStudioConfigValidator.validate_xml(label_config)