feat: Refactor knowledge base retrieval to return detailed search results and enhance API integration #108

This commit is contained in:
Dallas98
2025-11-25 21:21:21 +08:00
committed by GitHub
parent b50c12d135
commit bc26cfba55
9 changed files with 168 additions and 75 deletions

View File

@@ -76,16 +76,14 @@ public class KnowledgeBaseService {
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
if (StringUtils.hasText(request.getName())) {
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName())
.newCollectionName(request.getName())
.build());
knowledgeBase.setName(request.getName());
}
if (StringUtils.hasText(request.getDescription())) {
knowledgeBase.setDescription(request.getDescription());
}
knowledgeBase.setDescription(request.getDescription());
knowledgeBaseRepository.updateById(knowledgeBase);
}
@@ -147,7 +145,7 @@ public class KnowledgeBaseService {
RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
ragFile.setFileId(fileInfo.id());
ragFile.setFileName(fileInfo.name());
ragFile.setFileName(fileInfo.fileName());
ragFile.setStatus(FileStatus.UNPROCESSED);
return ragFile;
}).toList();
@@ -209,23 +207,19 @@ public class KnowledgeBaseService {
* @param request 检索请求
* @return 检索结果
*/
public SearchResp retrieve(RetrieveReq request) {
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
return searchResp;
List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().getFirst();
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
// });
// return searchResp;
searchResults.forEach(item -> {
String metadata = item.getEntity().get("metadata").toString();
item.getEntity().put("metadata", metadata);
});
return searchResults;
}
}

View File

@@ -15,6 +15,7 @@ import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.vector.request.AnnSearchReq;
import io.milvus.v2.service.vector.request.FunctionScore;
import io.milvus.v2.service.vector.request.HybridSearchReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.data.BaseVector;
@@ -197,21 +198,26 @@ public class MilvusService {
.params("{\"drop_ratio_search\": 0.2}")
.topK(topK)
.build());
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
.name("rrf")
.name("weight")
.functionType(FunctionType.RERANK)
.param("reranker", "rrf")
.param("k", "60")
.param("reranker", "weighted")
.param("weights", "[0.1, 0.9]")
.param("norm_score", "true")
.build();
FunctionScore functionScore = FunctionScore.builder()
.functions(Collections.singletonList(ranker))
.build();
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
.collectionName(collectionName)
.searchRequests(searchRequests)
.ranker(ranker)
.functionScore(functionScore)
.outFields(Arrays.asList("id", "text", "metadata"))
.topK(topK)
.limit(topK)
.build());
return searchResp;
}

View File

@@ -11,6 +11,8 @@ import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 知识库控制器
@@ -136,7 +138,7 @@ public class KnowledgeBaseController {
* @return 检索结果
*/
@PostMapping("/retrieve")
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
public List<SearchResp.SearchResult> retrieve(@RequestBody @Valid RetrieveReq request) {
return knowledgeBaseService.retrieve(request);
}
}

View File

@@ -21,6 +21,6 @@ public class AddFilesReq {
private String delimiter;
private List<FileInfo> files;
public record FileInfo(String id, String name) {
public record FileInfo(String id, String fileName) {
}
}

View File

@@ -20,12 +20,12 @@ public class KnowledgeBaseCreateReq {
*/
@NotEmpty(message = "知识库名称不能为空")
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
private String name;
/**
* 知识库描述
*/
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
private String description;
/**

View File

@@ -20,11 +20,11 @@ public class KnowledgeBaseUpdateReq {
*/
@NotEmpty(message = "知识库名称不能为空")
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
private String name;
/**
* 知识库描述
*/
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
private String description;
}