You've already forked DataMate
feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证 - 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证 - 修改DatasetRepository接口和实现类,增加按创建者过滤的方法 - 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法 - 在RAG索引器服务中添加知识库访问权限检查和作用域过滤 - 更新实体元对象处理器以使用请求用户上下文获取当前用户 - 在前端设置页面添加用户权限管理功能和角色权限控制 - 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.datamate.common.auth.application.ResourceAccessService;
|
||||
import com.datamate.common.domain.utils.ChunksSaver;
|
||||
import com.datamate.common.setting.application.SysParamApplicationService;
|
||||
import com.datamate.datamanagement.interfaces.dto.*;
|
||||
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
|
||||
private final CollectionTaskClient collectionTaskClient;
|
||||
private final DatasetFileApplicationService datasetFileApplicationService;
|
||||
private final SysParamApplicationService sysParamService;
|
||||
private final ResourceAccessService resourceAccessService;
|
||||
|
||||
@Value("${datamate.data-management.base-path:/dataset}")
|
||||
private String datasetBasePath;
|
||||
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
|
||||
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||
|
||||
if (StringUtils.hasText(updateDatasetRequest.getName())) {
|
||||
dataset.setName(updateDatasetRequest.getName());
|
||||
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
|
||||
public void deleteDataset(String datasetId) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||
long childCount = datasetRepository.countByParentId(datasetId);
|
||||
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
|
||||
datasetRepository.removeById(datasetId);
|
||||
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
|
||||
public Dataset getDataset(String datasetId) {
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
|
||||
dataset.setFiles(datasetFiles);
|
||||
applyVisibleFileCounts(Collections.singletonList(dataset));
|
||||
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
|
||||
@Transactional(readOnly = true)
|
||||
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
|
||||
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
|
||||
page = datasetRepository.findByCriteria(page, query);
|
||||
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
|
||||
String datasetPvcName = getDatasetPvcName();
|
||||
applyVisibleFileCounts(page.getRecords());
|
||||
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
|
||||
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
|
||||
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
|
||||
Dataset dataset = datasetRepository.getById(datasetId);
|
||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||
Set<String> sourceTags = normalizeTagNames(dataset.getTags());
|
||||
if (sourceTags.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
|
||||
SIMILAR_DATASET_CANDIDATE_MAX,
|
||||
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
|
||||
);
|
||||
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||
List<Dataset> candidates = datasetRepository.findSimilarByTags(
|
||||
new ArrayList<>(sourceTags),
|
||||
datasetId,
|
||||
candidateLimit
|
||||
candidateLimit,
|
||||
ownerFilterUserId
|
||||
);
|
||||
if (CollectionUtils.isEmpty(candidates)) {
|
||||
return Collections.emptyList();
|
||||
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
|
||||
if (dataset == null) {
|
||||
throw new IllegalArgumentException("Dataset not found: " + datasetId);
|
||||
}
|
||||
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||
|
||||
Map<String, Object> statistics = new HashMap<>();
|
||||
|
||||
@@ -485,7 +495,11 @@ public class DatasetApplicationService {
|
||||
* 获取所有数据集的汇总统计信息
|
||||
*/
|
||||
public AllDatasetStatisticsResponse getAllDatasetStatistics() {
|
||||
return datasetRepository.getAllDatasetStatistics();
|
||||
if (resourceAccessService.isAdmin()) {
|
||||
return datasetRepository.getAllDatasetStatistics();
|
||||
}
|
||||
String currentUserId = resourceAccessService.requireCurrentUserId();
|
||||
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.datamate.common.auth.application.ResourceAccessService;
|
||||
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||
import com.datamate.common.infrastructure.exception.CommonErrorCode;
|
||||
import com.datamate.common.interfaces.PagedResponse;
|
||||
@@ -40,6 +41,7 @@ import java.util.UUID;
|
||||
public class KnowledgeSetApplicationService {
|
||||
private final KnowledgeSetRepository knowledgeSetRepository;
|
||||
private final TagMapper tagMapper;
|
||||
private final ResourceAccessService resourceAccessService;
|
||||
|
||||
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
|
||||
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
|
||||
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
|
||||
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
|
||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
|
||||
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
|
||||
|
||||
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
|
||||
public void deleteKnowledgeSet(String setId) {
|
||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||
knowledgeSetRepository.removeById(setId);
|
||||
}
|
||||
|
||||
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
|
||||
public KnowledgeSet getKnowledgeSet(String setId) {
|
||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||
return knowledgeSet;
|
||||
}
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
|
||||
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
|
||||
page = knowledgeSetRepository.findByCriteria(page, query);
|
||||
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
|
||||
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
|
||||
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
|
||||
}
|
||||
|
||||
@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
|
||||
|
||||
AllDatasetStatisticsResponse getAllDatasetStatistics();
|
||||
|
||||
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query);
|
||||
AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
|
||||
|
||||
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
|
||||
|
||||
long countByParentId(String parentDatasetId);
|
||||
|
||||
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit);
|
||||
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
|
||||
}
|
||||
|
||||
@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
|
||||
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
|
||||
KnowledgeSet findByName(String name);
|
||||
|
||||
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query);
|
||||
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
|
||||
}
|
||||
|
||||
@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
||||
|
||||
|
||||
@Override
|
||||
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) {
|
||||
public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
|
||||
List<Dataset> datasets = lambdaQuery()
|
||||
.eq(Dataset::getCreatedBy, createdBy)
|
||||
.list();
|
||||
long totalFiles = datasets.stream()
|
||||
.map(Dataset::getFileCount)
|
||||
.filter(java.util.Objects::nonNull)
|
||||
.mapToLong(Long::longValue)
|
||||
.sum();
|
||||
long totalSize = datasets.stream()
|
||||
.map(Dataset::getSizeBytes)
|
||||
.filter(java.util.Objects::nonNull)
|
||||
.mapToLong(Long::longValue)
|
||||
.sum();
|
||||
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
|
||||
response.setTotalDatasets(datasets.size());
|
||||
response.setTotalFiles(totalFiles);
|
||||
response.setTotalSize(totalSize);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
|
||||
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
|
||||
.eq(query.getType() != null, Dataset::getDatasetType, query.getType())
|
||||
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus());
|
||||
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
|
||||
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
|
||||
|
||||
if (query.getParentDatasetId() != null) {
|
||||
if (StringUtils.isBlank(query.getParentDatasetId())) {
|
||||
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) {
|
||||
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
|
||||
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
||||
if (StringUtils.isNotBlank(excludedDatasetId)) {
|
||||
wrapper.ne(Dataset::getId, excludedDatasetId.trim());
|
||||
}
|
||||
if (StringUtils.isNotBlank(createdBy)) {
|
||||
wrapper.eq(Dataset::getCreatedBy, createdBy);
|
||||
}
|
||||
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
|
||||
wrapper.and(condition -> {
|
||||
boolean hasCondition = false;
|
||||
|
||||
@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
|
||||
}
|
||||
|
||||
@Override
|
||||
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) {
|
||||
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
|
||||
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
|
||||
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
|
||||
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
|
||||
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
|
||||
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
|
||||
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
|
||||
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
|
||||
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo());
|
||||
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
|
||||
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
|
||||
|
||||
if (StringUtils.isNotBlank(query.getKeyword())) {
|
||||
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())
|
||||
|
||||
@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.datamate.common.auth.application.ResourceAccessService;
|
||||
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
|
||||
import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
||||
import com.datamate.common.interfaces.PagedResponse;
|
||||
import com.datamate.common.interfaces.PagingQuery;
|
||||
import com.datamate.common.setting.domain.entity.ModelConfig;
|
||||
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final ModelConfigRepository modelConfigRepository;
|
||||
private final MilvusService milvusService;
|
||||
private final ResourceAccessService resourceAccessService;
|
||||
|
||||
/**
|
||||
* 创建知识库
|
||||
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
|
||||
*/
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
|
||||
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
|
||||
.collectionName(knowledgeBase.getName())
|
||||
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
|
||||
*/
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void delete(String knowledgeBaseId) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
knowledgeBaseRepository.removeById(knowledgeBaseId);
|
||||
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
|
||||
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
|
||||
}
|
||||
|
||||
public KnowledgeBaseResp getById(String knowledgeBaseId) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
|
||||
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
|
||||
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
|
||||
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
|
||||
|
||||
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
|
||||
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
|
||||
page = knowledgeBaseRepository.page(page, request);
|
||||
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
|
||||
|
||||
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
|
||||
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
|
||||
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
|
||||
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void addFiles(AddFilesReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId()))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
|
||||
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
|
||||
RagFile ragFile = new RagFile();
|
||||
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
|
||||
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
|
||||
}
|
||||
|
||||
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
|
||||
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
||||
request.setKnowledgeBaseId(knowledgeBaseId);
|
||||
page = ragFileRepository.page(page, request);
|
||||
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
|
||||
}
|
||||
|
||||
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
|
||||
boolean admin = resourceAccessService.isAdmin();
|
||||
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
|
||||
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
|
||||
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
|
||||
}
|
||||
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
||||
page = ragFileRepository.searchPage(page, request);
|
||||
page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
|
||||
List<RagFile> records = page.getRecords();
|
||||
if (records.isEmpty()) {
|
||||
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
|
||||
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
|
||||
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
ragFileRepository.removeByIds(request.getIds());
|
||||
milvusService.getMilvusClient().delete(DeleteReq.builder()
|
||||
.collectionName(knowledgeBase.getName())
|
||||
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
|
||||
}
|
||||
|
||||
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
|
||||
.collectionName(knowledgeBase.getName())
|
||||
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
|
||||
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
|
||||
* @return 检索结果
|
||||
*/
|
||||
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
|
||||
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
|
||||
});
|
||||
return searchResults;
|
||||
}
|
||||
|
||||
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
|
||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
|
||||
return knowledgeBase;
|
||||
}
|
||||
|
||||
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
|
||||
if (admin) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
String currentUserId = resourceAccessService.requireCurrentUserId();
|
||||
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
|
||||
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
|
||||
return ownedKnowledgeBaseIds;
|
||||
}
|
||||
BusinessAssert.isTrue(
|
||||
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
|
||||
SystemErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
);
|
||||
return Collections.singletonList(request.getKnowledgeBaseId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
|
||||
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
|
||||
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 知识库仓储接口
|
||||
*
|
||||
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
|
||||
* @param request 查询请求
|
||||
* @return 知识库分页结果
|
||||
*/
|
||||
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request);
|
||||
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
|
||||
|
||||
List<String> listIdsByCreatedBy(String createdBy);
|
||||
}
|
||||
|
||||
@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
|
||||
|
||||
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
|
||||
|
||||
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request);
|
||||
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 知识库仓储实现类
|
||||
*
|
||||
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
|
||||
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
|
||||
|
||||
@Override
|
||||
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) {
|
||||
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
|
||||
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
|
||||
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
|
||||
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
|
||||
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
|
||||
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
|
||||
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
|
||||
.orderByDesc(KnowledgeBase::getCreatedAt));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> listIdsByCreatedBy(String createdBy) {
|
||||
if (!StringUtils.hasText(createdBy)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return lambdaQuery()
|
||||
.select(KnowledgeBase::getId)
|
||||
.eq(KnowledgeBase::getCreatedBy, createdBy)
|
||||
.list()
|
||||
.stream()
|
||||
.map(KnowledgeBase::getId)
|
||||
.filter(StringUtils::hasText)
|
||||
.toList();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
|
||||
}
|
||||
|
||||
@Override
|
||||
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) {
|
||||
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
|
||||
return lambdaQuery()
|
||||
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
|
||||
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
|
||||
RagFile::getKnowledgeBaseId,
|
||||
knowledgeBaseIds)
|
||||
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
|
||||
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
|
||||
.page(page);
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.datamate.common.auth.application;
|
||||
|
||||
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
|
||||
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||
import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 资源访问控制服务(基于请求用户上下文)
|
||||
*/
|
||||
@Service
|
||||
public class ResourceAccessService {
|
||||
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
|
||||
|
||||
public boolean isAdmin() {
|
||||
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
|
||||
}
|
||||
|
||||
public String getCurrentUserId() {
|
||||
return RequestUserContextHolder.getCurrentUserId();
|
||||
}
|
||||
|
||||
public String requireCurrentUserId() {
|
||||
String currentUserId = getCurrentUserId();
|
||||
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
|
||||
return currentUserId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 资源列表查询的 owner 过滤:
|
||||
* - 管理员返回 null(不过滤)
|
||||
* - 非管理员返回当前用户ID
|
||||
*/
|
||||
public String resolveOwnerFilterUserId() {
|
||||
if (isAdmin()) {
|
||||
return null;
|
||||
}
|
||||
return requireCurrentUserId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 校验当前用户是否可访问 owner 资源
|
||||
*/
|
||||
public void assertOwnerAccess(String ownerUserId) {
|
||||
if (isAdmin()) {
|
||||
return;
|
||||
}
|
||||
String currentUserId = requireCurrentUserId();
|
||||
BusinessAssert.isTrue(
|
||||
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
|
||||
SystemErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.datamate.common.auth.infrastructure.context;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 请求级用户上下文
|
||||
*/
|
||||
@Getter
|
||||
public class RequestUserContext {
|
||||
private final String userId;
|
||||
private final String username;
|
||||
private final List<String> roles;
|
||||
|
||||
private RequestUserContext(String userId, String username, List<String> roles) {
|
||||
this.userId = userId;
|
||||
this.username = username;
|
||||
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
|
||||
}
|
||||
|
||||
public static RequestUserContext of(String userId, String username, List<String> roles) {
|
||||
return new RequestUserContext(userId, username, roles);
|
||||
}
|
||||
|
||||
public static RequestUserContext empty() {
|
||||
return new RequestUserContext(null, null, Collections.emptyList());
|
||||
}
|
||||
|
||||
public boolean hasRole(String roleCode) {
|
||||
if (!StringUtils.hasText(roleCode)) {
|
||||
return false;
|
||||
}
|
||||
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package com.datamate.common.auth.infrastructure.context;
|
||||
|
||||
import org.springframework.core.NamedThreadLocal;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 请求级用户上下文持有器
|
||||
*/
|
||||
public final class RequestUserContextHolder {
|
||||
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
|
||||
new NamedThreadLocal<>("datamate-request-user-context");
|
||||
|
||||
private RequestUserContextHolder() {
|
||||
}
|
||||
|
||||
public static void set(RequestUserContext context) {
|
||||
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
|
||||
}
|
||||
|
||||
public static RequestUserContext get() {
|
||||
RequestUserContext context = USER_CONTEXT_HOLDER.get();
|
||||
return context == null ? RequestUserContext.empty() : context;
|
||||
}
|
||||
|
||||
public static String getCurrentUserId() {
|
||||
return get().getUserId();
|
||||
}
|
||||
|
||||
public static List<String> getCurrentRoles() {
|
||||
List<String> roles = get().getRoles();
|
||||
return roles == null ? Collections.emptyList() : roles;
|
||||
}
|
||||
|
||||
public static boolean hasRole(String roleCode) {
|
||||
if (!StringUtils.hasText(roleCode)) {
|
||||
return false;
|
||||
}
|
||||
return getCurrentRoles().stream()
|
||||
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
|
||||
}
|
||||
|
||||
public static void clear() {
|
||||
USER_CONTEXT_HOLDER.remove();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
package com.datamate.common.auth.infrastructure.context;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.servlet.HandlerInterceptor;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 从网关透传请求头中提取用户上下文
|
||||
*/
|
||||
@Component
|
||||
public class RequestUserContextInterceptor implements HandlerInterceptor {
|
||||
private static final String HEADER_USER_ID = "X-User-Id";
|
||||
private static final String HEADER_USER_NAME = "X-User-Name";
|
||||
private static final String HEADER_USER_ROLES = "X-User-Roles";
|
||||
|
||||
@Override
|
||||
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
|
||||
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
|
||||
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
|
||||
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
|
||||
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
|
||||
RequestUserContextHolder.clear();
|
||||
}
|
||||
|
||||
private String normalizeValue(String value) {
|
||||
if (!StringUtils.hasText(value)) {
|
||||
return null;
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
|
||||
private List<String> parseRoleCodes(String roleHeader) {
|
||||
if (!StringUtils.hasText(roleHeader)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return Arrays.stream(roleHeader.split(","))
|
||||
.map(String::trim)
|
||||
.filter(StringUtils::hasText)
|
||||
.toList();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.datamate.common.auth.infrastructure.context;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
|
||||
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||
|
||||
/**
|
||||
* 请求用户上下文拦截器注册
|
||||
*/
|
||||
@Configuration
|
||||
@RequiredArgsConstructor
|
||||
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
|
||||
private final RequestUserContextInterceptor requestUserContextInterceptor;
|
||||
|
||||
@Override
|
||||
public void addInterceptors(InterceptorRegistry registry) {
|
||||
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.datamate.common.infrastructure.config;
|
||||
|
||||
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
|
||||
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.ibatis.reflection.MetaObject;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
|
||||
* 获取当前用户(需要根据你的安全框架实现)
|
||||
*/
|
||||
private String getCurrentUser() {
|
||||
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等
|
||||
// 示例:返回默认用户或从SecurityContext获取
|
||||
try {
|
||||
// 如果是Spring Security
|
||||
// return SecurityContextHolder.getContext().getAuthentication().getName();
|
||||
|
||||
// 临时返回默认值,请根据实际情况修改
|
||||
return "system";
|
||||
} catch (Exception e) {
|
||||
log.error("Error getting current user", e);
|
||||
return "unknown";
|
||||
String currentUserId = RequestUserContextHolder.getCurrentUserId();
|
||||
if (StringUtils.hasText(currentUserId)) {
|
||||
return currentUserId;
|
||||
}
|
||||
return "system";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user