From 985838808440e1ddc6f3033f8f0da18ed5607c8c Mon Sep 17 00:00:00 2001 From: Dallas98 <40557804+Dallas98@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:28:25 +0800 Subject: [PATCH] feat: Refactor dataset file pagination and enhance retrieval functionality with new request structure #98 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Enhance knowledge base management with collection renaming, imp… * feat: Update Milvus integration with new API, enhance collection mana… * Merge branch 'refs/heads/main' into dev * feat: Refactor dataset file pagination and enhance retrieval function… * Merge branch 'main' into dev --- backend/pom.xml | 7 + .../application/CleaningTaskService.java | 12 +- .../DatasetFileApplicationService.java | 17 +- .../repository/DatasetFileRepository.java | 5 +- .../impl/DatasetFileRepositoryImpl.java | 13 +- .../rest/DatasetFileController.java | 37 ++-- .../application/KnowledgeBaseService.java | 89 ++++++++- .../rag/indexer/domain/model/RagChunk.java | 9 +- .../rag/indexer/domain/model/RagFile.java | 2 + .../infrastructure/event/RagEtlService.java | 38 ++-- .../infrastructure/milvus/MilvusService.java | 181 ++++++++++++++++-- .../interfaces/KnowledgeBaseController.java | 17 +- .../indexer/interfaces/dto/AddFilesReq.java | 2 +- .../indexer/interfaces/dto/ProcessType.java | 21 +- .../indexer/interfaces/dto/RetrieveReq.java | 21 ++ .../exception/KnowledgeBaseErrorCode.java | 7 +- .../common/interfaces/PagedResponse.java | 19 +- .../common/interfaces/PagingQuery.java | 7 + scripts/db/rag-management-init.sql | 1 + 19 files changed, 399 insertions(+), 106 deletions(-) create mode 100644 backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/RetrieveReq.java diff --git a/backend/pom.xml b/backend/pom.xml index 97739ea..7ca22c3 100644 --- a/backend/pom.xml +++ b/backend/pom.xml @@ -144,6 +144,13 @@ poi ${poi.version} + + + io.milvus + milvus-sdk-java + 2.6.6 + + diff --git a/backend/services/data-cleaning-service/src/main/java/com/datamate/cleaning/application/CleaningTaskService.java b/backend/services/data-cleaning-service/src/main/java/com/datamate/cleaning/application/CleaningTaskService.java index 7b02817..7e730de 100644 --- a/backend/services/data-cleaning-service/src/main/java/com/datamate/cleaning/application/CleaningTaskService.java +++ b/backend/services/data-cleaning-service/src/main/java/com/datamate/cleaning/application/CleaningTaskService.java @@ -4,16 +4,16 @@ package com.datamate.cleaning.application; import com.datamate.cleaning.application.scheduler.CleaningTaskScheduler; import com.datamate.cleaning.common.enums.CleaningTaskStatusEnum; import com.datamate.cleaning.common.enums.ExecutorType; - import com.datamate.cleaning.domain.model.TaskProcess; import com.datamate.cleaning.domain.repository.CleaningResultRepository; import com.datamate.cleaning.domain.repository.CleaningTaskRepository; import com.datamate.cleaning.domain.repository.OperatorInstanceRepository; - import com.datamate.cleaning.infrastructure.validator.CleanTaskValidator; import com.datamate.cleaning.interfaces.dto.*; import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.SystemErrorCode; +import com.datamate.common.interfaces.PagedResponse; +import com.datamate.common.interfaces.PagingQuery; import com.datamate.datamanagement.application.DatasetApplicationService; import com.datamate.datamanagement.application.DatasetFileApplicationService; import com.datamate.datamanagement.common.enums.DatasetType; @@ -26,8 +26,6 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.PageRequest; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.yaml.snakeyaml.DumperOptions; @@ -208,10 +206,10 @@ public class CleaningTaskService { private void scanDataset(String taskId, String srcDatasetId) { int pageNumber = 0; int pageSize = 500; - PageRequest pageRequest = PageRequest.of(pageNumber, pageSize); - Page datasetFiles; + PagingQuery pageRequest = new PagingQuery(pageNumber, pageSize); + PagedResponse datasetFiles; do { - datasetFiles = datasetFileService.getDatasetFiles(srcDatasetId, null, null, pageRequest); + datasetFiles = datasetFileService.getDatasetFiles(srcDatasetId, null, null,null, pageRequest); if (datasetFiles.getContent().isEmpty()) { break; } diff --git a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/DatasetFileApplicationService.java b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/DatasetFileApplicationService.java index 390f713..1b7e1ad 100644 --- a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/DatasetFileApplicationService.java +++ b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/application/DatasetFileApplicationService.java @@ -1,5 +1,6 @@ package com.datamate.datamanagement.application; +import com.baomidou.mybatisplus.core.metadata.IPage; import com.datamate.common.domain.model.ChunkUploadPreRequest; import com.datamate.common.domain.model.FileUploadResult; import com.datamate.common.domain.service.FileService; @@ -7,6 +8,8 @@ import com.datamate.common.domain.utils.AnalyzerUtils; import com.datamate.common.infrastructure.exception.BusinessAssert; import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.SystemErrorCode; +import com.datamate.common.interfaces.PagedResponse; +import com.datamate.common.interfaces.PagingQuery; import com.datamate.datamanagement.common.enums.DuplicateMethod; import com.datamate.datamanagement.domain.contants.DatasetConstant; import com.datamate.datamanagement.domain.model.dataset.Dataset; @@ -23,14 +26,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; -import org.apache.ibatis.session.RowBounds; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.core.io.UrlResource; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.PageImpl; -import org.springframework.data.domain.Pageable; import org.springframework.http.HttpHeaders; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -82,12 +81,10 @@ public class DatasetFileApplicationService { * 获取数据集文件列表 */ @Transactional(readOnly = true) - public Page getDatasetFiles(String datasetId, String fileType, - String status, Pageable pageable) { - RowBounds bounds = new RowBounds(pageable.getPageNumber() * pageable.getPageSize(), pageable.getPageSize()); - List content = datasetFileRepository.findByCriteria(datasetId, fileType, status, bounds); - long total = content.size() < pageable.getPageSize() && pageable.getPageNumber() == 0 ? content.size() : content.size() + (long) pageable.getPageNumber() * pageable.getPageSize(); - return new PageImpl<>(content, pageable, total); + public PagedResponse getDatasetFiles(String datasetId, String fileType, String status, String name, PagingQuery pagingQuery) { + IPage page = new com.baomidou.mybatisplus.extension.plugins.pagination.Page<>(pagingQuery.getPage(), pagingQuery.getSize()); + IPage files = datasetFileRepository.findByCriteria(datasetId, fileType, status, name, page); + return PagedResponse.of(files); } /** diff --git a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/DatasetFileRepository.java b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/DatasetFileRepository.java index de9880f..1ed09bc 100644 --- a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/DatasetFileRepository.java +++ b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/DatasetFileRepository.java @@ -1,8 +1,8 @@ package com.datamate.datamanagement.infrastructure.persistence.repository; +import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.repository.IRepository; import com.datamate.datamanagement.domain.model.dataset.DatasetFile; -import org.apache.ibatis.session.RowBounds; import java.util.List; @@ -23,5 +23,6 @@ public interface DatasetFileRepository extends IRepository { DatasetFile findByDatasetIdAndFileName(String datasetId, String fileName); - List findByCriteria(String datasetId, String fileType, String status, RowBounds bounds); + IPage findByCriteria(String datasetId, String fileType, String status, String name, + IPage page); } diff --git a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/impl/DatasetFileRepositoryImpl.java b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/impl/DatasetFileRepositoryImpl.java index 277e576..be7cbd8 100644 --- a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/impl/DatasetFileRepositoryImpl.java +++ b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/infrastructure/persistence/repository/impl/DatasetFileRepositoryImpl.java @@ -1,13 +1,14 @@ package com.datamate.datamanagement.infrastructure.persistence.repository.impl; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.repository.CrudRepository; import com.datamate.datamanagement.domain.model.dataset.DatasetFile; import com.datamate.datamanagement.infrastructure.persistence.mapper.DatasetFileMapper; import com.datamate.datamanagement.infrastructure.persistence.repository.DatasetFileRepository; import lombok.RequiredArgsConstructor; -import org.apache.ibatis.session.RowBounds; import org.springframework.stereotype.Repository; +import org.springframework.util.StringUtils; import java.util.List; @@ -47,8 +48,12 @@ public class DatasetFileRepositoryImpl extends CrudRepository findByCriteria(String datasetId, String fileType, String status, RowBounds bounds) { - return datasetFileMapper.findByCriteria(datasetId, fileType, status, bounds); + public IPage findByCriteria(String datasetId, String fileType, String status, String name, + IPage page) { + return datasetFileMapper.selectPage(page, new LambdaQueryWrapper() + .eq(DatasetFile::getDatasetId, datasetId) + .eq(StringUtils.hasText(fileType), DatasetFile::getFileType, fileType) + .eq(StringUtils.hasText(status), DatasetFile::getStatus, status) + .like(StringUtils.hasText(name), DatasetFile::getFileName, name)); } } diff --git a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/interfaces/rest/DatasetFileController.java b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/interfaces/rest/DatasetFileController.java index 36c628e..e50eabc 100644 --- a/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/interfaces/rest/DatasetFileController.java +++ b/backend/services/data-management-service/src/main/java/com/datamate/datamanagement/interfaces/rest/DatasetFileController.java @@ -3,18 +3,20 @@ package com.datamate.datamanagement.interfaces.rest; import com.datamate.common.infrastructure.common.IgnoreResponseWrap; import com.datamate.common.infrastructure.common.Response; import com.datamate.common.infrastructure.exception.SystemErrorCode; +import com.datamate.common.interfaces.PagedResponse; +import com.datamate.common.interfaces.PagingQuery; import com.datamate.datamanagement.application.DatasetFileApplicationService; import com.datamate.datamanagement.domain.model.dataset.DatasetFile; import com.datamate.datamanagement.interfaces.converter.DatasetConverter; -import com.datamate.datamanagement.interfaces.dto.*; +import com.datamate.datamanagement.interfaces.dto.CopyFilesRequest; +import com.datamate.datamanagement.interfaces.dto.DatasetFileResponse; +import com.datamate.datamanagement.interfaces.dto.UploadFileRequest; +import com.datamate.datamanagement.interfaces.dto.UploadFilesPreRequest; import jakarta.servlet.http.HttpServletResponse; import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.io.Resource; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.PageRequest; -import org.springframework.data.domain.Pageable; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -22,7 +24,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; import java.util.List; -import java.util.stream.Collectors; /** * 数据集文件 REST 控制器(UUID 模式) @@ -40,29 +41,17 @@ public class DatasetFileController { } @GetMapping - public ResponseEntity> getDatasetFiles( + public Response> getDatasetFiles( @PathVariable("datasetId") String datasetId, @RequestParam(value = "page", required = false, defaultValue = "0") Integer page, @RequestParam(value = "size", required = false, defaultValue = "20") Integer size, @RequestParam(value = "fileType", required = false) String fileType, - @RequestParam(value = "status", required = false) String status) { - Pageable pageable = PageRequest.of(page != null ? page : 0, size != null ? size : 20); - - Page filesPage = datasetFileApplicationService.getDatasetFiles( - datasetId, fileType, status, pageable); - - PagedDatasetFileResponse response = new PagedDatasetFileResponse(); - response.setContent(filesPage.getContent().stream() - .map(DatasetConverter.INSTANCE::convertToResponse) - .collect(Collectors.toList())); - response.setPage(filesPage.getNumber()); - response.setSize(filesPage.getSize()); - response.setTotalElements((int) filesPage.getTotalElements()); - response.setTotalPages(filesPage.getTotalPages()); - response.setFirst(filesPage.isFirst()); - response.setLast(filesPage.isLast()); - - return ResponseEntity.ok(Response.ok(response)); + @RequestParam(value = "status", required = false) String status, + @RequestParam(value = "name", required = false) String name) { + PagingQuery pagingQuery = new PagingQuery(page, size); + PagedResponse filesPage = datasetFileApplicationService.getDatasetFiles( + datasetId, fileType, status, name, pagingQuery); + return Response.ok(filesPage); } @GetMapping("/{fileId}") diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java index 85f1e76..a1ca257 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/application/KnowledgeBaseService.java @@ -6,7 +6,9 @@ import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode; import com.datamate.common.interfaces.PagedResponse; import com.datamate.common.interfaces.PagingQuery; +import com.datamate.common.setting.domain.entity.ModelConfig; import com.datamate.common.setting.domain.repository.ModelConfigRepository; +import com.datamate.common.setting.infrastructure.client.ModelClient; import com.datamate.rag.indexer.domain.model.FileStatus; import com.datamate.rag.indexer.domain.model.KnowledgeBase; import com.datamate.rag.indexer.domain.model.RagChunk; @@ -16,8 +18,14 @@ import com.datamate.rag.indexer.domain.repository.RagFileRepository; import com.datamate.rag.indexer.infrastructure.event.DataInsertedEvent; import com.datamate.rag.indexer.infrastructure.milvus.MilvusService; import com.datamate.rag.indexer.interfaces.dto.*; -import io.milvus.param.collection.DropCollectionParam; -import io.milvus.param.dml.DeleteParam; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.RenameCollectionReq; +import io.milvus.v2.service.vector.request.DeleteReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; import lombok.RequiredArgsConstructor; import org.jetbrains.annotations.NotNull; import org.springframework.beans.BeanUtils; @@ -26,6 +34,7 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StringUtils; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -63,10 +72,15 @@ public class KnowledgeBaseService { * @param knowledgeBaseId 知识库 ID * @param request 知识库更新请求 */ + @Transactional(rollbackFor = Exception.class) public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) { KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); if (StringUtils.hasText(request.getName())) { + milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder() + .collectionName(knowledgeBase.getName()) + .newCollectionName(request.getName()) + .build()); knowledgeBase.setName(request.getName()); } if (StringUtils.hasText(request.getDescription())) { @@ -75,13 +89,19 @@ public class KnowledgeBaseService { knowledgeBaseRepository.updateById(knowledgeBase); } - @Transactional + + /** + * 删除知识库 + * + * @param knowledgeBaseId 知识库 ID + */ + @Transactional(rollbackFor = Exception.class) public void delete(String knowledgeBaseId) { KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); knowledgeBaseRepository.removeById(knowledgeBaseId); ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId); - milvusService.getMilvusClient().dropCollection(DropCollectionParam.newBuilder().withCollectionName(knowledgeBase.getName()).build()); + milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build()); } public KnowledgeBaseResp getById(String knowledgeBaseId) { @@ -147,14 +167,65 @@ public class KnowledgeBaseService { KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); ragFileRepository.removeByIds(request.getIds()); - milvusService.getMilvusClient().delete(DeleteParam.newBuilder() - .withCollectionName(knowledgeBase.getName()) - .withExpr("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]") + milvusService.getMilvusClient().delete(DeleteReq.builder() + .collectionName(knowledgeBase.getName()) + .filter("metadata[\"rag_file_id\"] in [" + org.apache.commons.lang3.StringUtils.join(request.getIds().stream().map(id -> "\"" + id + "\"").toArray(), ",") + "]") .build()); } public PagedResponse getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) { - IPage page = new Page<>(pagingQuery.getPage(), pagingQuery.getSize()); - return PagedResponse.of(page.getRecords(), page.getCurrent(), page.getTotal(), page.getPages()); + KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) + .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); + QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder() + .collectionName(knowledgeBase.getName()) + .filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"") + .outputFields(Collections.singletonList("*")) + .limit(Long.valueOf(pagingQuery.getSize())) + .offset((long) (pagingQuery.getPage() - 1) * pagingQuery.getSize()) + .build()); + List queryResults = results.getQueryResults(); + List ragChunks = queryResults.stream() + .map(QueryResp.QueryResult::getEntity) + .map(item -> new RagChunk( + item.get("id").toString(), + item.get("text").toString(), + item.get("metadata").toString() + )).toList(); + + // 获取总数 + QueryResp countResults = milvusService.getMilvusClient().query(QueryReq.builder() + .collectionName(knowledgeBase.getName()) + .filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"") + .outputFields(Collections.singletonList("count(*)")) + .build()); + + long totalCount = Long.parseLong(countResults.getQueryResults().getFirst().getEntity().get("count(*)").toString()); + return PagedResponse.of(ragChunks, pagingQuery.getPage(), totalCount, (int) Math.ceil((double) totalCount / pagingQuery.getSize())); + } + + /** + * 检索知识库内容 + * + * @param request 检索请求 + * @return 检索结果 + */ + public SearchResp retrieve(RetrieveReq request) { + KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst())) + .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); + ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()); + EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig); + Embedding embedding = embeddingModel.embed(request.getQuery()).content(); + SearchResp searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK()); + return searchResp; + +// request.getKnowledgeBaseIds().forEach(knowledgeId -> { +// KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeId)) +// .orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND)); +// ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()); +// EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig); +// Embedding embedding = embeddingModel.embed(request.getQuery()).content(); +// searchResp = milvusService.hybridSearch(knowledgeBase.getName(), request.getQuery(), embedding.vector(), request.getTopK()); +// }); +// return searchResp; } } \ No newline at end of file diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagChunk.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagChunk.java index 6a6b884..3dbf6bf 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagChunk.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagChunk.java @@ -6,5 +6,10 @@ package com.datamate.rag.indexer.domain.model; * @author dallas * @since 2025-10-29 */ -public class RagChunk { -} + +public record RagChunk( + String id, + String text, + String metadata +) { +} \ No newline at end of file diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagFile.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagFile.java index ec0445a..64a3d78 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagFile.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/domain/model/RagFile.java @@ -44,4 +44,6 @@ public class RagFile extends BaseEntity { private Map metadata; private FileStatus status; + + private String errMsg; } diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/event/RagEtlService.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/event/RagEtlService.java index 7c7822a..f1c9f1d 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/event/RagEtlService.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/event/RagEtlService.java @@ -9,7 +9,7 @@ import com.datamate.rag.indexer.domain.model.FileStatus; import com.datamate.rag.indexer.domain.model.RagFile; import com.datamate.rag.indexer.domain.repository.RagFileRepository; import com.datamate.rag.indexer.infrastructure.milvus.MilvusService; -import com.datamate.rag.indexer.interfaces.dto.ProcessType; +import com.datamate.rag.indexer.interfaces.dto.AddFilesReq; import com.google.common.collect.Lists; import dev.langchain4j.data.document.Document; import dev.langchain4j.data.document.DocumentParser; @@ -20,10 +20,7 @@ import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentPa import dev.langchain4j.data.document.parser.apache.poi.ApachePoiDocumentParser; import dev.langchain4j.data.document.parser.apache.tika.ApacheTikaDocumentParser; import dev.langchain4j.data.document.parser.markdown.MarkdownDocumentParser; -import dev.langchain4j.data.document.splitter.DocumentByLineSplitter; -import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter; -import dev.langchain4j.data.document.splitter.DocumentBySentenceSplitter; -import dev.langchain4j.data.document.splitter.DocumentByWordSplitter; +import dev.langchain4j.data.document.splitter.*; import dev.langchain4j.data.document.transformer.jsoup.HtmlToTextDocumentTransformer; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; @@ -85,6 +82,7 @@ public class RagEtlService { // 处理异常 log.error("Error processing RAG file: {}", ragFile.getFileId(), e); ragFile.setStatus(FileStatus.PROCESS_FAILED); + ragFile.setErrMsg(e.getMessage()); ragFileRepository.updateById(ragFile); } finally { SEMAPHORE.release(); @@ -109,7 +107,7 @@ public class RagEtlService { } document.metadata().put("rag_file_id", ragFile.getId()); // 使用文档分块器对文档进行分块 - DocumentSplitter splitter = documentSplitter(event.addFilesReq().getProcessType()); + DocumentSplitter splitter = documentSplitter(event.addFilesReq()); List split = splitter.split(document); // 更新分块数量 @@ -121,16 +119,19 @@ public class RagEtlService { EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(model); // 调用嵌入模型获取嵌入向量 + if (!milvusService.hasCollection(event.knowledgeBase().getName())) { + milvusService.createCollection(event.knowledgeBase().getName(), embeddingModel.dimension()); + } + Lists.partition(split, 20).forEach(partition -> { - List content = embeddingModel.embedAll(partition).content(); - // 存储嵌入向量到 Milvus - milvusService.embeddingStore(embeddingModel, event.knowledgeBase().getName()).addAll(content, partition); + List embeddings = embeddingModel.embedAll(partition).content(); + milvusService.addAll(event.knowledgeBase().getName(),partition, embeddings); }); } /** * 根据文件类型返回对应的文档解析器 - * + *x * @param fileType 文件类型 * @return 文档解析器 */ @@ -145,13 +146,14 @@ public class RagEtlService { }; } - public DocumentSplitter documentSplitter(ProcessType processType) { - return switch (processType) { - case PARAGRAPH_CHUNK -> new DocumentByParagraphSplitter(1000, 100); - case CHAPTER_CHUNK -> new DocumentByLineSplitter(1000, 100); - case CUSTOM_SEPARATOR_CHUNK -> new DocumentBySentenceSplitter(1000, 100); - case LENGTH_CHUNK -> new DocumentByWordSplitter(1000, 100); - case DEFAULT_CHUNK -> new DocumentByLineSplitter(1000, 100); + public DocumentSplitter documentSplitter(AddFilesReq req) { + return switch (req.getProcessType()) { + case PARAGRAPH_CHUNK -> new DocumentByParagraphSplitter(req.getChunkSize(), req.getOverlapSize()); + case SENTENCE_CHUNK -> new DocumentBySentenceSplitter(req.getChunkSize(), req.getOverlapSize()); + case LENGTH_CHUNK -> new DocumentByCharacterSplitter(req.getChunkSize(), req.getOverlapSize()); + case DEFAULT_CHUNK -> new DocumentByWordSplitter(req.getChunkSize(), req.getOverlapSize()); + case CUSTOM_SEPARATOR_CHUNK -> + new DocumentByRegexSplitter(req.getDelimiter(), "", req.getChunkSize(), req.getOverlapSize()); }; } -} +} \ No newline at end of file diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java index 9fc1b56..5283031 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/infrastructure/milvus/MilvusService.java @@ -1,16 +1,34 @@ package com.datamate.rag.indexer.infrastructure.milvus; +import com.google.gson.*; +import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; -import io.milvus.client.MilvusClient; -import io.milvus.client.MilvusServiceClient; -import io.milvus.param.ConnectParam; +import io.milvus.common.clientenum.FunctionType; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import io.milvus.v2.service.vector.request.AnnSearchReq; +import io.milvus.v2.service.vector.request.HybridSearchReq; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.data.BaseVector; +import io.milvus.v2.service.vector.request.data.EmbeddedText; +import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.response.SearchResp; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; +import java.util.*; + +import static dev.langchain4j.internal.Utils.randomUUID; + /** * Milvus 服务类 * @@ -24,28 +42,38 @@ public class MilvusService { private String milvusHost; @Value("${datamate.rag.milvus-port:19530}") private int milvusPort; + @Value("${datamate.rag.milvus-uri:http://milvus-standalone:19530}") + private String milvusUri; + private static final Gson GSON; - private volatile MilvusClient milvusClient; + static { + GSON = (new GsonBuilder()).setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE).create(); + } + + private volatile MilvusClientV2 milvusClient; public EmbeddingStore embeddingStore(EmbeddingModel embeddingModel, String knowledgeBaseName) { return MilvusEmbeddingStore.builder() - .host(milvusHost) - .port(milvusPort) + .uri(milvusUri) .collectionName(knowledgeBaseName) .dimension(embeddingModel.dimension()) .build(); } - public MilvusClient getMilvusClient() { + /** + * 单例模式获取 Milvus 客户端,不依赖 Spring 容器 + * + * @return MilvusClient + */ + public MilvusClientV2 getMilvusClient() { if (milvusClient == null) { synchronized (this) { if (milvusClient == null) { try { - ConnectParam connectParam = ConnectParam.newBuilder() - .withHost(milvusHost) - .withPort(milvusPort) + ConnectConfig connectConfig = ConnectConfig.builder() + .uri(milvusUri) .build(); - milvusClient = new MilvusServiceClient(connectParam); + milvusClient = new MilvusClientV2(connectConfig); log.info("Milvus client connected successfully"); } catch (Exception e) { log.error("Milvus client connection failed: {}", e.getMessage()); @@ -56,4 +84,135 @@ public class MilvusService { } return milvusClient; } + + + public boolean hasCollection(String collectionName) { + HasCollectionReq request = HasCollectionReq.builder().collectionName(collectionName).build(); + return getMilvusClient().hasCollection(request); + } + + public void createCollection(String collectionName, int dimension) { + CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder() + .build(); + schema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.VarChar) + .maxLength(36) + .isPrimaryKey(true) + .autoID(false) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("text") + .dataType(DataType.VarChar) + .maxLength(65535) + .enableAnalyzer(true) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("metadata") + .dataType(DataType.JSON) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("vector") + .dataType(DataType.FloatVector) + .dimension(dimension) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("sparse") + .dataType(DataType.SparseFloatVector) + .build()); + schema.addFunction(CreateCollectionReq.Function.builder() + .functionType(FunctionType.BM25) + .name("text_bm25_emb") + .inputFieldNames(Collections.singletonList("text")) + .outputFieldNames(Collections.singletonList("sparse")) + .build()); + + Map params = new HashMap<>(); + params.put("inverted_index_algo", "DAAT_MAXSCORE"); + params.put("bm25_k1", 1.2); + params.put("bm25_b", 0.75); + + List indexes = new ArrayList<>(); + indexes.add(IndexParam.builder() + .fieldName("sparse") + .indexType(IndexParam.IndexType.SPARSE_INVERTED_INDEX) + .metricType(IndexParam.MetricType.BM25) + .extraParams(params) + .build()); + indexes.add(IndexParam.builder() + .fieldName("vector") + .indexType(IndexParam.IndexType.FLAT) + .metricType(IndexParam.MetricType.COSINE) + .extraParams(Map.of()) + .build()); + + CreateCollectionReq createCollectionReq = CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(schema) + .indexParams(indexes) + .build(); + this.getMilvusClient().createCollection(createCollectionReq); + } + + public void addAll(String collectionName, List textSegments, List embeddings) { + List data = convertToJsonObjects(textSegments, embeddings); + InsertReq insertReq = InsertReq.builder() + .collectionName(collectionName) + .data(data) + .build(); + this.getMilvusClient().insert(insertReq); + } + + public List convertToJsonObjects(List textSegments, List embeddings) { + List data = new ArrayList<>(); + for (int i = 0; i < textSegments.size(); i++) { + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("id", randomUUID()); + jsonObject.addProperty("text", textSegments.get(i).text()); + jsonObject.add("metadata", GSON.toJsonTree(textSegments.get(i).metadata().toMap()).getAsJsonObject()); + JsonArray vectorArray = new JsonArray(); + for (float f : embeddings.get(i).vector()) { + vectorArray.add(f); + } + jsonObject.add("vector", vectorArray); + data.add(jsonObject); + } + return data; + } + + public SearchResp hybridSearch(String collectionName, String query, float[] queryDense, int topK) { + List queryTexts = Collections.singletonList(new EmbeddedText(query)); + List queryVectors = Collections.singletonList(new FloatVec(queryDense)); + + List searchRequests = new ArrayList<>(); + searchRequests.add(AnnSearchReq.builder() + .vectorFieldName("vector") + .vectors(queryVectors) + .params("{\"nprobe\": 10}") + .topK(topK) + .build()); + searchRequests.add(AnnSearchReq.builder() + .vectorFieldName("sparse") + .vectors(queryTexts) + .params("{\"drop_ratio_search\": 0.2}") + .topK(topK) + .build()); + CreateCollectionReq.Function ranker = CreateCollectionReq.Function.builder() + .name("rrf") + .functionType(FunctionType.RERANK) + .param("reranker", "rrf") + .param("k", "60") + .build(); + + + + SearchResp searchResp = this.getMilvusClient().hybridSearch(HybridSearchReq.builder() + .collectionName(collectionName) + .searchRequests(searchRequests) + .ranker(ranker) + .outFields(Arrays.asList("id", "text", "metadata")) + .topK(topK) + .build()); + return searchResp; + } } diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java index 63266db..ef95de9 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/KnowledgeBaseController.java @@ -6,12 +6,12 @@ import com.datamate.rag.indexer.application.KnowledgeBaseService; import com.datamate.rag.indexer.domain.model.RagChunk; import com.datamate.rag.indexer.domain.model.RagFile; import com.datamate.rag.indexer.interfaces.dto.*; +import io.milvus.v2.service.vector.response.SearchResp; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.*; - /** * 知识库控制器 * @@ -124,8 +124,19 @@ public class KnowledgeBaseController { */ @GetMapping("/{knowledgeBaseId}/files/{ragFileId}") public PagedResponse getChunks(@PathVariable("knowledgeBaseId") String knowledgeBaseId, - @PathVariable("ragFileId") String ragFileId, - PagingQuery pagingQuery) { + @PathVariable("ragFileId") String ragFileId, + PagingQuery pagingQuery) { return knowledgeBaseService.getChunks(knowledgeBaseId, ragFileId, pagingQuery); } + + /** + * 检索知识库内容 + * + * @param request 检索请求 + * @return 检索结果 + */ + @PostMapping("/retrieve") + public SearchResp retrieve(@RequestBody @Valid RetrieveReq request) { + return knowledgeBaseService.retrieve(request); + } } \ No newline at end of file diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/AddFilesReq.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/AddFilesReq.java index bb0f64d..464c9eb 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/AddFilesReq.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/AddFilesReq.java @@ -18,7 +18,7 @@ public class AddFilesReq { private ProcessType processType; private Integer chunkSize; private Integer overlapSize; - private String customSeparator; + private String delimiter; private List files; public record FileInfo(String id, String name) { diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/ProcessType.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/ProcessType.java index 7301163..a6933ec 100644 --- a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/ProcessType.java +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/ProcessType.java @@ -7,27 +7,28 @@ package com.datamate.rag.indexer.interfaces.dto; * @since 2025-10-29 */ public enum ProcessType { - /** - * 章节分块 - */ - CHAPTER_CHUNK, /** * 段落分块 */ PARAGRAPH_CHUNK, /** - * 按长度分块 + * 按句子分块 + */ + SENTENCE_CHUNK, + + /** + * 按长度分块,字符串分块 */ LENGTH_CHUNK, + /** + * 默认分块,按单词分块 + */ + DEFAULT_CHUNK, + /** * 自定义分割符分块 */ CUSTOM_SEPARATOR_CHUNK, - - /** - * 默认分块 - */ - DEFAULT_CHUNK, } diff --git a/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/RetrieveReq.java b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/RetrieveReq.java new file mode 100644 index 0000000..6362523 --- /dev/null +++ b/backend/services/rag-indexer-service/src/main/java/com/datamate/rag/indexer/interfaces/dto/RetrieveReq.java @@ -0,0 +1,21 @@ +package com.datamate.rag.indexer.interfaces.dto; + +import lombok.Getter; +import lombok.Setter; + +import java.util.List; + +/** + * 检索请求 + * + * @author dallas + * @since 2025-11-20 + */ +@Getter +@Setter +public class RetrieveReq { + private String query; + private int topK; + private Float threshold; + private List knowledgeBaseIds; +} diff --git a/backend/shared/domain-common/src/main/java/com/datamate/common/infrastructure/exception/KnowledgeBaseErrorCode.java b/backend/shared/domain-common/src/main/java/com/datamate/common/infrastructure/exception/KnowledgeBaseErrorCode.java index d56bdb9..58755ba 100644 --- a/backend/shared/domain-common/src/main/java/com/datamate/common/infrastructure/exception/KnowledgeBaseErrorCode.java +++ b/backend/shared/domain-common/src/main/java/com/datamate/common/infrastructure/exception/KnowledgeBaseErrorCode.java @@ -16,7 +16,12 @@ public enum KnowledgeBaseErrorCode implements ErrorCode { /** * 知识库不存在 */ - KNOWLEDGE_BASE_NOT_FOUND("knowledge.0001", "知识库不存在"); + KNOWLEDGE_BASE_NOT_FOUND("knowledge.0001", "知识库不存在"), + + /** + * 文件不存在 + */ + RAG_FILE_NOT_FOUND("knowledge.0002", "文件不存在"); private final String code; private final String message; diff --git a/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagedResponse.java b/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagedResponse.java index 4a9647a..17270d9 100644 --- a/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagedResponse.java +++ b/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagedResponse.java @@ -1,6 +1,6 @@ package com.datamate.common.interfaces; -import lombok.AllArgsConstructor; +import com.baomidou.mybatisplus.core.metadata.IPage; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; @@ -10,9 +10,8 @@ import java.util.List; @Getter @Setter @NoArgsConstructor -@AllArgsConstructor -public class PagedResponse { - // 当前页码(从 0 开始) +public class PagedResponse { + // 当前页码(从 1 开始) private long page; // 每页数量 private long size; @@ -36,6 +35,14 @@ public class PagedResponse { this.content = content; } + public PagedResponse(long page, long size, long totalElements, long totalPages, List content) { + this.page = page; + this.size = size; + this.totalElements = totalElements; + this.totalPages = totalPages; + this.content = content; + } + public static PagedResponse of(List content) { return new PagedResponse<>(content); } @@ -43,4 +50,8 @@ public class PagedResponse { public static PagedResponse of(List content, long page, long totalElements, long totalPages) { return new PagedResponse<>(content, page, totalElements, totalPages); } + + public static PagedResponse of(IPage page) { + return new PagedResponse<>(page.getCurrent(), page.getSize(), page.getTotal(), page.getPages(), page.getRecords()); + } } diff --git a/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagingQuery.java b/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagingQuery.java index 5c646dd..66356c6 100644 --- a/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagingQuery.java +++ b/backend/shared/domain-common/src/main/java/com/datamate/common/interfaces/PagingQuery.java @@ -1,8 +1,10 @@ package com.datamate.common.interfaces; import lombok.Getter; +import lombok.NoArgsConstructor; @Getter +@NoArgsConstructor public class PagingQuery { /** * 页码,从0开始 @@ -28,4 +30,9 @@ public class PagingQuery { this.size = size; } } + + public PagingQuery(Integer page, Integer size) { + setPage(page); + setSize(size); + } } diff --git a/scripts/db/rag-management-init.sql b/scripts/db/rag-management-init.sql index 0e429ae..ebf0aac 100644 --- a/scripts/db/rag-management-init.sql +++ b/scripts/db/rag-management-init.sql @@ -22,6 +22,7 @@ create table if not exists t_rag_file chunk_count INT COMMENT '切片数', metadata JSON COMMENT '元数据', status VARCHAR(50) COMMENT '文件状态', + err_msg text NULL COMMENT '错误信息', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', created_by VARCHAR(255) COMMENT '创建者',