feat(auth): 为数据管理和RAG服务增加资源访问控制

- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证
- 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证
- 修改DatasetRepository接口和实现类,增加按创建者过滤的方法
- 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法
- 在RAG索引器服务中添加知识库访问权限检查和作用域过滤
- 更新实体元对象处理器以使用请求用户上下文获取当前用户
- 在前端设置页面添加用户权限管理功能和角色权限控制
- 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
2026-02-06 14:58:46 +08:00
parent 056cee11cc
commit 6a4c4ae3d7
28 changed files with 1063 additions and 158 deletions

View File

@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.domain.utils.ChunksSaver; import com.datamate.common.domain.utils.ChunksSaver;
import com.datamate.common.setting.application.SysParamApplicationService; import com.datamate.common.setting.application.SysParamApplicationService;
import com.datamate.datamanagement.interfaces.dto.*; import com.datamate.datamanagement.interfaces.dto.*;
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
private final CollectionTaskClient collectionTaskClient; private final CollectionTaskClient collectionTaskClient;
private final DatasetFileApplicationService datasetFileApplicationService; private final DatasetFileApplicationService datasetFileApplicationService;
private final SysParamApplicationService sysParamService; private final SysParamApplicationService sysParamService;
private final ResourceAccessService resourceAccessService;
@Value("${datamate.data-management.base-path:/dataset}") @Value("${datamate.data-management.base-path:/dataset}")
private String datasetBasePath; private String datasetBasePath;
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) { public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
if (StringUtils.hasText(updateDatasetRequest.getName())) { if (StringUtils.hasText(updateDatasetRequest.getName())) {
dataset.setName(updateDatasetRequest.getName()); dataset.setName(updateDatasetRequest.getName());
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
public void deleteDataset(String datasetId) { public void deleteDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
long childCount = datasetRepository.countByParentId(datasetId); long childCount = datasetRepository.countByParentId(datasetId);
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN); BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
datasetRepository.removeById(datasetId); datasetRepository.removeById(datasetId);
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
public Dataset getDataset(String datasetId) { public Dataset getDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId); List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
dataset.setFiles(datasetFiles); dataset.setFiles(datasetFiles);
applyVisibleFileCounts(Collections.singletonList(dataset)); applyVisibleFileCounts(Collections.singletonList(dataset));
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
@Transactional(readOnly = true) @Transactional(readOnly = true)
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) { public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize()); IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
page = datasetRepository.findByCriteria(page, query); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
String datasetPvcName = getDatasetPvcName(); String datasetPvcName = getDatasetPvcName();
applyVisibleFileCounts(page.getRecords()); applyVisibleFileCounts(page.getRecords());
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords()); List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Set<String> sourceTags = normalizeTagNames(dataset.getTags()); Set<String> sourceTags = normalizeTagNames(dataset.getTags());
if (sourceTags.isEmpty()) { if (sourceTags.isEmpty()) {
return Collections.emptyList(); return Collections.emptyList();
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
SIMILAR_DATASET_CANDIDATE_MAX, SIMILAR_DATASET_CANDIDATE_MAX,
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit) Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
); );
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
List<Dataset> candidates = datasetRepository.findSimilarByTags( List<Dataset> candidates = datasetRepository.findSimilarByTags(
new ArrayList<>(sourceTags), new ArrayList<>(sourceTags),
datasetId, datasetId,
candidateLimit candidateLimit,
ownerFilterUserId
); );
if (CollectionUtils.isEmpty(candidates)) { if (CollectionUtils.isEmpty(candidates)) {
return Collections.emptyList(); return Collections.emptyList();
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
if (dataset == null) { if (dataset == null) {
throw new IllegalArgumentException("Dataset not found: " + datasetId); throw new IllegalArgumentException("Dataset not found: " + datasetId);
} }
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Map<String, Object> statistics = new HashMap<>(); Map<String, Object> statistics = new HashMap<>();
@@ -485,7 +495,11 @@ public class DatasetApplicationService {
* 获取所有数据集的汇总统计信息 * 获取所有数据集的汇总统计信息
*/ */
public AllDatasetStatisticsResponse getAllDatasetStatistics() { public AllDatasetStatisticsResponse getAllDatasetStatistics() {
return datasetRepository.getAllDatasetStatistics(); if (resourceAccessService.isAdmin()) {
return datasetRepository.getAllDatasetStatistics();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
} }
/** /**

View File

@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert; import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.CommonErrorCode; import com.datamate.common.infrastructure.exception.CommonErrorCode;
import com.datamate.common.interfaces.PagedResponse; import com.datamate.common.interfaces.PagedResponse;
@@ -40,6 +41,7 @@ import java.util.UUID;
public class KnowledgeSetApplicationService { public class KnowledgeSetApplicationService {
private final KnowledgeSetRepository knowledgeSetRepository; private final KnowledgeSetRepository knowledgeSetRepository;
private final TagMapper tagMapper; private final TagMapper tagMapper;
private final ResourceAccessService resourceAccessService;
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) { public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null, BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) { public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()), BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR); DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
public void deleteKnowledgeSet(String setId) { public void deleteKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
knowledgeSetRepository.removeById(setId); knowledgeSetRepository.removeById(setId);
} }
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet getKnowledgeSet(String setId) { public KnowledgeSet getKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
return knowledgeSet; return knowledgeSet;
} }
@Transactional(readOnly = true) @Transactional(readOnly = true)
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) { public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize()); IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
page = knowledgeSetRepository.findByCriteria(page, query); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords()); List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages()); return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
} }

View File

@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
AllDatasetStatisticsResponse getAllDatasetStatistics(); AllDatasetStatisticsResponse getAllDatasetStatistics();
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query); AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
long countByParentId(String parentDatasetId); long countByParentId(String parentDatasetId);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit); List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
} }

View File

@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> { public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
KnowledgeSet findByName(String name); KnowledgeSet findByName(String name);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query); IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
} }

View File

@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
@Override @Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) { public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
List<Dataset> datasets = lambdaQuery()
.eq(Dataset::getCreatedBy, createdBy)
.list();
long totalFiles = datasets.stream()
.map(Dataset::getFileCount)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
long totalSize = datasets.stream()
.map(Dataset::getSizeBytes)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
response.setTotalDatasets(datasets.size());
response.setTotalFiles(totalFiles);
response.setTotalSize(totalSize);
return response;
}
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>() LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
.eq(query.getType() != null, Dataset::getDatasetType, query.getType()) .eq(query.getType() != null, Dataset::getDatasetType, query.getType())
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus()); .eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
if (query.getParentDatasetId() != null) { if (query.getParentDatasetId() != null) {
if (StringUtils.isBlank(query.getParentDatasetId())) { if (StringUtils.isBlank(query.getParentDatasetId())) {
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
} }
@Override @Override
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) { public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) { if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
return Collections.emptyList(); return Collections.emptyList();
} }
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
if (StringUtils.isNotBlank(excludedDatasetId)) { if (StringUtils.isNotBlank(excludedDatasetId)) {
wrapper.ne(Dataset::getId, excludedDatasetId.trim()); wrapper.ne(Dataset::getId, excludedDatasetId.trim());
} }
if (StringUtils.isNotBlank(createdBy)) {
wrapper.eq(Dataset::getCreatedBy, createdBy);
}
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0"); wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
wrapper.and(condition -> { wrapper.and(condition -> {
boolean hasCondition = false; boolean hasCondition = false;

View File

@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
} }
@Override @Override
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) { public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>() LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus()) .eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain()) .eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity()) .eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType()) .eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom()) .ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo()); .le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
if (StringUtils.isNotBlank(query.getKeyword())) { if (StringUtils.isNotBlank(query.getKeyword())) {
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword()) wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())

View File

@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode; import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse; import com.datamate.common.interfaces.PagedResponse;
import com.datamate.common.interfaces.PagingQuery; import com.datamate.common.interfaces.PagingQuery;
import com.datamate.common.setting.domain.entity.ModelConfig; import com.datamate.common.setting.domain.entity.ModelConfig;
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
private final ApplicationEventPublisher eventPublisher; private final ApplicationEventPublisher eventPublisher;
private final ModelConfigRepository modelConfigRepository; private final ModelConfigRepository modelConfigRepository;
private final MilvusService milvusService; private final MilvusService milvusService;
private final ResourceAccessService resourceAccessService;
/** /**
* 创建知识库 * 创建知识库
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
*/ */
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) { public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) { if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder() milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
*/ */
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void delete(String knowledgeBaseId) { public void delete(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
knowledgeBaseRepository.removeById(knowledgeBaseId); knowledgeBaseRepository.removeById(knowledgeBaseId);
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId); ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build()); milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
} }
public KnowledgeBaseResp getById(String knowledgeBaseId) { public KnowledgeBaseResp getById(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase); KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel())); resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel())); resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) { public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize()); IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
page = knowledgeBaseRepository.page(page, request); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount // 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList(); List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void addFiles(AddFilesReq request) { public void addFiles(AddFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId())) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> { List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
RagFile ragFile = new RagFile(); RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId()); ragFile.setKnowledgeBaseId(knowledgeBase.getId());
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
} }
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) { public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize()); IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
request.setKnowledgeBaseId(knowledgeBaseId); request.setKnowledgeBaseId(knowledgeBaseId);
page = ragFileRepository.page(page, request); page = ragFileRepository.page(page, request);
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
} }
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) { public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
boolean admin = resourceAccessService.isAdmin();
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
}
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize()); IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
page = ragFileRepository.searchPage(page, request); page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
List<RagFile> records = page.getRecords(); List<RagFile> records = page.getRecords();
if (records.isEmpty()) { if (records.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages()); return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) { public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ragFileRepository.removeByIds(request.getIds()); ragFileRepository.removeByIds(request.getIds());
milvusService.getMilvusClient().delete(DeleteReq.builder() milvusService.getMilvusClient().delete(DeleteReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
} }
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) { public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder() QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"") .filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
* @return 检索结果 * @return 检索结果
*/ */
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) { public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst())) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()); ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig); EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content(); Embedding embedding = embeddingModel.embed(request.getQuery()).content();
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
}); });
return searchResults; return searchResults;
} }
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
return knowledgeBase;
}
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
if (admin) {
return Collections.emptyList();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
return ownedKnowledgeBaseIds;
}
BusinessAssert.isTrue(
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
return Collections.singletonList(request.getKnowledgeBaseId());
}
} }

