feat(auth): 为数据管理和RAG服务增加资源访问控制

- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证
- 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证
- 修改DatasetRepository接口和实现类,增加按创建者过滤的方法
- 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法
- 在RAG索引器服务中添加知识库访问权限检查和作用域过滤
- 更新实体元对象处理器以使用请求用户上下文获取当前用户
- 在前端设置页面添加用户权限管理功能和角色权限控制
- 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
2026-02-06 14:58:46 +08:00
parent 056cee11cc
commit 6a4c4ae3d7
28 changed files with 1063 additions and 158 deletions

View File

@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.domain.utils.ChunksSaver;
import com.datamate.common.setting.application.SysParamApplicationService;
import com.datamate.datamanagement.interfaces.dto.*;
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
private final CollectionTaskClient collectionTaskClient;
private final DatasetFileApplicationService datasetFileApplicationService;
private final SysParamApplicationService sysParamService;
private final ResourceAccessService resourceAccessService;
@Value("${datamate.data-management.base-path:/dataset}")
private String datasetBasePath;
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
if (StringUtils.hasText(updateDatasetRequest.getName())) {
dataset.setName(updateDatasetRequest.getName());
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
public void deleteDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
long childCount = datasetRepository.countByParentId(datasetId);
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
datasetRepository.removeById(datasetId);
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
public Dataset getDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
dataset.setFiles(datasetFiles);
applyVisibleFileCounts(Collections.singletonList(dataset));
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
@Transactional(readOnly = true)
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
page = datasetRepository.findByCriteria(page, query);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
String datasetPvcName = getDatasetPvcName();
applyVisibleFileCounts(page.getRecords());
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Set<String> sourceTags = normalizeTagNames(dataset.getTags());
if (sourceTags.isEmpty()) {
return Collections.emptyList();
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
SIMILAR_DATASET_CANDIDATE_MAX,
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
List<Dataset> candidates = datasetRepository.findSimilarByTags(
new ArrayList<>(sourceTags),
datasetId,
candidateLimit
candidateLimit,
ownerFilterUserId
);
if (CollectionUtils.isEmpty(candidates)) {
return Collections.emptyList();
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
if (dataset == null) {
throw new IllegalArgumentException("Dataset not found: " + datasetId);
}
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Map<String, Object> statistics = new HashMap<>();
@@ -485,7 +495,11 @@ public class DatasetApplicationService {
* 获取所有数据集的汇总统计信息
*/
public AllDatasetStatisticsResponse getAllDatasetStatistics() {
return datasetRepository.getAllDatasetStatistics();
if (resourceAccessService.isAdmin()) {
return datasetRepository.getAllDatasetStatistics();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
}
/**

View File

@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.CommonErrorCode;
import com.datamate.common.interfaces.PagedResponse;
@@ -40,6 +41,7 @@ import java.util.UUID;
public class KnowledgeSetApplicationService {
private final KnowledgeSetRepository knowledgeSetRepository;
private final TagMapper tagMapper;
private final ResourceAccessService resourceAccessService;
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
public void deleteKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
knowledgeSetRepository.removeById(setId);
}
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet getKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
return knowledgeSet;
}
@Transactional(readOnly = true)
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
page = knowledgeSetRepository.findByCriteria(page, query);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
}

View File

@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
AllDatasetStatisticsResponse getAllDatasetStatistics();
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query);
AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
long countByParentId(String parentDatasetId);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
}

View File

@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
KnowledgeSet findByName(String name);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
}

View File

@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) {
public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
List<Dataset> datasets = lambdaQuery()
.eq(Dataset::getCreatedBy, createdBy)
.list();
long totalFiles = datasets.stream()
.map(Dataset::getFileCount)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
long totalSize = datasets.stream()
.map(Dataset::getSizeBytes)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
response.setTotalDatasets(datasets.size());
response.setTotalFiles(totalFiles);
response.setTotalSize(totalSize);
return response;
}
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
.eq(query.getType() != null, Dataset::getDatasetType, query.getType())
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus());
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
if (query.getParentDatasetId() != null) {
if (StringUtils.isBlank(query.getParentDatasetId())) {
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
}
@Override
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) {
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
return Collections.emptyList();
}
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
if (StringUtils.isNotBlank(excludedDatasetId)) {
wrapper.ne(Dataset::getId, excludedDatasetId.trim());
}
if (StringUtils.isNotBlank(createdBy)) {
wrapper.eq(Dataset::getCreatedBy, createdBy);
}
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
wrapper.and(condition -> {
boolean hasCondition = false;

View File

@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
}
@Override
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) {
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo());
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
if (StringUtils.isNotBlank(query.getKeyword())) {
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())

View File

@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.common.interfaces.PagingQuery;
import com.datamate.common.setting.domain.entity.ModelConfig;
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
private final ApplicationEventPublisher eventPublisher;
private final ModelConfigRepository modelConfigRepository;
private final MilvusService milvusService;
private final ResourceAccessService resourceAccessService;
/**
* 创建知识库
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
*/
@Transactional(rollbackFor = Exception.class)
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName())
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
*/
@Transactional(rollbackFor = Exception.class)
public void delete(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
knowledgeBaseRepository.removeById(knowledgeBaseId);
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
}
public KnowledgeBaseResp getById(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
page = knowledgeBaseRepository.page(page, request);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class)
public void addFiles(AddFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
}
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
request.setKnowledgeBaseId(knowledgeBaseId);
page = ragFileRepository.page(page, request);
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
}
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
boolean admin = resourceAccessService.isAdmin();
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
}
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
page = ragFileRepository.searchPage(page, request);
page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
List<RagFile> records = page.getRecords();
if (records.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class)
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
ragFileRepository.removeByIds(request.getIds());
milvusService.getMilvusClient().delete(DeleteReq.builder()
.collectionName(knowledgeBase.getName())
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
}
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
* @return 检索结果
*/
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
});
return searchResults;
}
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
return knowledgeBase;
}
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
if (admin) {
return Collections.emptyList();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
return ownedKnowledgeBaseIds;
}
BusinessAssert.isTrue(
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
return Collections.singletonList(request.getKnowledgeBaseId());
}
}

