feat: Refactor dataset file pagination and enhance retrieval functionality with new request structure #98

* feat: Enhance knowledge base management with collection renaming, imp…

* feat: Update Milvus integration with new API, enhance collection mana…

* Merge branch 'refs/heads/main' into dev

* feat: Refactor dataset file pagination and enhance retrieval function…

* Merge branch 'main' into dev
This commit is contained in:
Dallas98
2025-11-21 17:28:25 +08:00
committed by GitHub
parent 536ef9f556
commit 9858388084
19 changed files with 399 additions and 106 deletions

View File

@@ -6,7 +6,9 @@ import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.common.interfaces.PagingQuery;
import com.datamate.common.setting.domain.entity.ModelConfig;
import com.datamate.common.setting.domain.repository.ModelConfigRepository;
import com.datamate.common.setting.infrastructure.client.ModelClient;
import com.datamate.rag.indexer.domain.model.FileStatus;
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
import com.datamate.rag.indexer.domain.model.RagChunk;
@@ -16,8 +18,14 @@ import com.datamate.rag.indexer.domain.repository.RagFileRepository;
import com.datamate.rag.indexer.infrastructure.event.DataInsertedEvent;
import com.datamate.rag.indexer.infrastructure.milvus.MilvusService;
import com.datamate.rag.indexer.interfaces.dto.*;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.dml.DeleteParam;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.collection.request.RenameCollectionReq;
import io.milvus.v2.service.vector.request.DeleteReq;
import io.milvus.v2.service.vector.request.QueryReq;
import io.milvus.v2.service.vector.response.QueryResp;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.RequiredArgsConstructor;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeanUtils;
@@ -26,6 +34,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
@@ -63,10 +72,15 @@ public class KnowledgeBaseService {
* @param knowledgeBaseId 知识库 ID
* @param request 知识库更新请求
*/
@Transactional(rollbackFor = Exception.class)
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
if (StringUtils.hasText(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName())
.newCollectionName(request.getName())
.build());
knowledgeBase.setName(request.getName());
}
if (StringUtils.hasText(request.getDescription())) {
@@ -75,13 +89,19 @@ public class KnowledgeBaseService {
knowledgeBaseRepository.updateById(knowledgeBase);
}
@Transactional
/**
* 删除知识库
*
* @param knowledgeBaseId 知识库 ID
*/
@Transactional(rollbackFor = Exception.class)
public void delete(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
knowledgeBaseRepository.removeById(knowledgeBaseId);
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
milvusService.getMilvusClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(knowledgeBase.getName()).build());
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
}
public KnowledgeBaseResp getById(String knowledgeBaseId) {
@@ -147,14 +167,65 @@ public class KnowledgeBaseService {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ragFileRepository.removeByIds(request.getIds());
milvusService.getMilvusClient().delete(DeleteParam.newBuilder()
.withCollectionName(knowledgeBase.getName())
.withExpr("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]")
milvusService.getMilvusClient().delete(DeleteReq.builder()
.collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]")
.build());
}
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
IPage<RagChunk> page = new Page<>(pagingQuery.getPage(), pagingQuery.getSize());
return PagedResponse.of(page.getRecords(), page.getCurrent(), page.getTotal(), page.getPages());
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
.outputFields(Collections.singletonList("*"))
.limit(Long.valueOf(pagingQuery.getSize()))
.offset((long) (pagingQuery.getPage() - 1) * pagingQuery.getSize())
.build());
List<QueryResp.QueryResult> queryResults = results.getQueryResults();
List<RagChunk> ragChunks = queryResults.stream()
.map(QueryResp.QueryResult::getEntity)
.map(item -> new RagChunk(
item.get("id").toString(),
item.get("text").toString(),
item.get("metadata").toString()
)).toList();
// 获取总数
QueryResp countResults = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
.outputFields(Collections.singletonList("count(*)"))
.build());
long totalCount = Long.parseLong(countResults.getQueryResults().getFirst().getEntity().get("count(*)").toString());
return PagedResponse.of(ragChunks, pagingQuery.getPage(), totalCount, (int) Math.ceil((double) totalCount / pagingQuery.getSize()));
}
/**
* 检索知识库内容
*
* @param request 检索请求
* @return 检索结果
*/
public SearchResp retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
return searchResp;
// request.getKnowledgeBaseIds().forEach(knowledgeId -> {
// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId))
// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
// Embedding embedding = embeddingModel.embed(request.getQuery()).content();
// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK());
// });
// return searchResp;
}
}