View File

@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
import com.datamate.rag.indexer.domain.model.KnowledgeBase; import com.datamate.rag.indexer.domain.model.KnowledgeBase;
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq; import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import java.util.List;
/** /**
* 知识库仓储接口 * 知识库仓储接口
* *
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
* @param request 查询请求 * @param request 查询请求
* @return 知识库分页结果 * @return 知识库分页结果
*/ */
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request); IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
List<String> listIdsByCreatedBy(String createdBy);
} }

View File

@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request); IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request); IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
} }

View File

@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/** /**
* 知识库仓储实现类 * 知识库仓储实现类
* *
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository { public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
@Override @Override
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) { public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>() return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName()) .like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription()) .like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy()) .like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy()) .like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
.orderByDesc(KnowledgeBase::getCreatedAt)); .orderByDesc(KnowledgeBase::getCreatedAt));
} }
@Override
public List<String> listIdsByCreatedBy(String createdBy) {
if (!StringUtils.hasText(createdBy)) {
return Collections.emptyList();
}
return lambdaQuery()
.select(KnowledgeBase::getId)
.eq(KnowledgeBase::getCreatedBy, createdBy)
.list()
.stream()
.map(KnowledgeBase::getId)
.filter(StringUtils::hasText)
.toList();
}
} }

View File

@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
} }
@Override @Override
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) { public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
return lambdaQuery() return lambdaQuery()
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId()) .eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
RagFile::getKnowledgeBaseId,
knowledgeBaseIds)
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName()) .like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath())) .likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
.page(page); .page(page);

View File

@@ -0,0 +1,58 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Objects;
/**
* 资源访问控制服务(基于请求用户上下文)
*/
@Service
public class ResourceAccessService {
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
public boolean isAdmin() {
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
}
public String getCurrentUserId() {
return RequestUserContextHolder.getCurrentUserId();
}
public String requireCurrentUserId() {
String currentUserId = getCurrentUserId();
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
return currentUserId;
}
/**
* 资源列表查询的 owner 过滤:
* - 管理员返回 null(不过滤)
* - 非管理员返回当前用户ID
*/
public String resolveOwnerFilterUserId() {
if (isAdmin()) {
return null;
}
return requireCurrentUserId();
}
/**
* 校验当前用户是否可访问 owner 资源
*/
public void assertOwnerAccess(String ownerUserId) {
if (isAdmin()) {
return;
}
String currentUserId = requireCurrentUserId();
BusinessAssert.isTrue(
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
}
}

View File

@@ -0,0 +1,40 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.Getter;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* 请求级用户上下文
*/
@Getter
public class RequestUserContext {
private final String userId;
private final String username;
private final List<String> roles;
private RequestUserContext(String userId, String username, List<String> roles) {
this.userId = userId;
this.username = username;
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
}
public static RequestUserContext of(String userId, String username, List<String> roles) {
return new RequestUserContext(userId, username, roles);
}
public static RequestUserContext empty() {
return new RequestUserContext(null, null, Collections.emptyList());
}
public boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
}
}

