feat: Refactor knowledge base retrieval to return detailed search results and enhance API integration #108

This commit is contained in:
Dallas98
2025-11-25 21:21:21 +08:00
committed by GitHub
parent b50c12d135
commit bc26cfba55
9 changed files with 168 additions and 75 deletions

View File

@@ -76,16 +76,14 @@ public class KnowledgeBaseService {
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
if (StringUtils.hasText(request.getName())) {
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName())
.newCollectionName(request.getName())
.build());
knowledgeBase.setName(request.getName());
}
if (StringUtils.hasText(request.getDescription())) {
knowledgeBase.setDescription(request.getDescription());
}
knowledgeBaseRepository.updateById(knowledgeBase);
}
@@ -147,7 +145,7 @@ public class KnowledgeBaseService {
RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
ragFile.setFileId(fileInfo.id());
ragFile.setFileName(fileInfo.name());
ragFile.setFileName(fileInfo.fileName());
ragFile.setStatus(FileStatus.UNPROCESSED);
return ragFile;
}).toList();
@@ -209,23 +207,19 @@ public class KnowledgeBaseService {
* @param request 检索请求
* @return 检索结果
*/
public SearchResp retrieve(RetrieveReq request) {
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
return searchResp;
List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().getFirst();
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
// });
// return searchResp;
searchResults.forEach(item -> {
String metadata = item.getEntity().get("metadata").toString();
item.getEntity().put("metadata", metadata);
});
return searchResults;
}
}

View File

@@ -15,6 +15,7 @@ import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.vector.request.AnnSearchReq;
import io.milvus.v2.service.vector.request.FunctionScore;
import io.milvus.v2.service.vector.request.HybridSearchReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.data.BaseVector;
@@ -197,21 +198,26 @@ public class MilvusService {
.params("{\"drop_ratio_search\": 0.2}")
.topK(topK)
.build());
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
.name("rrf")
.name("weight")
.functionType(FunctionType.RERANK)
.param("reranker", "rrf")
.param("k", "60")
.param("reranker", "weighted")
.param("weights", "[0.1, 0.9]")
.param("norm_score", "true")
.build();
FunctionScore functionScore = FunctionScore.builder()
.functions(Collections.singletonList(ranker))
.build();
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
.collectionName(collectionName)
.searchRequests(searchRequests)
.ranker(ranker)
.functionScore(functionScore)
.outFields(Arrays.asList("id", "text", "metadata"))
.topK(topK)
.limit(topK)
.build());
return searchResp;
}

View File

@@ -11,6 +11,8 @@ import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 知识库控制器
@@ -136,7 +138,7 @@ public class KnowledgeBaseController {
* @return 检索结果
*/
@PostMapping("/retrieve")
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
public List<SearchResp.SearchResult> retrieve(@RequestBody @Valid RetrieveReq request) {
return knowledgeBaseService.retrieve(request);
}
}

View File