View File

@@ -6,5 +6,10 @@ package com.datamate.rag.indexer.domain.model;
* @author dallas
* @since 2025-10-29
*/
public class RagChunk {
}
public record RagChunk(
String id,
String text,
String metadata
) {
}

View File

@@ -44,4 +44,6 @@ public class RagFile extends BaseEntity<String> {
private Map<String, Object> metadata;
private FileStatus status;
private String errMsg;
}

View File

@@ -9,7 +9,7 @@ import com.datamate.rag.indexer.domain.model.FileStatus;
import com.datamate.rag.indexer.domain.model.RagFile;
import com.datamate.rag.indexer.domain.repository.RagFileRepository;
import com.datamate.rag.indexer.infrastructure.milvus.MilvusService;
import com.datamate.rag.indexer.interfaces.dto.ProcessType;
import com.datamate.rag.indexer.interfaces.dto.AddFilesReq;
import com.google.common.collect.Lists;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentParser;
@@ -20,10 +20,7 @@ import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentPa
import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser;
import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser;
import dev.langchain4j.data.document.parser.markdown.MarkdownDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentByLineSplitter;
import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter;
import dev.langchain4j.data.document.splitter.DocumentBySentenceSplitter;
import dev.langchain4j.data.document.splitter.DocumentByWordSplitter;
import dev.langchain4j.data.document.splitter.*;
import dev.langchain4j.data.document.transformer.jsoup.HtmlToTextDocumentTransformer;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
@@ -85,6 +82,7 @@ public class RagEtlService {
// 处理异常
log.error("Error processing RAG file: {}", ragFile.getFileId(), e);
ragFile.setStatus(FileStatus.PROCESS_FAILED);
ragFile.setErrMsg(e.getMessage());
ragFileRepository.updateById(ragFile);
} finally {
SEMAPHORE.release();
@@ -109,7 +107,7 @@ public class RagEtlService {
}
document.metadata().put("rag_file_id", ragFile.getId());
// 使用文档分块器对文档进行分块
DocumentSplitter splitter = documentSplitter(event.addFilesReq().getProcessType());
DocumentSplitter splitter = documentSplitter(event.addFilesReq());
List<TextSegment> split = splitter.split(document);
// 更新分块数量
@@ -121,16 +119,19 @@ public class RagEtlService {
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(model);
// 调用嵌入模型获取嵌入向量
if (!milvusService.hasCollection(event.knowledgeBase().getName())) {
milvusService.createCollection(event.knowledgeBase().getName(), embeddingModel.dimension());
}
Lists.partition(split, 20).forEach(partition -> {
List<Embedding> content = embeddingModel.embedAll(partition).content();
// 存储嵌入向量到 Milvus
milvusService.embeddingStore(embeddingModel, event.knowledgeBase().getName()).addAll(content, partition);
List<Embedding> embeddings = embeddingModel.embedAll(partition).content();
milvusService.addAll(event.knowledgeBase().getName(),partition, embeddings);
});
}
/**
* 根据文件类型返回对应的文档解析器
*
*x
* @param fileType 文件类型
* @return 文档解析器
*/
@@ -145,13 +146,14 @@ public class RagEtlService {
};
}
public DocumentSplitter documentSplitter(ProcessType processType) {
return switch (processType) {
case PARAGRAPH_CHUNK -> new DocumentByParagraphSplitter(1000, 100);
case CHAPTER_CHUNK -> new DocumentByLineSplitter(1000, 100);
case CUSTOM_SEPARATOR_CHUNK -> new DocumentBySentenceSplitter(1000, 100);
case LENGTH_CHUNK -> new DocumentByWordSplitter(1000, 100);
case DEFAULT_CHUNK -> new DocumentByLineSplitter(1000, 100);
public DocumentSplitter documentSplitter(AddFilesReq req) {
return switch (req.getProcessType()) {
case PARAGRAPH_CHUNK -> new DocumentByParagraphSplitter(req.getChunkSize(), req.getOverlapSize());
case SENTENCE_CHUNK -> new DocumentBySentenceSplitter(req.getChunkSize(), req.getOverlapSize());
case LENGTH_CHUNK -> new DocumentByCharacterSplitter(req.getChunkSize(), req.getOverlapSize());
case DEFAULT_CHUNK -> new DocumentByWordSplitter(req.getChunkSize(), req.getOverlapSize());
case CUSTOM_SEPARATOR_CHUNK ->
new DocumentByRegexSplitter(req.getDelimiter(), "", req.getChunkSize(), req.getOverlapSize());
};
}
}
}