View File

@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import java.util.List;
/**
* 知识库仓储接口
*
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
* @param request 查询请求
* @return 知识库分页结果
*/
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request);
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
List<String> listIdsByCreatedBy(String createdBy);
}

View File

@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
}

View File

@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import org.springframework.stereotype.Repository;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 知识库仓储实现类
*
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
@Override
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) {
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
.orderByDesc(KnowledgeBase::getCreatedAt));
}
@Override
public List<String> listIdsByCreatedBy(String createdBy) {
if (!StringUtils.hasText(createdBy)) {
return Collections.emptyList();
}
return lambdaQuery()
.select(KnowledgeBase::getId)
.eq(KnowledgeBase::getCreatedBy, createdBy)
.list()
.stream()
.map(KnowledgeBase::getId)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
}
@Override
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) {
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
return lambdaQuery()
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
RagFile::getKnowledgeBaseId,
knowledgeBaseIds)
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
.page(page);

View File

@@ -0,0 +1,58 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Objects;
/**
* 资源访问控制服务(基于请求用户上下文)
*/
@Service
public class ResourceAccessService {
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
public boolean isAdmin() {
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
}
public String getCurrentUserId() {
return RequestUserContextHolder.getCurrentUserId();
}
public String requireCurrentUserId() {
String currentUserId = getCurrentUserId();
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
return currentUserId;
}
/**
* 资源列表查询的 owner 过滤:
* - 管理员返回 null(不过滤)
* - 非管理员返回当前用户ID
*/
public String resolveOwnerFilterUserId() {
if (isAdmin()) {
return null;
}
return requireCurrentUserId();
}
/**
* 校验当前用户是否可访问 owner 资源
*/
public void assertOwnerAccess(String ownerUserId) {
if (isAdmin()) {
return;
}
String currentUserId = requireCurrentUserId();
BusinessAssert.isTrue(
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
}
}

View File

@@ -0,0 +1,40 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.Getter;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* 请求级用户上下文
*/
@Getter
public class RequestUserContext {
private final String userId;
private final String username;
private final List<String> roles;
private RequestUserContext(String userId, String username, List<String> roles) {
this.userId = userId;
this.username = username;
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
}
public static RequestUserContext of(String userId, String username, List<String> roles) {
return new RequestUserContext(userId, username, roles);
}
public static RequestUserContext empty() {
return new RequestUserContext(null, null, Collections.emptyList());
}
public boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
}
}

View File

@@ -0,0 +1,49 @@
package com.datamate.common.auth.infrastructure.context;
import org.springframework.core.NamedThreadLocal;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 请求级用户上下文持有器
*/
public final class RequestUserContextHolder {
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
new NamedThreadLocal<>("datamate-request-user-context");
private RequestUserContextHolder() {
}
public static void set(RequestUserContext context) {
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
}
public static RequestUserContext get() {
RequestUserContext context = USER_CONTEXT_HOLDER.get();
return context == null ? RequestUserContext.empty() : context;
}
public static String getCurrentUserId() {
return get().getUserId();
}
public static List<String> getCurrentRoles() {
List<String> roles = get().getRoles();
return roles == null ? Collections.emptyList() : roles;
}
public static boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return getCurrentRoles().stream()
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
}
public static void clear() {
USER_CONTEXT_HOLDER.remove();
}
}

View File

@@ -0,0 +1,53 @@
package com.datamate.common.auth.infrastructure.context;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* 从网关透传请求头中提取用户上下文
*/
@Component
public class RequestUserContextInterceptor implements HandlerInterceptor {
private static final String HEADER_USER_ID = "X-User-Id";
private static final String HEADER_USER_NAME = "X-User-Name";
private static final String HEADER_USER_ROLES = "X-User-Roles";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
RequestUserContextHolder.clear();
}
private String normalizeValue(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
private List<String> parseRoleCodes(String roleHeader) {
if (!StringUtils.hasText(roleHeader)) {
return Collections.emptyList();
}
return Arrays.stream(roleHeader.split(","))
.map(String::trim)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* 请求用户上下文拦截器注册
*/
@Configuration
@RequiredArgsConstructor
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
private final RequestUserContextInterceptor requestUserContextInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
}
}

View File

@@ -1,9 +1,11 @@
package com.datamate.common.infrastructure.config;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.StringUtils;
import java.time.LocalDateTime;
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
* 获取当前用户(需要根据你的安全框架实现)
*/
private String getCurrentUser() {
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等
// 示例:返回默认用户或从SecurityContext获取
try {
// 如果是Spring Security
// return SecurityContextHolder.getContext().getAuthentication().getName();
// 临时返回默认值,请根据实际情况修改
return "system";
} catch (Exception e) {
log.error("Error getting current user", e);
return "unknown";
String currentUserId = RequestUserContextHolder.getCurrentUserId();
if (StringUtils.hasText(currentUserId)) {
return currentUserId;
}
return "system";
}
}