@@ -21,6 +21,6 @@ public class AddFilesReq {
private String delimiter;
private List<FileInfo> files;
public record FileInfo(String id, String name) {
public record FileInfo(String id, String fileName) {
}
}

View File

@@ -20,12 +20,12 @@ public class KnowledgeBaseCreateReq {
*/
@NotEmpty(message = "知识库名称不能为空")
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
private String name;
/**
* 知识库描述
*/
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
private String description;
/**

View File

@@ -20,11 +20,11 @@ public class KnowledgeBaseUpdateReq {
*/
@NotEmpty(message = "知识库名称不能为空")
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
private String name;
/**
* 知识库描述
*/
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
private String description;
}

View File

@@ -1,6 +1,6 @@
import type React from "react";
import { useEffect, useState } from "react";
import { Table, Badge, Button, Breadcrumb, Tooltip, App } from "antd";
import { Table, Badge, Button, Breadcrumb, Tooltip, App, Card, Input, Empty, Spin } from "antd";
import {
DeleteOutlined,
EditOutlined,
@@ -16,17 +16,39 @@ import {
deleteKnowledgeBaseFileByIdUsingDelete,
queryKnowledgeBaseByIdUsingGet,
queryKnowledgeBaseFilesUsingGet,
retrieveKnowledgeBaseContent,
} from "../knowledge-base.api";
import useFetchData from "@/hooks/useFetchData";
import AddDataDialog from "../components/AddDataDialog";
import CreateKnowledgeBase from "../components/CreateKnowledgeBase";
interface StatisticItem {
icon?: React.ReactNode;
label: string;
value: string | number;
}
interface RagChunk {
id: string;
text: string;
metadata: string;
}
interface RecallResult {
score: number;
entity: RagChunk;
id?: string | object;
primaryKey?: string;
}
const KnowledgeBaseDetailPage: React.FC = () => {
const navigate = useNavigate();
const { message } = App.useApp();
const { id } = useParams<{ id: string }>();
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseItem>(null);
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseItem | undefined>(undefined);
const [showEdit, setShowEdit] = useState(false);
const [activeTab, setActiveTab] = useState<'fileList' | 'recallTest'>('fileList');
const [recallLoading, setRecallLoading] = useState(false);
const [recallResults, setRecallResults] = useState<RecallResult[]>([]);
const [recallQuery, setRecallQuery] = useState("");
const fetchKnowledgeBaseDetails = async (id: string) => {
const { data } = await queryKnowledgeBaseByIdUsingGet(id);
@@ -55,12 +77,12 @@ const KnowledgeBaseDetailPage: React.FC = () => {
// File table logic
const handleDeleteFile = async (file: KBFile) => {
try {
await deleteKnowledgeBaseFileByIdUsingDelete(knowledgeBase.id, {
await deleteKnowledgeBaseFileByIdUsingDelete(knowledgeBase!.id, {
ids: [file.id]
});
message.success("文件已删除");
fetchFiles();
} catch (error) {
} catch {
message.error("文件删除失败");
}
};
@@ -72,11 +94,30 @@ const KnowledgeBaseDetailPage: React.FC = () => {
};
const handleRefreshPage = () => {
if (knowledgeBase) {
fetchKnowledgeBaseDetails(knowledgeBase.id);
}
fetchFiles();
setShowEdit(false);
};
const handleRecallTest = async () => {
if (!recallQuery || !knowledgeBase?.id) return;
setRecallLoading(true);
try {
const result = await retrieveKnowledgeBaseContent({
query: recallQuery,
topK: 10,
threshold: 0.2,
knowledgeBaseIds: [knowledgeBase.id],
});
setRecallResults(result?.data || []);
} catch {
setRecallResults([]);
}
setRecallLoading(false);
};
const operations = [
{
key: "edit",
@@ -104,7 +145,7 @@ const KnowledgeBaseDetailPage: React.FC = () => {
cancelText: "取消",
okText: "删除",
okType: "danger",
onConfirm: () => handleDeleteKB(knowledgeBase),
onConfirm: () => knowledgeBase && handleDeleteKB(knowledgeBase),
},
icon: <DeleteOutlined className="w-4 h-4" />,
},
@@ -134,9 +175,13 @@ const KnowledgeBaseDetailPage: React.FC = () => {
dataIndex: "status",
key: "vectorizationStatus",
width: 120,
render: (status: any) => (
<Badge color={status?.color} text={status?.label} />
),
render: (status: unknown) => {
if (typeof status === 'object' && status !== null) {
const s = status as { color?: string; label?: string };
return <Badge color={s.color} text={s.label} />;
}
return <Badge color="default" text={String(status)} />;
},
},
{
title: "分块数",
@@ -164,7 +209,7 @@ const KnowledgeBaseDetailPage: React.FC = () => {
key: "actions",
align: "right" as const,
width: 100,
render: (_: any, file: KBFile) => (
render: (_: unknown, file: KBFile) => (
<div>
{fileOps.map((op) => (
<Tooltip key={op.key} title={op.label}>
@@ -193,7 +238,9 @@ const KnowledgeBaseDetailPage: React.FC = () => {
</div>
<DetailHeader
data={knowledgeBase}
statistics={knowledgeBase?.statistics || []}
statistics={knowledgeBase && Array.isArray((knowledgeBase as { statistics?: StatisticItem[] }).statistics)
? ((knowledgeBase as { statistics?: StatisticItem[] }).statistics ?? [])
: []}
operations={operations}
/>
<CreateKnowledgeBase
@@ -205,25 +252,34 @@ const KnowledgeBaseDetailPage: React.FC = () => {
/>
<div className="flex-1 border-card p-6 mt-4">
<div className="flex items-center justify-between mb-4 gap-3">
<div className="flex items-center gap-2">
<Button type={activeTab === 'fileList' ? 'primary' : 'default'} onClick={() => setActiveTab('fileList')}>
</Button>
<Button type={activeTab === 'recallTest' ? 'primary' : 'default'} onClick={() => setActiveTab('recallTest')}>
</Button>
</div>
{activeTab === 'fileList' && (
<>
<div className="flex-1">
<SearchControls
searchTerm={searchParams.keyword}
onSearchChange={(keyword) =>
setSearchParams({ ...searchParams, keyword })
}
onSearchChange={(keyword) => setSearchParams({ ...searchParams, keyword })}
searchPlaceholder="搜索文件名..."
filters={[]}
onFiltersChange={handleFiltersChange}
onClearFilters={() =>
setSearchParams({ ...searchParams, filter: {} })
}
onClearFilters={() => setSearchParams({ ...searchParams, filter: { type: [], status: [], tags: [] } })}
showViewToggle={false}
showReload={false}
/>
</div>
<AddDataDialog knowledgeBase={knowledgeBase} onDataAdded={handleRefreshPage} />
</>
)}
</div>
{activeTab === 'fileList' ? (
<Table
loading={loading}
columns={fileColumns}
@@ -232,6 +288,41 @@ const KnowledgeBaseDetailPage: React.FC = () => {
pagination={pagination}
scroll={{ y: "calc(100vh - 30rem)" }}
/>
) : (
<div className="p-2">
<div style={{ fontSize: 14, fontWeight: 300, marginBottom: 8 }}></div>
<div className="flex items-center mb-4">
<Input.Search
value={recallQuery}
onChange={e => setRecallQuery(e.target.value)}
onSearch={handleRecallTest}
placeholder="请输入召回测试问题"
enterButton="检索"
loading={recallLoading}
style={{ width: "100%", fontSize: 18, height: 48 }}
/>
</div>
{recallLoading ? (
<Spin className="mt-8" />
) : recallResults.length === 0 ? (
<Empty description="暂无召回结果" />
) : (
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
{recallResults.map((item, idx) => (
<Card key={idx} title={`得分:${item.score?.toFixed(4) ?? "-"}`}
extra={<span style={{ fontSize: 12 }}>ID: {item.entity?.id ?? "-"}</span>}
style={{ wordBreak: "break-all" }}
>
<div style={{ marginBottom: 8, fontWeight: 500 }}>{item.entity?.text ?? ""}</div>
<div style={{ fontSize: 12, color: '#888' }}>
metadata: <pre style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-all', margin: 0 }}>{item.entity?.metadata}</pre>
</div>
</Card>
))}
</div>
)}
</div>
)}
</div>
</div>
);

View File

@@ -120,9 +120,7 @@ export default function AddDataDialog({ knowledgeBase, onDataAdded }) {
};
const handleAddData = async () => {
const selectedFiles = [];
if (selectedFiles.length === 0) {
if (getSelectedFilesCount() === 0) {
message.warning("请至少选择一个文件");
return;
}
@@ -130,7 +128,7 @@ export default function AddDataDialog({ knowledgeBase, onDataAdded }) {
try {
// 构造符合API要求的请求数据
const requestData = {
files: Object.entries(selectedFilesMap),
files: Object.values(selectedFilesMap),
processType: newKB.processType,
chunkSize: Number(newKB.chunkSize), // 确保是数字类型
overlapSize: Number(newKB.overlapSize), // 确保是数字类型

View File

@@ -35,15 +35,17 @@ export function addKnowledgeBaseFilesUsingPost(baseId: string, data: any) {
return post(`/api/knowledge-base/${baseId}/files`, data);
}
// 获取知识生成文件详情
export function queryKnowledgeBaseFilesByIdUsingGet(
baseId: string,
fileId: string
) {
return get(`/api/knowledge-base/${baseId}/files/${fileId}`);
}
// 删除知识生成文件
export function deleteKnowledgeBaseFileByIdUsingDelete(baseId: string, data: any) {
return del(`/api/knowledge-base/${baseId}/files`, data);
}
// 检索知识库内容
export function retrieveKnowledgeBaseContent(data: {
query: string;
topK?: number;
threshold?: number;
knowledgeBaseIds: string[];
}) {
return post("/api/knowledge-base/retrieve", data);
}