You've already forked DataMate
feat: Refactor knowledge base retrieval to return detailed search results and enhance API integration #108
This commit is contained in:
@@ -76,16 +76,14 @@ public class KnowledgeBaseService {
|
||||
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
if (StringUtils.hasText(request.getName())) {
|
||||
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
|
||||
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
|
||||
.collectionName(knowledgeBase.getName())
|
||||
.newCollectionName(request.getName())
|
||||
.build());
|
||||
knowledgeBase.setName(request.getName());
|
||||
}
|
||||
if (StringUtils.hasText(request.getDescription())) {
|
||||
knowledgeBase.setDescription(request.getDescription());
|
||||
}
|
||||
knowledgeBase.setDescription(request.getDescription());
|
||||
knowledgeBaseRepository.updateById(knowledgeBase);
|
||||
}
|
||||
|
||||
@@ -147,7 +145,7 @@ public class KnowledgeBaseService {
|
||||
RagFile ragFile = new RagFile();
|
||||
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
|
||||
ragFile.setFileId(fileInfo.id());
|
||||
ragFile.setFileName(fileInfo.name());
|
||||
ragFile.setFileName(fileInfo.fileName());
|
||||
ragFile.setStatus(FileStatus.UNPROCESSED);
|
||||
return ragFile;
|
||||
}).toList();
|
||||
@@ -209,23 +207,19 @@ public class KnowledgeBaseService {
|
||||
* @param request 检索请求
|
||||
* @return 检索结果
|
||||
*/
|
||||
public SearchResp retrieve(RetrieveReq request) {
|
||||
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
|
||||
return searchResp;
|
||||
List<SearchResp.SearchResult> searchResults = searchResp.getSearchResults().getFirst();
|
||||
|
||||
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
|
||||
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
|
||||
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
|
||||
// });
|
||||
// return searchResp;
|
||||
searchResults.forEach(item -> {
|
||||
String metadata = item.getEntity().get("metadata").toString();
|
||||
item.getEntity().put("metadata", metadata);
|
||||
});
|
||||
return searchResults;
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import io.milvus.v2.service.collection.request.AddFieldReq;
|
||||
import io.milvus.v2.service.collection.request.CreateCollectionReq;
|
||||
import io.milvus.v2.service.collection.request.HasCollectionReq;
|
||||
import io.milvus.v2.service.vector.request.AnnSearchReq;
|
||||
import io.milvus.v2.service.vector.request.FunctionScore;
|
||||
import io.milvus.v2.service.vector.request.HybridSearchReq;
|
||||
import io.milvus.v2.service.vector.request.InsertReq;
|
||||
import io.milvus.v2.service.vector.request.data.BaseVector;
|
||||
@@ -197,21 +198,26 @@ public class MilvusService {
|
||||
.params("{\"drop_ratio_search\": 0.2}")
|
||||
.topK(topK)
|
||||
.build());
|
||||
|
||||
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
|
||||
.name("rrf")
|
||||
.name("weight")
|
||||
.functionType(FunctionType.RERANK)
|
||||
.param("reranker", "rrf")
|
||||
.param("k", "60")
|
||||
.param("reranker", "weighted")
|
||||
.param("weights", "[0.1, 0.9]")
|
||||
.param("norm_score", "true")
|
||||
.build();
|
||||
|
||||
FunctionScore functionScore = FunctionScore.builder()
|
||||
.functions(Collections.singletonList(ranker))
|
||||
.build();
|
||||
|
||||
|
||||
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
|
||||
.collectionName(collectionName)
|
||||
.searchRequests(searchRequests)
|
||||
.ranker(ranker)
|
||||
.functionScore(functionScore)
|
||||
.outFields(Arrays.asList("id", "text", "metadata"))
|
||||
.topK(topK)
|
||||
.limit(topK)
|
||||
.build());
|
||||
return searchResp;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import jakarta.validation.Valid;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* 知识库控制器
|
||||
@@ -136,7 +138,7 @@ public class KnowledgeBaseController {
|
||||
* @return 检索结果
|
||||
*/
|
||||
@PostMapping("/retrieve")
|
||||
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
|
||||
public List<SearchResp.SearchResult> retrieve(@RequestBody @Valid RetrieveReq request) {
|
||||
return knowledgeBaseService.retrieve(request);
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,6 @@ public class AddFilesReq {
|
||||
private String delimiter;
|
||||
private List<FileInfo> files;
|
||||
|
||||
public record FileInfo(String id, String name) {
|
||||
public record FileInfo(String id, String fileName) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,12 +20,12 @@ public class KnowledgeBaseCreateReq {
|
||||
*/
|
||||
@NotEmpty(message = "知识库名称不能为空")
|
||||
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
|
||||
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
private String name;
|
||||
/**
|
||||
* 知识库描述
|
||||
*/
|
||||
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
|
||||
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
|
||||
private String description;
|
||||
|
||||
/**
|
||||
|
||||
@@ -20,11 +20,11 @@ public class KnowledgeBaseUpdateReq {
|
||||
*/
|
||||
@NotEmpty(message = "知识库名称不能为空")
|
||||
@Size(min = 1, max = 255, message = "知识库名称长度必须在 1 到 255 之间")
|
||||
@Pattern(regexp = "^[a-zA-Z0-9_]+$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
@Pattern(regexp = "^[a-zA-Z][a-zA-Z0-9_]*$", message = "知识库名称只能包含字母、数字和下划线")
|
||||
private String name;
|
||||
/**
|
||||
* 知识库描述
|
||||
*/
|
||||
@Size(min = 1, max = 512, message = "知识库描述长度必须在 1 到 512 之间")
|
||||
@Size(max = 512, message = "知识库描述长度必须在 0 到 512 之间")
|
||||
private String description;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user