View File

@@ -1,16 +1,34 @@
package com.datamate.rag.indexer.infrastructure.milvus;
import com.google.gson.*;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore;
import io.milvus.client.MilvusClient;
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.ConnectParam;
import io.milvus.common.clientenum.FunctionType;
import io.milvus.v2.client.ConnectConfig;
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.vector.request.AnnSearchReq;
import io.milvus.v2.service.vector.request.HybridSearchReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.data.BaseVector;
import io.milvus.v2.service.vector.request.data.EmbeddedText;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.util.*;
import static dev.langchain4j.internal.Utils.randomUUID;
/**
* Milvus 服务类
*
@@ -24,28 +42,38 @@ public class MilvusService {
private String milvusHost;
@Value("${datamate.rag.milvus-port:19530}")
private int milvusPort;
@Value("${datamate.rag.milvus-uri:http://milvus-standalone:19530}")
private String milvusUri;
private static final Gson GSON;
private volatile MilvusClient milvusClient;
static {
GSON = (new GsonBuilder()).setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE).create();
}
private volatile MilvusClientV2 milvusClient;
public EmbeddingStore<TextSegment> embeddingStore(EmbeddingModel embeddingModel, String knowledgeBaseName) {
return MilvusEmbeddingStore.builder()
.host(milvusHost)
.port(milvusPort)
.uri(milvusUri)
.collectionName(knowledgeBaseName)
.dimension(embeddingModel.dimension())
.build();
}
public MilvusClient getMilvusClient() {
/**
* 单例模式获取 Milvus 客户端,不依赖 Spring 容器
*
* @return MilvusClient
*/
public MilvusClientV2 getMilvusClient() {
if (milvusClient == null) {
synchronized (this) {
if (milvusClient == null) {
try {
ConnectParam connectParam = ConnectParam.newBuilder()
.withHost(milvusHost)
.withPort(milvusPort)
ConnectConfig connectConfig = ConnectConfig.builder()
.uri(milvusUri)
.build();
milvusClient = new MilvusServiceClient(connectParam);
milvusClient = new MilvusClientV2(connectConfig);
log.info("Milvus client connected successfully");
} catch (Exception e) {
log.error("Milvus client connection failed: {}", e.getMessage());
@@ -56,4 +84,135 @@ public class MilvusService {
}
return milvusClient;
}
public boolean hasCollection(String collectionName) {
HasCollectionReq request = HasCollectionReq.builder().collectionName(collectionName).build();
return getMilvusClient().hasCollection(request);
}
public void createCollection(String collectionName, int dimension) {
CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
.build();
schema.addField(AddFieldReq.builder()
.fieldName("id")
.dataType(DataType.VarChar)
.maxLength(36)
.isPrimaryKey(true)
.autoID(false)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("text")
.dataType(DataType.VarChar)
.maxLength(65535)
.enableAnalyzer(true)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("metadata")
.dataType(DataType.JSON)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("vector")
.dataType(DataType.FloatVector)
.dimension(dimension)
.build());
schema.addField(AddFieldReq.builder()
.fieldName("sparse")
.dataType(DataType.SparseFloatVector)
.build());
schema.addFunction(CreateCollectionReq.Function.builder()
.functionType(FunctionType.BM25)
.name("text_bm25_emb")
.inputFieldNames(Collections.singletonList("text"))
.outputFieldNames(Collections.singletonList("sparse"))
.build());
Map<String, Object> params = new HashMap<>();
params.put("inverted_index_algo", "DAAT_MAXSCORE");
params.put("bm25_k1", 1.2);
params.put("bm25_b", 0.75);
List<IndexParam> indexes = new ArrayList<>();
indexes.add(IndexParam.builder()
.fieldName("sparse")
.indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX)
.metricType(IndexParam.MetricType.BM25)
.extraParams(params)
.build());
indexes.add(IndexParam.builder()
.fieldName("vector")
.indexType(IndexParam.IndexType.FLAT)
.metricType(IndexParam.MetricType.COSINE)
.extraParams(Map.of())
.build());
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionName(collectionName)
.collectionSchema(schema)
.indexParams(indexes)
.build();
this.getMilvusClient().createCollection(createCollectionReq);
}
public void addAll(String collectionName, List<TextSegment> textSegments, List<Embedding> embeddings) {
List<JsonObject> data = convertToJsonObjects(textSegments, embeddings);
InsertReq insertReq = InsertReq.builder()
.collectionName(collectionName)
.data(data)
.build();
this.getMilvusClient().insert(insertReq);
}
public List<JsonObject> convertToJsonObjects(List<TextSegment> textSegments, List<Embedding> embeddings) {
List<JsonObject> data = new ArrayList<>();
for (int i = 0; i < textSegments.size(); i++) {
JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("id", randomUUID());
jsonObject.addProperty("text", textSegments.get(i).text());
jsonObject.add("metadata", GSON.toJsonTree(textSegments.get(i).metadata().toMap()).getAsJsonObject());
JsonArray vectorArray = new JsonArray();
for (float f : embeddings.get(i).vector()) {
vectorArray.add(f);
}
jsonObject.add("vector", vectorArray);
data.add(jsonObject);
}
return data;
}
public SearchResp hybridSearch(String collectionName, String query, float[] queryDense, int topK) {
List<BaseVector> queryTexts = Collections.singletonList(new EmbeddedText(query));
List<BaseVector> queryVectors = Collections.singletonList(new FloatVec(queryDense));
List<AnnSearchReq> searchRequests = new ArrayList<>();
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("vector")
.vectors(queryVectors)
.params("{\"nprobe\": 10}")
.topK(topK)
.build());
searchRequests.add(AnnSearchReq.builder()
.vectorFieldName("sparse")
.vectors(queryTexts)
.params("{\"drop_ratio_search\": 0.2}")
.topK(topK)
.build());
CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder()
.name("rrf")
.functionType(FunctionType.RERANK)
.param("reranker", "rrf")
.param("k", "60")
.build();
SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder()
.collectionName(collectionName)
.searchRequests(searchRequests)
.ranker(ranker)
.outFields(Arrays.asList("id", "text", "metadata"))
.topK(topK)
.build());
return searchResp;
}
}

