You've already forked DataMate
feat: Refactor knowledge base retrieval to return detailed search results and enhance API integration #108
This commit is contained in:
@@ -76,16 +76,14 @@ public class KnowledgeBaseService {
|
||||
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
if (StringUtils.hasText(request.getName())) {
|
||||
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
|
||||
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
|
||||
.collectionName(knowledgeBase.getName())
|
||||
.newCollectionName(request.getName())
|
||||
.build());
|
||||
knowledgeBase.setName(request.getName());
|
||||
}
|
||||
if (StringUtils.hasText(request.getDescription())) {
|
||||
knowledgeBase.setDescription(request.getDescription());
|
||||
}
|
||||
knowledgeBaseRepository.updateById(knowledgeBase);
|
||||
}
|
||||
|
||||
@@ -147,7 +145,7 @@ public class KnowledgeBaseService {
|
||||
RagFile ragFile = new RagFile();
|
||||
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
|
||||
ragFile.setFileId(fileInfo.id());
|
||||
ragFile.setFileName(fileInfo.name());
|
||||
ragFile.setFileName(fileInfo.fileName());
|
||||
ragFile.setStatus(FileStatus.UNPROCESSED);
|
||||
return ragFile;
|
||||
}).toList();
|
||||
@@ -209,23 +207,19 @@ public class KnowledgeBaseService {
|
||||
* @param request 检索请求
|
||||
* @return 检索结果
|
||||
*/
|
||||
public SearchResp retrieve(RetrieveReq request) {
|
||||
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
|
||||
return searchResp;
|
||||
List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().getFirst();
|
||||
|
||||
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
|
||||
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
|
||||
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
|
||||
// });
|
||||
// return searchResp;
|
||||
searchResults.forEach(item -> {
|
||||
String metadata = item.getEntity().get("metadata").toString();
|
||||
item.getEntity().put("metadata", metadata);
|
||||
});
|
||||
return searchResults;
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import io.milvus.v2.service.collection.request.AddFieldReq;
|
||||
import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
||||
import io.milvus.v2.service.collection.request.HasCollectionReq;
|
||||
import io.milvus.v2.service.vector.request.AnnSearchReq;
|
||||
import io.milvus.v2.service.vector.request.FunctionScore;
|
||||
import io.milvus.v2.service.vector.request.HybridSearchReq;
|
||||
import io.milvus.v2.service.vector.request.InsertReq;
|
||||
import io.milvus.v2.service.vector.request.data.BaseVector;
|
||||
@@ -197,21 +198,26 @@ public class MilvusService {
|
||||
.params("{\"drop_ratio_search\": 0.2}")
|
||||
.topK(topK)
|
||||
.build());
|
||||
|
||||
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
|
||||
.name("rrf")
|
||||
.name("weight")
|
||||
.functionType(FunctionType.RERANK)
|
||||
.param("reranker", "rrf")
|
||||
.param("k", "60")
|
||||
.param("reranker", "weighted")
|
||||
.param("weights", "[0.1, 0.9]")
|
||||
.param("norm_score", "true")
|
||||
.build();
|
||||
|
||||
FunctionScore functionScore = FunctionScore.builder()
|
||||
.functions(Collections.singletonList(ranker))
|
||||
.build();
|
||||
|
||||
|
||||
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
|
||||
.collectionName(collectionName)
|
||||
.searchRequests(searchRequests)
|
||||
.ranker(ranker)
|
||||
.functionScore(functionScore)
|
||||
.outFields(Arrays.asList("id", "text", "metadata"))
|
||||
.topK(topK)
|
||||
.limit(topK)
|
||||
.build());
|
||||
return searchResp;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import jakarta.validation.Valid;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* 知识库控制器
|
||||
@@ -136,7 +138,7 @@ public class KnowledgeBaseController {
|
||||
* @return 检索结果
|
||||
*/
|
||||
@PostMapping("/retrieve")
|
||||
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
|
||||
public List<SearchResp.SearchResult> retrieve(@RequestBody @Valid RetrieveReq request) {
|
||||
return knowledgeBaseService.retrieve(request);
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,6 @@ public class AddFilesReq {
|
||||
private String delimiter;
|
||||
private List<FileInfo> files;
|
||||
|
||||
public record FileInfo(String id, String name) {
|
||||
public record FileInfo(String id, String fileName) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,12 +20,12 @@ public class KnowledgeBaseCreateReq {
|
||||
*/
|
||||
@NotEmpty(message = "知识库名称不能为空")
|
||||
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
|
||||
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
private String name;
|
||||
/**
|
||||
* 知识库描述
|
||||
*/
|
||||
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
|
||||
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
|
||||
private String description;
|
||||
|
||||
/**
|
||||
|
||||
@@ -20,11 +20,11 @@ public class KnowledgeBaseUpdateReq {
|
||||
*/
|
||||
@NotEmpty(message = "知识库名称不能为空")
|
||||
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
|
||||
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
private String name;
|
||||
/**
|
||||
* 知识库描述
|
||||
*/
|
||||
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
|
||||
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
|
||||
private String description;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type React from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Table, Badge, Button, Breadcrumb, Tooltip, App } from "antd";
|
||||
import { Table, Badge, Button, Breadcrumb, Tooltip, App, Card, Input, Empty, Spin } from "antd";
|
||||
import {
|
||||
DeleteOutlined,
|
||||
EditOutlined,
|
||||
@@ -16,17 +16,39 @@ import {
|
||||
deleteKnowledgeBaseFileByIdUsingDelete,
|
||||
queryKnowledgeBaseByIdUsingGet,
|
||||
queryKnowledgeBaseFilesUsingGet,
|
||||
retrieveKnowledgeBaseContent,
|
||||
} from "../knowledge-base.api";
|
||||
import useFetchData from "@/hooks/useFetchData";
|
||||
import AddDataDialog from "../components/AddDataDialog";
|
||||
import CreateKnowledgeBase from "../components/CreateKnowledgeBase";
|
||||
|
||||
interface StatisticItem {
|
||||
icon?: React.ReactNode;
|
||||
label: string;
|
||||
value: string | number;
|
||||
}
|
||||
interface RagChunk {
|
||||
id: string;
|
||||
text: string;
|
||||
metadata: string;
|
||||
}
|
||||
interface RecallResult {
|
||||
score: number;
|
||||
entity: RagChunk;
|
||||
id?: string | object;
|
||||
primaryKey?: string;
|
||||
}
|
||||
|
||||
const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const { message } = App.useApp();
|
||||
const { id } = useParams<{ id: string }>();
|
||||
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseItem>(null);
|
||||
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseItem | undefined>(undefined);
|
||||
const [showEdit, setShowEdit] = useState(false);
|
||||
const [activeTab, setActiveTab] = useState<'fileList' | 'recallTest'>('fileList');
|
||||
const [recallLoading, setRecallLoading] = useState(false);
|
||||
const [recallResults, setRecallResults] = useState<RecallResult[]>([]);
|
||||
const [recallQuery, setRecallQuery] = useState("");
|
||||
|
||||
const fetchKnowledgeBaseDetails = async (id: string) => {
|
||||
const { data } = await queryKnowledgeBaseByIdUsingGet(id);
|
||||
@@ -55,12 +77,12 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
// File table logic
|
||||
const handleDeleteFile = async (file: KBFile) => {
|
||||
try {
|
||||
await deleteKnowledgeBaseFileByIdUsingDelete(knowledgeBase.id, {
|
||||
await deleteKnowledgeBaseFileByIdUsingDelete(knowledgeBase!.id, {
|
||||
ids: [file.id]
|
||||
});
|
||||
message.success("文件已删除");
|
||||
fetchFiles();
|
||||
} catch (error) {
|
||||
} catch {
|
||||
message.error("文件删除失败");
|
||||
}
|
||||
};
|
||||
@@ -72,11 +94,30 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
};
|
||||
|
||||
const handleRefreshPage = () => {
|
||||
if (knowledgeBase) {
|
||||
fetchKnowledgeBaseDetails(knowledgeBase.id);
|
||||
}
|
||||
fetchFiles();
|
||||
setShowEdit(false);
|
||||
};
|
||||
|
||||
const handleRecallTest = async () => {
|
||||
if (!recallQuery || !knowledgeBase?.id) return;
|
||||
setRecallLoading(true);
|
||||
try {
|
||||
const result = await retrieveKnowledgeBaseContent({
|
||||
query: recallQuery,
|
||||
topK: 10,
|
||||
threshold: 0.2,
|
||||
knowledgeBaseIds: [knowledgeBase.id],
|
||||
});
|
||||
setRecallResults(result?.data || []);
|
||||
} catch {
|
||||
setRecallResults([]);
|
||||
}
|
||||
setRecallLoading(false);
|
||||
};
|
||||
|
||||
const operations = [
|
||||
{
|
||||
key: "edit",
|
||||
@@ -104,7 +145,7 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
cancelText: "取消",
|
||||
okText: "删除",
|
||||
okType: "danger",
|
||||
onConfirm: () => handleDeleteKB(knowledgeBase),
|
||||
onConfirm: () => knowledgeBase && handleDeleteKB(knowledgeBase),
|
||||
},
|
||||
icon: <DeleteOutlined className="w-4 h-4" />,
|
||||
},
|
||||
@@ -134,9 +175,13 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
dataIndex: "status",
|
||||
key: "vectorizationStatus",
|
||||
width: 120,
|
||||
render: (status: any) => (
|
||||
<Badge color={status?.color} text={status?.label} />
|
||||
),
|
||||
render: (status: unknown) => {
|
||||
if (typeof status === 'object' && status !== null) {
|
||||
const s = status as { color?: string; label?: string };
|
||||
return <Badge color={s.color} text={s.label} />;
|
||||
}
|
||||
return <Badge color="default" text={String(status)} />;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "分块数",
|
||||
@@ -164,7 +209,7 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
key: "actions",
|
||||
align: "right" as const,
|
||||
width: 100,
|
||||
render: (_: any, file: KBFile) => (
|
||||
render: (_: unknown, file: KBFile) => (
|
||||
<div>
|
||||
{fileOps.map((op) => (
|
||||
<Tooltip key={op.key} title={op.label}>
|
||||
@@ -193,7 +238,9 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
</div>
|
||||
<DetailHeader
|
||||
data={knowledgeBase}
|
||||
statistics={knowledgeBase?.statistics || []}
|
||||
statistics={knowledgeBase && Array.isArray((knowledgeBase as { statistics?: StatisticItem[] }).statistics)
|
||||
? ((knowledgeBase as { statistics?: StatisticItem[] }).statistics ?? [])
|
||||
: []}
|
||||
operations={operations}
|
||||
/>
|
||||
<CreateKnowledgeBase
|
||||
@@ -205,25 +252,34 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
/>
|
||||
<div className="flex-1 border-card p-6 mt-4">
|
||||
<div className="flex items-center justify-between mb-4 gap-3">
|
||||
<div className="flex items-center gap-2">
|
||||
<Button type={activeTab === 'fileList' ? 'primary' : 'default'} onClick={() => setActiveTab('fileList')}>
|
||||
文件列表
|
||||
</Button>
|
||||
<Button type={activeTab === 'recallTest' ? 'primary' : 'default'} onClick={() => setActiveTab('recallTest')}>
|
||||
召回测试
|
||||
</Button>
|
||||
</div>
|
||||
{activeTab === 'fileList' && (
|
||||
<>
|
||||
<div className="flex-1">
|
||||
<SearchControls
|
||||
searchTerm={searchParams.keyword}
|
||||
onSearchChange={(keyword) =>
|
||||
setSearchParams({ ...searchParams, keyword })
|
||||
}
|
||||
onSearchChange={(keyword) => setSearchParams({ ...searchParams, keyword })}
|
||||
searchPlaceholder="搜索文件名..."
|
||||
filters={[]}
|
||||
onFiltersChange={handleFiltersChange}
|
||||
onClearFilters={() =>
|
||||
setSearchParams({ ...searchParams, filter: {} })
|
||||
}
|
||||
onClearFilters={() => setSearchParams({ ...searchParams, filter: { type: [], status: [], tags: [] } })}
|
||||
showViewToggle={false}
|
||||
showReload={false}
|
||||
/>
|
||||
</div>
|
||||
<AddDataDialog knowledgeBase={knowledgeBase} onDataAdded={handleRefreshPage} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{activeTab === 'fileList' ? (
|
||||
<Table
|
||||
loading={loading}
|
||||
columns={fileColumns}
|
||||
@@ -232,6 +288,41 @@ const KnowledgeBaseDetailPage: React.FC = () => {
|
||||
pagination={pagination}
|
||||
scroll={{ y: "calc(100vh - 30rem)" }}
|
||||
/>
|
||||
) : (
|
||||
<div className="p-2">
|
||||
<div style={{ fontSize: 14, fontWeight: 300, marginBottom: 8 }}>基于语义文本检索和全文检索后的加权平均结果</div>
|
||||
<div className="flex items-center mb-4">
|
||||
<Input.Search
|
||||
value={recallQuery}
|
||||
onChange={e => setRecallQuery(e.target.value)}
|
||||
onSearch={handleRecallTest}
|
||||
placeholder="请输入召回测试问题"
|
||||
enterButton="检索"
|
||||
loading={recallLoading}
|
||||
style={{ width: "100%", fontSize: 18, height: 48 }}
|
||||
/>
|
||||
</div>
|
||||
{recallLoading ? (
|
||||
<Spin className="mt-8" />
|
||||
) : recallResults.length === 0 ? (
|
||||
<Empty description="暂无召回结果" />
|
||||
) : (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
{recallResults.map((item, idx) => (
|
||||
<Card key={idx} title={`得分:${item.score?.toFixed(4) ?? "-"}`}
|
||||
extra={<span style={{ fontSize: 12 }}>ID: {item.entity?.id ?? "-"}</span>}
|
||||
style={{ wordBreak: "break-all" }}
|
||||
>
|
||||
<div style={{ marginBottom: 8, fontWeight: 500 }}>{item.entity?.text ?? ""}</div>
|
||||
<div style={{ fontSize: 12, color: '#888' }}>
|
||||
metadata: <pre style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-all', margin: 0 }}>{item.entity?.metadata}</pre>
|
||||
</div>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -120,9 +120,7 @@ export default function AddDataDialog({ knowledgeBase, onDataAdded }) {
|
||||
};
|
||||
|
||||
const handleAddData = async () => {
|
||||
const selectedFiles = [];
|
||||
|
||||
if (selectedFiles.length === 0) {
|
||||
if (getSelectedFilesCount() === 0) {
|
||||
message.warning("请至少选择一个文件");
|
||||
return;
|
||||
}
|
||||
@@ -130,7 +128,7 @@ export default function AddDataDialog({ knowledgeBase, onDataAdded }) {
|
||||
try {
|
||||
// 构造符合API要求的请求数据
|
||||
const requestData = {
|
||||
files: Object.entries(selectedFilesMap),
|
||||
files: Object.values(selectedFilesMap),
|
||||
processType: newKB.processType,
|
||||
chunkSize: Number(newKB.chunkSize), // 确保是数字类型
|
||||
overlapSize: Number(newKB.overlapSize), // 确保是数字类型
|
||||
|
||||
@@ -35,15 +35,17 @@ export function addKnowledgeBaseFilesUsingPost(baseId: string, data: any) {
|
||||
return post(`/api/knowledge-base/${baseId}/files`, data);
|
||||
}
|
||||
|
||||
// 获取知识生成文件详情
|
||||
export function queryKnowledgeBaseFilesByIdUsingGet(
|
||||
baseId: string,
|
||||
fileId: string
|
||||
) {
|
||||
return get(`/api/knowledge-base/${baseId}/files/${fileId}`);
|
||||
}
|
||||
|
||||
// 删除知识生成文件
|
||||
export function deleteKnowledgeBaseFileByIdUsingDelete(baseId: string, data: any) {
|
||||
return del(`/api/knowledge-base/${baseId}/files`, data);
|
||||
}
|
||||
|
||||
// 检索知识库内容
|
||||
export function retrieveKnowledgeBaseContent(data: {
|
||||
query: string;
|
||||
topK?: number;
|
||||
threshold?: number;
|
||||
knowledgeBaseIds: string[];
|
||||
}) {
|
||||
return post("/api/knowledge-base/retrieve", data);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user