View File

@@ -0,0 +1,49 @@
package com.datamate.common.auth.infrastructure.context;
import org.springframework.core.NamedThreadLocal;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 请求级用户上下文持有器
*/
public final class RequestUserContextHolder {
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
new NamedThreadLocal<>("datamate-request-user-context");
private RequestUserContextHolder() {
}
public static void set(RequestUserContext context) {
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
}
public static RequestUserContext get() {
RequestUserContext context = USER_CONTEXT_HOLDER.get();
return context == null ? RequestUserContext.empty() : context;
}
public static String getCurrentUserId() {
return get().getUserId();
}
public static List<String> getCurrentRoles() {
List<String> roles = get().getRoles();
return roles == null ? Collections.emptyList() : roles;
}
public static boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return getCurrentRoles().stream()
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
}
public static void clear() {
USER_CONTEXT_HOLDER.remove();
}
}

View File

@@ -0,0 +1,53 @@
package com.datamate.common.auth.infrastructure.context;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* 从网关透传请求头中提取用户上下文
*/
@Component
public class RequestUserContextInterceptor implements HandlerInterceptor {
private static final String HEADER_USER_ID = "X-User-Id";
private static final String HEADER_USER_NAME = "X-User-Name";
private static final String HEADER_USER_ROLES = "X-User-Roles";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
RequestUserContextHolder.clear();
}
private String normalizeValue(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
private List<String> parseRoleCodes(String roleHeader) {
if (!StringUtils.hasText(roleHeader)) {
return Collections.emptyList();
}
return Arrays.stream(roleHeader.split(","))
.map(String::trim)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* 请求用户上下文拦截器注册
*/
@Configuration
@RequiredArgsConstructor
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
private final RequestUserContextInterceptor requestUserContextInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
}
}

View File

@@ -1,9 +1,11 @@
package com.datamate.common.infrastructure.config; package com.datamate.common.infrastructure.config;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler; import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.MetaObject;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.util.StringUtils;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
* 获取当前用户(需要根据你的安全框架实现) * 获取当前用户(需要根据你的安全框架实现)
*/ */
private String getCurrentUser() { private String getCurrentUser() {
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等 String currentUserId = RequestUserContextHolder.getCurrentUserId();
// 示例:返回默认用户或从SecurityContext获取 if (StringUtils.hasText(currentUserId)) {
try { return currentUserId;
// 如果是Spring Security
// return SecurityContextHolder.getContext().getAuthentication().getName();
// 临时返回默认值,请根据实际情况修改
return "system";
} catch (Exception e) {
log.error("Error getting current user", e);
return "unknown";
} }
return "system";
} }
} }

View File