View File

@@ -6,12 +6,12 @@ import com.datamate.rag.indexer.application.KnowledgeBaseService;
import com.datamate.rag.indexer.domain.model.RagChunk;
import com.datamate.rag.indexer.domain.model.RagFile;
import com.datamate.rag.indexer.interfaces.dto.*;
import io.milvus.v2.service.vector.response.SearchResp;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
/**
* 知识库控制器
*
@@ -124,8 +124,19 @@ public class KnowledgeBaseController {
*/
@GetMapping("/{knowledgeBaseId}/files/{ragFileId}")
public PagedResponse<RagChunk> getChunks(@PathVariable("knowledgeBaseId") String knowledgeBaseId,
@PathVariable("ragFileId") String ragFileId,
PagingQuery pagingQuery) {
@PathVariable("ragFileId") String ragFileId,
PagingQuery pagingQuery) {
return knowledgeBaseService.getChunks(knowledgeBaseId, ragFileId, pagingQuery);
}
/**
* 检索知识库内容
*
* @param request 检索请求
* @return 检索结果
*/
@PostMapping("/retrieve")
public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) {
return knowledgeBaseService.retrieve(request);
}
}

View File

@@ -18,7 +18,7 @@ public class AddFilesReq {
private ProcessType processType;
private Integer chunkSize;
private Integer overlapSize;
private String customSeparator;
private String delimiter;
private List<FileInfo> files;
public record FileInfo(String id, String name) {

View File

@@ -7,27 +7,28 @@ package com.datamate.rag.indexer.interfaces.dto;
* @since 2025-10-29
*/
public enum ProcessType {
/**
* 章节分块
*/
CHAPTER_CHUNK,
/**
* 段落分块
*/
PARAGRAPH_CHUNK,
/**
* 按长度分块
* 按句子分块
*/
SENTENCE_CHUNK,
/**
* 按长度分块,字符串分块
*/
LENGTH_CHUNK,
/**
* 默认分块,按单词分块
*/
DEFAULT_CHUNK,
/**
* 自定义分割符分块
*/
CUSTOM_SEPARATOR_CHUNK,
/**
* 默认分块
*/
DEFAULT_CHUNK,
}

View File

@@ -0,0 +1,21 @@
package com.datamate.rag.indexer.interfaces.dto;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
/**
* 检索请求
*
* @author dallas
* @since 2025-11-20
*/
@Getter
@Setter
public class RetrieveReq {
private String query;
private int topK;
private Float threshold;
private List<String> knowledgeBaseIds;
}