@@ -1,12 +1,51 @@
import { useState } from "react"; import { useEffect, useMemo, useState } from "react";
import { Menu } from "antd"; import { Menu } from "antd";
import { SettingOutlined } from "@ant-design/icons"; import { SettingOutlined, TeamOutlined } from "@ant-design/icons";
import { Component } from "lucide-react"; import { Component } from "lucide-react";
import SystemConfig from "./SystemConfig"; import SystemConfig from "./SystemConfig";
import ModelAccess from "./ModelAccess"; import ModelAccess from "./ModelAccess";
import UserPermissionManagement from "./UserPermissionManagement";
import { useAppSelector } from "@/store/hooks";
import { hasPermission, PermissionCodes } from "@/auth/permissions";
export default function SettingsPage() { export default function SettingsPage() {
const [activeTab, setActiveTab] = useState("model-access"); const permissions = useAppSelector((state) => state.auth.permissions);
const canManageUsers = hasPermission(permissions, PermissionCodes.userManage);
const canViewRoles = hasPermission(permissions, PermissionCodes.roleManage);
const canViewPermissions = hasPermission(
permissions,
PermissionCodes.permissionManage
);
const tabs = useMemo(() => {
const nextTabs = [
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
];
if (canManageUsers || canViewRoles || canViewPermissions) {
nextTabs.push({
key: "user-permission",
icon: <TeamOutlined />,
label: "用户与权限",
});
}
return nextTabs;
}, [canManageUsers, canViewPermissions, canViewRoles]);
const [activeTab, setActiveTab] = useState<string>(tabs[0]?.key ?? "model-access");
useEffect(() => {
const hasActiveTab = tabs.some((tab) => tab.key === activeTab);
if (!hasActiveTab && tabs.length > 0) {
setActiveTab(tabs[0].key);
}
}, [activeTab, tabs]);
return ( return (
<div className="h-screen flex"> <div className="h-screen flex">
@@ -18,21 +57,10 @@ export default function SettingsPage() {
<div className="h-full"> <div className="h-full">
<Menu <Menu
mode="inline" mode="inline"
items={[ items={tabs}
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
]}
selectedKeys={[activeTab]} selectedKeys={[activeTab]}
onClick={({ key }) => { onClick={({ key }) => {
setActiveTab(key); setActiveTab(String(key));
}} }}
/> />
</div> </div>
@@ -41,6 +69,13 @@ export default function SettingsPage() {
{/* 内容区域,根据 activeTab 渲染不同的组件 */} {/* 内容区域,根据 activeTab 渲染不同的组件 */}
{activeTab === "system-config" && <SystemConfig />} {activeTab === "system-config" && <SystemConfig />}
{activeTab === "model-access" && <ModelAccess />} {activeTab === "model-access" && <ModelAccess />}
{activeTab === "user-permission" && (
<UserPermissionManagement
canManageUsers={canManageUsers}
canViewRoles={canViewRoles}
canViewPermissions={canViewPermissions}
/>
)}
</div> </div>
</div> </div>
); );

View File

@@ -0,0 +1,321 @@
import { useCallback, useEffect, useMemo, useState } from "react";
import {
Button,
Card,
Empty,
message,
Modal,
Select,
Space,
Table,
Tag,
Typography,
} from "antd";
import type { ColumnsType } from "antd/es/table";
import {
assignUserRolesUsingPut,
listAuthPermissionsUsingGet,
listAuthRolesUsingGet,
listAuthUsersUsingGet,
} from "./settings.apis";
import type {
AuthPermissionInfo,
AuthRoleInfo,
AuthUserWithRoles,
} from "./settings.apis";
interface ApiResponse<T> {
code: string;
message: string;
data: T;
}
interface UserPermissionManagementProps {
canManageUsers: boolean;
canViewRoles: boolean;
canViewPermissions: boolean;
}
export default function UserPermissionManagement({
canManageUsers,
canViewRoles,
canViewPermissions,
}: UserPermissionManagementProps) {
const [loading, setLoading] = useState(false);
const [users, setUsers] = useState<AuthUserWithRoles[]>([]);
const [roles, setRoles] = useState<AuthRoleInfo[]>([]);
const [permissions, setPermissions] = useState<AuthPermissionInfo[]>([]);
const [editingUser, setEditingUser] = useState<AuthUserWithRoles | null>(null);
const [selectedRoleCodes, setSelectedRoleCodes] = useState<string[]>([]);
const [submitting, setSubmitting] = useState(false);
const canShowAnything = canManageUsers || canViewRoles || canViewPermissions;
const canAssignRoles = canManageUsers && roles.length > 0;
const roleNameMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.roleName || role.roleCode])),
[roles]
);
const roleCodeToIdMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.id])),
[roles]
);
const loadData = useCallback(async () => {
setLoading(true);
try {
const requestTasks: Array<Promise<unknown>> = [];
if (canManageUsers || canViewRoles || canViewPermissions) {
requestTasks.push(listAuthUsersUsingGet());
}
if (canManageUsers || canViewRoles) {
requestTasks.push(listAuthRolesUsingGet());
}
if (canViewPermissions) {
requestTasks.push(listAuthPermissionsUsingGet());
}
const responses = await Promise.all(requestTasks);
let index = 0;
if (canManageUsers || canViewRoles || canViewPermissions) {
const userResponse = responses[index++] as ApiResponse<AuthUserWithRoles[]>;
setUsers(userResponse?.data ?? []);
}
if (canManageUsers || canViewRoles) {
const roleResponse = responses[index++] as ApiResponse<AuthRoleInfo[]>;
setRoles(roleResponse?.data ?? []);
} else {
setRoles([]);
}
if (canViewPermissions) {
const permissionResponse = responses[index++] as ApiResponse<AuthPermissionInfo[]>;
setPermissions(permissionResponse?.data ?? []);
} else {
setPermissions([]);
}
} catch (error) {
message.error("加载用户权限信息失败");
console.error("加载用户权限信息失败:", error);
} finally {
setLoading(false);
}
}, [canManageUsers, canViewPermissions, canViewRoles]);
useEffect(() => {
if (!canShowAnything) {
return;
}
void loadData();
}, [canShowAnything, loadData]);
const userColumns: ColumnsType<AuthUserWithRoles> = [
{
title: "用户名",
dataIndex: "username",
key: "username",
width: 180,
},
{
title: "姓名",
dataIndex: "fullName",
key: "fullName",
width: 180,
render: (value?: string) => value || "-",
},
{
title: "邮箱",
dataIndex: "email",
key: "email",
render: (value?: string) => value || "-",
},
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "角色",
dataIndex: "roleCodes",
key: "roleCodes",
render: (roleCodes: string[]) => (
<Space wrap>
{(roleCodes ?? []).map((roleCode) => (
<Tag key={roleCode}>{roleNameMap.get(roleCode) || roleCode}</Tag>
))}
</Space>
),
},
{
title: "操作",
key: "actions",
width: 120,
render: (_, record) => (
<Button
type="link"
disabled={!canAssignRoles}
onClick={() => {
setEditingUser(record);
setSelectedRoleCodes(record.roleCodes ?? []);
}}
>
</Button>
),
},
];
const roleColumns: ColumnsType<AuthRoleInfo> = [
{ title: "角色编码", dataIndex: "roleCode", key: "roleCode", width: 220 },
{ title: "角色名称", dataIndex: "roleName", key: "roleName", width: 180 },
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "描述",
dataIndex: "description",
key: "description",
render: (value?: string) => value || "-",
},
];
const permissionColumns: ColumnsType<AuthPermissionInfo> = [
{
title: "权限编码",
dataIndex: "permissionCode",
key: "permissionCode",
width: 260,
},
{
title: "权限名称",
dataIndex: "permissionName",
key: "permissionName",
width: 200,
},
{
title: "模块",
dataIndex: "module",
key: "module",
width: 140,
render: (value?: string) => value || "-",
},
{
title: "动作",
dataIndex: "action",
key: "action",
width: 120,
render: (value?: string) => value || "-",
},
{
title: "接口",
key: "api",
render: (_, record) =>
record.pathPattern ? `${record.method || "ALL"} ${record.pathPattern}` : "-",
},
];
const handleAssignRoles = async () => {
if (!editingUser) {
return;
}
if (selectedRoleCodes.length === 0) {
message.warning("请至少选择一个角色");
return;
}
const roleIds = selectedRoleCodes
.map((roleCode) => roleCodeToIdMap.get(roleCode))
.filter((roleId): roleId is string => Boolean(roleId));
if (roleIds.length !== selectedRoleCodes.length) {
message.error("角色映射失败,请刷新后重试");
return;
}
setSubmitting(true);
try {
await assignUserRolesUsingPut(editingUser.id, roleIds);
message.success("角色分配成功");
setEditingUser(null);
setSelectedRoleCodes([]);
await loadData();
} catch (error) {
message.error("角色分配失败");
console.error("角色分配失败:", error);
} finally {
setSubmitting(false);
}
};
if (!canShowAnything) {
return <Empty description="当前账号无用户与权限管理权限" />;
}
return (
<Space direction="vertical" size={16} className="w-full">
<Card title="用户管理">
<Table
loading={loading}
rowKey="id"
dataSource={users}
columns={userColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
{canViewRoles && (
<Card title="角色列表">
<Table
loading={loading}
rowKey="id"
dataSource={roles}
columns={roleColumns}
pagination={{ pageSize: 8, showSizeChanger: false }}
/>
</Card>
)}
{canViewPermissions && (
<Card title="权限列表">
<Table
loading={loading}
rowKey="id"
dataSource={permissions}
columns={permissionColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
)}
<Modal
title={`分配角色 - ${editingUser?.username || ""}`}
open={Boolean(editingUser)}
confirmLoading={submitting}
onOk={() => {
void handleAssignRoles();
}}
onCancel={() => {
setEditingUser(null);
setSelectedRoleCodes([]);
}}
>
{roles.length === 0 ? (
<Typography.Text type="secondary"></Typography.Text>
) : (
<Select
mode="multiple"
className="w-full"
placeholder="请选择角色"
value={selectedRoleCodes}
onChange={(values) => setSelectedRoleCodes(values)}
options={roles.map((role) => ({
value: role.roleCode,
label: `${role.roleName} (${role.roleCode})`,
}))}
/>
)}
</Modal>
</Space>
);
}

View File

@@ -1,11 +1,11 @@
import { get, post, put, del } from "@/utils/request"; import { get, post, put, del } from "@/utils/request";
// 模型相关接口 // 模型相关接口
export function queryModelProvidersUsingGet(params?: any) { export function queryModelProvidersUsingGet(params?: Record<string, unknown>) {
return get("/api/models/providers", params); return get("/api/models/providers", params);
} }
export function queryModelListUsingGet(data: any) { export function queryModelListUsingGet(data: Record<string, unknown>) {
return get("/api/models/list", data); return get("/api/models/list", data);
} }
@@ -15,12 +15,12 @@ export function queryModelDetailByIdUsingGet(id: string | number) {
export function updateModelByIdUsingPut( export function updateModelByIdUsingPut(
id: string | number, id: string | number,
data: any data: Record<string, unknown>
) { ) {
return put(`/api/models/${id}`, data); return put(`/api/models/${id}`, data);
} }
export function createModelUsingPost(data: any) { export function createModelUsingPost(data: Record<string, unknown>) {
return post("/api/models/create", data); return post("/api/models/create", data);
} }
@@ -28,13 +28,60 @@ export function deleteModelByIdUsingDelete(id: string | number) {
return del(`/api/models/${id}`); return del(`/api/models/${id}`);
} }
// 获取系统参数列表 // 获取系统参数列表
export function getSysParamList() { export function getSysParamList() {
return get('/api/sys-param/list'); return get("/api/sys-param/list");
} }
// 更新系统参数值 // 更新系统参数值
export const updateSysParamValue = async (params: { id: string; paramValue: string }) => { export const updateSysParamValue = async (params: {
id: string;
paramValue: string;
}) => {
return put(`/api/sys-param/${params.id}`, params); return put(`/api/sys-param/${params.id}`, params);
}; };
export interface AuthUserWithRoles {
id: number;
username: string;
fullName?: string;
email?: string;
enabled?: boolean;
roleCodes: string[];
}
export interface AuthRoleInfo {
id: string;
roleCode: string;
roleName: string;
description?: string;
enabled?: boolean;
}
export interface AuthPermissionInfo {
id: string;
permissionCode: string;
permissionName: string;
module?: string;
action?: string;
pathPattern?: string;
method?: string;
enabled?: boolean;
}
// 用户与权限管理接口
export function listAuthUsersUsingGet() {
return get("/api/auth/users");
}
export function listAuthRolesUsingGet() {
return get("/api/auth/roles");
}
export function listAuthPermissionsUsingGet() {
return get("/api/auth/permissions");
}
export function assignUserRolesUsingPut(userId: number, roleIds: string[]) {
return put(`/api/auth/users/${userId}/roles`, { roleIds });
}

View File

@@ -17,12 +17,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db from app.db.session import get_db
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from ..schema.auto import ( from ..security import (
CreateAutoAnnotationTaskRequest, RequestUserContext,
AutoAnnotationTaskResponse, assert_dataset_access,
get_request_user_context,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
) )
from ..service.auto import AutoAnnotationTaskService from ..service.auto import AutoAnnotationTaskService
@@ -37,15 +42,16 @@ service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]]) @router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks( async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取自动标注任务列表。 """获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。 前端当前不传分页参数,这里直接返回所有未删除任务。
""" """
tasks = await service.list_tasks(db) tasks = await service.list_tasks(db, user_context)
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
@@ -54,28 +60,30 @@ async def list_auto_annotation_tasks(
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task( async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest, request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""创建自动标注任务。 """创建自动标注任务。
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。 当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
""" """
logger.info( logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s", "Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name, request.name,
request.dataset_id, request.dataset_id,
request.config.model_dump(by_alias=True), request.config.model_dump(by_alias=True),
request.file_ids, request.file_ids,
) )
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建 # 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None dataset_name = None
total_images = 0 total_images = 0
try: await assert_dataset_access(db, request.dataset_id, user_context)
dm_client = DatasetManagementService(db) try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount # Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
dataset = await dm_client.get_dataset(request.dataset_id) dataset = await dm_client.get_dataset(request.dataset_id)
if dataset is not None: if dataset is not None:
@@ -103,16 +111,17 @@ async def create_auto_annotation_task(
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status( async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个自动标注任务状态。 """获取单个自动标注任务状态。
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。 前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
""" """
task = await service.get_task(db, task_id) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -124,13 +133,14 @@ async def get_auto_annotation_task_status(
@router.delete("/{task_id}", response_model=StandardResponse[bool]) @router.delete("/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task( async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。""" """删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id) ok = await service.soft_delete_task(db, task_id, user_context)
if not ok: if not ok:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -142,10 +152,11 @@ async def delete_auto_annotation_task(
@router.get("/{task_id}/download") @router.get("/{task_id}/download")
async def download_auto_annotation_result( async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。""" """下载指定自动标注任务的结果 ZIP。"""
import io import io
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
import tempfile import tempfile
# 复用服务层获取任务信息 # 复用服务层获取任务信息
task = await service.get_task(db, task_id) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
UpsertAnnotationResponse, UpsertAnnotationResponse,
) )
from app.module.annotation.service.editor import AnnotationEditorService from app.module.annotation.service.editor import AnnotationEditorService
from app.module.annotation.security import (
RequestUserContext,
get_request_user_context,
)
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -44,8 +48,9 @@ router = APIRouter(
async def get_editor_project_info( async def get_editor_project_info(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
info = await service.get_project_info(project_id) info = await service.get_project_info(project_id)
return StandardResponse(code=200, message="success", data=info) return StandardResponse(code=200, message="success", data=info)
@@ -64,8 +69,9 @@ async def list_editor_tasks(
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)", description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.list_tasks( result = await service.list_tasks(
project_id, project_id,
page=page, page=page,
@@ -86,8 +92,9 @@ async def get_editor_task(
None, alias="segmentIndex", description="段落索引(分段模式下使用)" None, alias="segmentIndex", description="段落索引(分段模式下使用)"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
task = await service.get_task(project_id, file_id, segment_index=segment_index) task = await service.get_task(project_id, file_id, segment_index=segment_index)
return StandardResponse(code=200, message="success", data=task) return StandardResponse(code=200, message="success", data=task)
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)" ..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.get_task_segment(project_id, file_id, segment_index) result = await service.get_task_segment(project_id, file_id, segment_index)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.upsert_annotation(project_id, file_id, request) result = await service.upsert_annotation(project_id, file_id, request)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -132,11 +141,12 @@ async def check_file_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
检查文件是否有新版本 检查文件是否有新版本
""" """
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.check_file_version(project_id, file_id) result = await service.check_file_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -149,10 +159,11 @@ async def use_new_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
使用文件新版本并清空标注 使用文件新版本并清空标注
""" """
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.use_new_version(project_id, file_id) result = await service.use_new_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)

View File

@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService from ..service.template import AnnotationTemplateService
from ..service.knowledge_sync import KnowledgeSyncService from ..service.knowledge_sync import KnowledgeSyncService
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201 "", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
) )
async def create_mapping( async def create_mapping(
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db) request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
创建数据集映射 创建数据集映射
@@ -58,6 +65,8 @@ async def create_mapping(
mapping_service = DatasetMappingService(db) mapping_service = DatasetMappingService(db)
template_service = AnnotationTemplateService() template_service = AnnotationTemplateService()
await assert_dataset_access(db, request.dataset_id, user_context)
logger.info(f"Create dataset mapping request: {request.dataset_id}") logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息 # 从DM服务获取数据集信息
@@ -163,7 +172,7 @@ async def create_mapping(
try: try:
from ..service.editor import AnnotationEditorService from ..service.editor import AnnotationEditorService
editor_service = AnnotationEditorService(db) editor_service = AnnotationEditorService(db, user_context)
# 异步预计算切片(不阻塞创建响应) # 异步预计算切片(不阻塞创建响应)
segmentation_result = ( segmentation_result = (
await editor_service.precompute_segmentation_for_project( await editor_service.precompute_segmentation_for_project(
@@ -202,6 +211,7 @@ async def list_mappings(
False, description="是否包含模板详情", alias="includeTemplate" False, description="是否包含模板详情", alias="includeTemplate"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
查询所有映射关系(分页) 查询所有映射关系(分页)
@@ -230,6 +240,8 @@ async def list_mappings(
limit=size, limit=size,
include_deleted=False, include_deleted=False,
include_template=include_template, include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
) )
# 计算总页数 # 计算总页数
@@ -256,7 +268,11 @@ async def list_mappings(
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse]) @router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)): async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
""" """
根据 UUID 查询单个映射关系(包含关联的标注模板详情) 根据 UUID 查询单个映射关系(包含关联的标注模板详情)
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping not found: {mapping_id}" status_code=404, detail=f"Mapping not found: {mapping_id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
logger.info( logger.info(
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}" f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
True, description="是否包含模板详情", alias="includeTemplate" True, description="是否包含模板详情", alias="includeTemplate"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
根据源数据集 ID 查询所有映射关系(分页,包含模板详情) 根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
""" """
try: try:
service = DatasetMappingService(db) service = DatasetMappingService(db)
await assert_dataset_access(db, dataset_id, user_context)
# 计算 skip # 计算 skip
skip = (page - 1) * size skip = (page - 1) * size
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
skip=skip, skip=skip,
limit=size, limit=size,
include_template=include_template, include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
) )
# 计算总页数 # 计算总页数
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
async def delete_mapping( async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
删除映射关系(软删除) 删除映射关系(软删除)
@@ -387,6 +409,7 @@ async def delete_mapping(
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping either not found or not specified." status_code=404, detail=f"Mapping either not found or not specified."
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
id = mapping.id id = mapping.id
dataset_id = mapping.dataset_id dataset_id = mapping.dataset_id
@@ -428,6 +451,7 @@ async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None, request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
更新标注项目信息 更新标注项目信息
@@ -456,6 +480,7 @@ async def update_mapping(
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping not found: {project_id}" status_code=404, detail=f"Mapping not found: {project_id}"
) )
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
# 构建更新数据 # 构建更新数据
update_values = {} update_values = {}

View File

@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.config import settings from app.core.config import settings
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService
from ..schema import ( from ..schema import (
SyncDatasetRequest, SyncDatasetRequest,
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse]) @router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content( async def sync_dataset_content(
request: SyncDatasetRequest, request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Sync Dataset Content (Files and Annotations) Sync Dataset Content (Files and Annotations)
@@ -51,6 +57,7 @@ async def sync_dataset_content(
status_code=404, status_code=404,
detail=f"Mapping not found: {request.id}" detail=f"Mapping not found: {request.id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
dm_client = DatasetManagementService(db) dm_client = DatasetManagementService(db)
dataset_info = await dm_client.get_dataset(mapping.dataset_id) dataset_info = await dm_client.get_dataset(mapping.dataset_id)
@@ -82,7 +89,8 @@ async def sync_dataset_content(
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse]) @router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
async def sync_annotations( async def sync_annotations(
request: SyncAnnotationsRequest, request: SyncAnnotationsRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Sync Annotations Only (Bidirectional Support) Sync Annotations Only (Bidirectional Support)
@@ -102,6 +110,7 @@ async def sync_annotations(
status_code=404, status_code=404,
detail=f"Mapping not found: {request.id}" detail=f"Mapping not found: {request.id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
result = SyncAnnotationsResponse( result = SyncAnnotationsResponse(
id=mapping.id, id=mapping.id,
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
async def update_file_tags( async def update_file_tags(
request: UpdateFileTagsRequest, request: UpdateFileTagsRequest,
file_id: str = Path(..., description="文件ID"), file_id: str = Path(..., description="文件ID"),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Update File Tags (Partial Update with Auto Format Conversion) Update File Tags (Partial Update with Auto Format Conversion)
@@ -189,6 +199,7 @@ async def update_file_tags(
raise HTTPException(status_code=404, detail=f"File not found: {file_id}") raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
await assert_dataset_access(db, dataset_id, user_context)
# 查找数据集关联的模板ID # 查找数据集关联的模板ID
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
from fastapi import HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.dataset_management import Dataset
HEADER_USER_ID = "X-User-Id"
HEADER_USER_NAME = "X-User-Name"
HEADER_USER_ROLES = "X-User-Roles"
ADMIN_ROLE_CODE = "ROLE_ADMIN"
@dataclass(frozen=True)
class RequestUserContext:
user_id: str
username: str | None
roles: Tuple[str, ...]
@property
def is_admin(self) -> bool:
return any(role.upper() == ADMIN_ROLE_CODE for role in self.roles)
def get_request_user_context(request: Request) -> RequestUserContext:
user_id = (request.headers.get(HEADER_USER_ID) or "").strip()
username = (request.headers.get(HEADER_USER_NAME) or "").strip() or None
role_header = request.headers.get(HEADER_USER_ROLES) or ""
roles = tuple(
role.strip()
for role in role_header.split(",")
if role and role.strip()
)
if not user_id:
raise HTTPException(status_code=403, detail="权限不足:缺少用户身份")
return RequestUserContext(user_id=user_id, username=username, roles=roles)
def ensure_dataset_owner_access(
user_context: RequestUserContext,
dataset_owner_user_id: str | None,
dataset_id: str,
) -> None:
if user_context.is_admin:
return
if not dataset_owner_user_id or dataset_owner_user_id != user_context.user_id:
raise HTTPException(
status_code=403,
detail=f"无权访问数据集: {dataset_id}",
)
async def assert_dataset_access(
db: AsyncSession,
dataset_id: str,
user_context: RequestUserContext,
) -> None:
owner_result = await db.execute(
select(Dataset.created_by).where(Dataset.id == dataset_id)
)
dataset_owner = owner_result.scalar_one_or_none()
if dataset_owner is None:
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
ensure_dataset_owner_access(user_context, str(dataset_owner), dataset_id)

View File

@@ -5,11 +5,12 @@ from typing import List, Optional
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
from ..schema.auto import ( from ..schema.auto import (
CreateAutoAnnotationTaskRequest, CreateAutoAnnotationTaskRequest,
@@ -17,7 +18,7 @@ from ..schema.auto import (
) )
class AutoAnnotationTaskService: class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)""" """自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
async def create_task( async def create_task(
@@ -63,15 +64,27 @@ class AutoAnnotationTaskService:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id] resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp return resp
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]: def _apply_dataset_scope(self, query, user_context: RequestUserContext):
"""获取未软删除的自动标注任务列表,按创建时间倒序。""" if user_context.is_admin:
return query
result = await db.execute( return query.join(
select(AutoAnnotationTask) Dataset,
.where(AutoAnnotationTask.deleted_at.is_(None)) AutoAnnotationTask.dataset_id == Dataset.id,
.order_by(AutoAnnotationTask.created_at.desc()) ).where(Dataset.created_by == user_context.user_id)
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all()) async def list_tasks(
self,
db: AsyncSession,
user_context: RequestUserContext,
) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(
query.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = [] responses: List[AutoAnnotationTaskResponse] = []
for task in tasks: for task in tasks:
@@ -87,16 +100,21 @@ class AutoAnnotationTaskService:
return responses return responses
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]: async def get_task(
result = await db.execute( self,
select(AutoAnnotationTask).where( db: AsyncSession,
AutoAnnotationTask.id == task_id, task_id: str,
AutoAnnotationTask.deleted_at.is_(None), user_context: RequestUserContext,
) ) -> Optional[AutoAnnotationTaskResponse]:
) query = select(AutoAnnotationTask).where(
task = result.scalar_one_or_none() AutoAnnotationTask.id == task_id,
if not task: AutoAnnotationTask.deleted_at.is_(None),
return None )
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return None
resp = AutoAnnotationTaskResponse.model_validate(task) resp = AutoAnnotationTaskResponse.model_validate(task)
try: try:
@@ -138,16 +156,21 @@ class AutoAnnotationTaskService:
return [task.dataset_id] return [task.dataset_id]
return [] return []
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool: async def soft_delete_task(
result = await db.execute( self,
select(AutoAnnotationTask).where( db: AsyncSession,
AutoAnnotationTask.id == task_id, task_id: str,
AutoAnnotationTask.deleted_at.is_(None), user_context: RequestUserContext,
) ) -> bool:
) query = select(AutoAnnotationTask).where(
task = result.scalar_one_or_none() AutoAnnotationTask.id == task_id,
if not task: AutoAnnotationTask.deleted_at.is_(None),
return False )
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return False
task.deleted_at = datetime.now() task.deleted_at = datetime.now()
await db.commit() await db.commit()

View File

@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
from app.module.annotation.service.annotation_text_splitter import ( from app.module.annotation.service.annotation_text_splitter import (
AnnotationTextSplitter, AnnotationTextSplitter,
) )
from app.module.annotation.security import (
RequestUserContext,
ensure_dataset_owner_access,
)
from app.module.annotation.service.text_fetcher import ( from app.module.annotation.service.text_fetcher import (
fetch_text_content_via_download_api, fetch_text_content_via_download_api,
) )
@@ -104,8 +108,9 @@ class AnnotationEditorService:
# 分段阈值:超过此字符数自动分段 # 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200 SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession): def __init__(self, db: AsyncSession, user_context: RequestUserContext):
self.db = db self.db = db
self.user_context = user_context
self.template_service = AnnotationTemplateService() self.template_service = AnnotationTemplateService()
@staticmethod @staticmethod
@@ -157,14 +162,24 @@ class AnnotationEditorService:
async def _get_project_or_404(self, project_id: str) -> LabelingProject: async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute( result = await self.db.execute(
select(LabelingProject).where( select(LabelingProject, Dataset.created_by).join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(
LabelingProject.id == project_id, LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None), LabelingProject.deleted_at.is_(None),
) )
) )
project = result.scalar_one_or_none() row = result.first()
if not project: if not row:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}") raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
project = row[0]
dataset_owner = row[1]
ensure_dataset_owner_access(
self.user_context,
str(dataset_owner) if dataset_owner is not None else None,
project.dataset_id,
)
return project return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]: async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:

View File

@@ -478,7 +478,9 @@ class DatasetMappingService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
include_deleted: bool = False, include_deleted: bool = False,
include_template: bool = False include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]: ) -> Tuple[List[DatasetMappingResponse], int]:
""" """
获取所有映射及总数(用于分页) 获取所有映射及总数(用于分页)
@@ -495,9 +497,16 @@ class DatasetMappingService:
query = self._build_query_with_dataset_name() query = self._build_query_with_dataset_name()
if not include_deleted: if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(LabelingProject) count_query = select(func.count()).select_from(LabelingProject)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted: if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))
@@ -557,7 +566,9 @@ class DatasetMappingService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
include_deleted: bool = False, include_deleted: bool = False,
include_template: bool = False include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]: ) -> Tuple[List[DatasetMappingResponse], int]:
""" """
根据源数据集ID获取映射关系及总数(用于分页) 根据源数据集ID获取映射关系及总数(用于分页)
@@ -578,11 +589,18 @@ class DatasetMappingService:
if not include_deleted: if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(LabelingProject).where( count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id LabelingProject.dataset_id == dataset_id
) )
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted: if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))