You've already forked DataMate
feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证 - 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证 - 修改DatasetRepository接口和实现类,增加按创建者过滤的方法 - 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法 - 在RAG索引器服务中添加知识库访问权限检查和作用域过滤 - 更新实体元对象处理器以使用请求用户上下文获取当前用户 - 在前端设置页面添加用户权限管理功能和角色权限控制 - 为Python标注服务增加用户上下文和数据集访问权限验证
This commit is contained in:
@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
|
|||||||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import com.datamate.common.auth.application.ResourceAccessService;
|
||||||
import com.datamate.common.domain.utils.ChunksSaver;
|
import com.datamate.common.domain.utils.ChunksSaver;
|
||||||
import com.datamate.common.setting.application.SysParamApplicationService;
|
import com.datamate.common.setting.application.SysParamApplicationService;
|
||||||
import com.datamate.datamanagement.interfaces.dto.*;
|
import com.datamate.datamanagement.interfaces.dto.*;
|
||||||
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
|
|||||||
private final CollectionTaskClient collectionTaskClient;
|
private final CollectionTaskClient collectionTaskClient;
|
||||||
private final DatasetFileApplicationService datasetFileApplicationService;
|
private final DatasetFileApplicationService datasetFileApplicationService;
|
||||||
private final SysParamApplicationService sysParamService;
|
private final SysParamApplicationService sysParamService;
|
||||||
|
private final ResourceAccessService resourceAccessService;
|
||||||
|
|
||||||
@Value("${datamate.data-management.base-path:/dataset}")
|
@Value("${datamate.data-management.base-path:/dataset}")
|
||||||
private String datasetBasePath;
|
private String datasetBasePath;
|
||||||
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
|
|||||||
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
|
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
|
||||||
Dataset dataset = datasetRepository.getById(datasetId);
|
Dataset dataset = datasetRepository.getById(datasetId);
|
||||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||||
|
|
||||||
if (StringUtils.hasText(updateDatasetRequest.getName())) {
|
if (StringUtils.hasText(updateDatasetRequest.getName())) {
|
||||||
dataset.setName(updateDatasetRequest.getName());
|
dataset.setName(updateDatasetRequest.getName());
|
||||||
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
|
|||||||
public void deleteDataset(String datasetId) {
|
public void deleteDataset(String datasetId) {
|
||||||
Dataset dataset = datasetRepository.getById(datasetId);
|
Dataset dataset = datasetRepository.getById(datasetId);
|
||||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||||
long childCount = datasetRepository.countByParentId(datasetId);
|
long childCount = datasetRepository.countByParentId(datasetId);
|
||||||
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
|
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
|
||||||
datasetRepository.removeById(datasetId);
|
datasetRepository.removeById(datasetId);
|
||||||
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
|
|||||||
public Dataset getDataset(String datasetId) {
|
public Dataset getDataset(String datasetId) {
|
||||||
Dataset dataset = datasetRepository.getById(datasetId);
|
Dataset dataset = datasetRepository.getById(datasetId);
|
||||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||||
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
|
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
|
||||||
dataset.setFiles(datasetFiles);
|
dataset.setFiles(datasetFiles);
|
||||||
applyVisibleFileCounts(Collections.singletonList(dataset));
|
applyVisibleFileCounts(Collections.singletonList(dataset));
|
||||||
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
|
|||||||
@Transactional(readOnly = true)
|
@Transactional(readOnly = true)
|
||||||
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
|
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
|
||||||
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
|
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
|
||||||
page = datasetRepository.findByCriteria(page, query);
|
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||||
|
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
|
||||||
String datasetPvcName = getDatasetPvcName();
|
String datasetPvcName = getDatasetPvcName();
|
||||||
applyVisibleFileCounts(page.getRecords());
|
applyVisibleFileCounts(page.getRecords());
|
||||||
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
|
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
|
||||||
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
|
|||||||
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
|
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
|
||||||
Dataset dataset = datasetRepository.getById(datasetId);
|
Dataset dataset = datasetRepository.getById(datasetId);
|
||||||
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||||
Set<String> sourceTags = normalizeTagNames(dataset.getTags());
|
Set<String> sourceTags = normalizeTagNames(dataset.getTags());
|
||||||
if (sourceTags.isEmpty()) {
|
if (sourceTags.isEmpty()) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
|
|||||||
SIMILAR_DATASET_CANDIDATE_MAX,
|
SIMILAR_DATASET_CANDIDATE_MAX,
|
||||||
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
|
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
|
||||||
);
|
);
|
||||||
|
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||||
List<Dataset> candidates = datasetRepository.findSimilarByTags(
|
List<Dataset> candidates = datasetRepository.findSimilarByTags(
|
||||||
new ArrayList<>(sourceTags),
|
new ArrayList<>(sourceTags),
|
||||||
datasetId,
|
datasetId,
|
||||||
candidateLimit
|
candidateLimit,
|
||||||
|
ownerFilterUserId
|
||||||
);
|
);
|
||||||
if (CollectionUtils.isEmpty(candidates)) {
|
if (CollectionUtils.isEmpty(candidates)) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
|
|||||||
if (dataset == null) {
|
if (dataset == null) {
|
||||||
throw new IllegalArgumentException("Dataset not found: " + datasetId);
|
throw new IllegalArgumentException("Dataset not found: " + datasetId);
|
||||||
}
|
}
|
||||||
|
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
|
||||||
|
|
||||||
Map<String, Object> statistics = new HashMap<>();
|
Map<String, Object> statistics = new HashMap<>();
|
||||||
|
|
||||||
@@ -485,7 +495,11 @@ public class DatasetApplicationService {
|
|||||||
* 获取所有数据集的汇总统计信息
|
* 获取所有数据集的汇总统计信息
|
||||||
*/
|
*/
|
||||||
public AllDatasetStatisticsResponse getAllDatasetStatistics() {
|
public AllDatasetStatisticsResponse getAllDatasetStatistics() {
|
||||||
return datasetRepository.getAllDatasetStatistics();
|
if (resourceAccessService.isAdmin()) {
|
||||||
|
return datasetRepository.getAllDatasetStatistics();
|
||||||
|
}
|
||||||
|
String currentUserId = resourceAccessService.requireCurrentUserId();
|
||||||
|
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
|
|||||||
|
|
||||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import com.datamate.common.auth.application.ResourceAccessService;
|
||||||
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||||
import com.datamate.common.infrastructure.exception.CommonErrorCode;
|
import com.datamate.common.infrastructure.exception.CommonErrorCode;
|
||||||
import com.datamate.common.interfaces.PagedResponse;
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
@@ -40,6 +41,7 @@ import java.util.UUID;
|
|||||||
public class KnowledgeSetApplicationService {
|
public class KnowledgeSetApplicationService {
|
||||||
private final KnowledgeSetRepository knowledgeSetRepository;
|
private final KnowledgeSetRepository knowledgeSetRepository;
|
||||||
private final TagMapper tagMapper;
|
private final TagMapper tagMapper;
|
||||||
|
private final ResourceAccessService resourceAccessService;
|
||||||
|
|
||||||
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
|
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
|
||||||
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
|
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
|
||||||
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
|
|||||||
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
|
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
|
||||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||||
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
|
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
|
||||||
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
|
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
|
||||||
|
|
||||||
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
|
|||||||
public void deleteKnowledgeSet(String setId) {
|
public void deleteKnowledgeSet(String setId) {
|
||||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||||
knowledgeSetRepository.removeById(setId);
|
knowledgeSetRepository.removeById(setId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
|
|||||||
public KnowledgeSet getKnowledgeSet(String setId) {
|
public KnowledgeSet getKnowledgeSet(String setId) {
|
||||||
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
|
||||||
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
|
||||||
|
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
|
||||||
return knowledgeSet;
|
return knowledgeSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional(readOnly = true)
|
@Transactional(readOnly = true)
|
||||||
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
|
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
|
||||||
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
|
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
|
||||||
page = knowledgeSetRepository.findByCriteria(page, query);
|
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||||
|
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
|
||||||
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
|
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
|
||||||
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
|
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
|
|||||||
|
|
||||||
AllDatasetStatisticsResponse getAllDatasetStatistics();
|
AllDatasetStatisticsResponse getAllDatasetStatistics();
|
||||||
|
|
||||||
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query);
|
AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
|
||||||
|
|
||||||
|
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
|
||||||
|
|
||||||
long countByParentId(String parentDatasetId);
|
long countByParentId(String parentDatasetId);
|
||||||
|
|
||||||
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit);
|
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
|
|||||||
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
|
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
|
||||||
KnowledgeSet findByName(String name);
|
KnowledgeSet findByName(String name);
|
||||||
|
|
||||||
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query);
|
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
|||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) {
|
public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
|
||||||
|
List<Dataset> datasets = lambdaQuery()
|
||||||
|
.eq(Dataset::getCreatedBy, createdBy)
|
||||||
|
.list();
|
||||||
|
long totalFiles = datasets.stream()
|
||||||
|
.map(Dataset::getFileCount)
|
||||||
|
.filter(java.util.Objects::nonNull)
|
||||||
|
.mapToLong(Long::longValue)
|
||||||
|
.sum();
|
||||||
|
long totalSize = datasets.stream()
|
||||||
|
.map(Dataset::getSizeBytes)
|
||||||
|
.filter(java.util.Objects::nonNull)
|
||||||
|
.mapToLong(Long::longValue)
|
||||||
|
.sum();
|
||||||
|
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
|
||||||
|
response.setTotalDatasets(datasets.size());
|
||||||
|
response.setTotalFiles(totalFiles);
|
||||||
|
response.setTotalSize(totalSize);
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
|
||||||
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
|
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
|
||||||
.eq(query.getType() != null, Dataset::getDatasetType, query.getType())
|
.eq(query.getType() != null, Dataset::getDatasetType, query.getType())
|
||||||
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus());
|
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
|
||||||
|
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
|
||||||
|
|
||||||
if (query.getParentDatasetId() != null) {
|
if (query.getParentDatasetId() != null) {
|
||||||
if (StringUtils.isBlank(query.getParentDatasetId())) {
|
if (StringUtils.isBlank(query.getParentDatasetId())) {
|
||||||
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) {
|
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
|
||||||
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
|
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
|
|||||||
if (StringUtils.isNotBlank(excludedDatasetId)) {
|
if (StringUtils.isNotBlank(excludedDatasetId)) {
|
||||||
wrapper.ne(Dataset::getId, excludedDatasetId.trim());
|
wrapper.ne(Dataset::getId, excludedDatasetId.trim());
|
||||||
}
|
}
|
||||||
|
if (StringUtils.isNotBlank(createdBy)) {
|
||||||
|
wrapper.eq(Dataset::getCreatedBy, createdBy);
|
||||||
|
}
|
||||||
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
|
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
|
||||||
wrapper.and(condition -> {
|
wrapper.and(condition -> {
|
||||||
boolean hasCondition = false;
|
boolean hasCondition = false;
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) {
|
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
|
||||||
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
|
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
|
||||||
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
|
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
|
||||||
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
|
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
|
||||||
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
|
|||||||
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
|
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
|
||||||
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
|
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
|
||||||
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
|
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
|
||||||
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo());
|
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
|
||||||
|
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
|
||||||
|
|
||||||
if (StringUtils.isNotBlank(query.getKeyword())) {
|
if (StringUtils.isNotBlank(query.getKeyword())) {
|
||||||
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())
|
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())
|
||||||
|
|||||||
@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
|
|||||||
|
|
||||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import com.datamate.common.auth.application.ResourceAccessService;
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||||
import com.datamate.common.infrastructure.exception.BusinessException;
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
|
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
|
||||||
|
import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
||||||
import com.datamate.common.interfaces.PagedResponse;
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
import com.datamate.common.interfaces.PagingQuery;
|
import com.datamate.common.interfaces.PagingQuery;
|
||||||
import com.datamate.common.setting.domain.entity.ModelConfig;
|
import com.datamate.common.setting.domain.entity.ModelConfig;
|
||||||
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
|
|||||||
private final ApplicationEventPublisher eventPublisher;
|
private final ApplicationEventPublisher eventPublisher;
|
||||||
private final ModelConfigRepository modelConfigRepository;
|
private final ModelConfigRepository modelConfigRepository;
|
||||||
private final MilvusService milvusService;
|
private final MilvusService milvusService;
|
||||||
|
private final ResourceAccessService resourceAccessService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建知识库
|
* 创建知识库
|
||||||
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
|
|||||||
*/
|
*/
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
|
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
|
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
|
||||||
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
|
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
|
||||||
.collectionName(knowledgeBase.getName())
|
.collectionName(knowledgeBase.getName())
|
||||||
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
|
|||||||
*/
|
*/
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void delete(String knowledgeBaseId) {
|
public void delete(String knowledgeBaseId) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
knowledgeBaseRepository.removeById(knowledgeBaseId);
|
knowledgeBaseRepository.removeById(knowledgeBaseId);
|
||||||
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
|
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
|
||||||
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
|
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
public KnowledgeBaseResp getById(String knowledgeBaseId) {
|
public KnowledgeBaseResp getById(String knowledgeBaseId) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
|
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
|
||||||
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
|
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
|
||||||
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
|
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
|
||||||
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
|
|||||||
|
|
||||||
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
|
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
|
||||||
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
|
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
|
||||||
page = knowledgeBaseRepository.page(page, request);
|
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
|
||||||
|
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
|
||||||
|
|
||||||
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
|
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
|
||||||
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
|
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
|
||||||
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
|
|||||||
|
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void addFiles(AddFilesReq request) {
|
public void addFiles(AddFilesReq request) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId()))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
|
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
|
||||||
RagFile ragFile = new RagFile();
|
RagFile ragFile = new RagFile();
|
||||||
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
|
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
|
||||||
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
|
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
|
||||||
|
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
||||||
request.setKnowledgeBaseId(knowledgeBaseId);
|
request.setKnowledgeBaseId(knowledgeBaseId);
|
||||||
page = ragFileRepository.page(page, request);
|
page = ragFileRepository.page(page, request);
|
||||||
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
|
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
|
||||||
|
boolean admin = resourceAccessService.isAdmin();
|
||||||
|
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
|
||||||
|
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
|
||||||
|
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
|
||||||
|
}
|
||||||
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
|
||||||
page = ragFileRepository.searchPage(page, request);
|
page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
|
||||||
List<RagFile> records = page.getRecords();
|
List<RagFile> records = page.getRecords();
|
||||||
if (records.isEmpty()) {
|
if (records.isEmpty()) {
|
||||||
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
|
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
|
||||||
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
|
|||||||
|
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
|
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
ragFileRepository.removeByIds(request.getIds());
|
ragFileRepository.removeByIds(request.getIds());
|
||||||
milvusService.getMilvusClient().delete(DeleteReq.builder()
|
milvusService.getMilvusClient().delete(DeleteReq.builder()
|
||||||
.collectionName(knowledgeBase.getName())
|
.collectionName(knowledgeBase.getName())
|
||||||
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
|
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
|
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
|
||||||
.collectionName(knowledgeBase.getName())
|
.collectionName(knowledgeBase.getName())
|
||||||
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
|
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
|
||||||
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
|
|||||||
* @return 检索结果
|
* @return 检索结果
|
||||||
*/
|
*/
|
||||||
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
|
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
|
||||||
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
|
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
|
||||||
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
|
||||||
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
|
||||||
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
|
||||||
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
|
|||||||
});
|
});
|
||||||
return searchResults;
|
return searchResults;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
|
||||||
|
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
|
||||||
|
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
|
||||||
|
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
|
||||||
|
return knowledgeBase;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
|
||||||
|
if (admin) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
String currentUserId = resourceAccessService.requireCurrentUserId();
|
||||||
|
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
|
||||||
|
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
|
||||||
|
return ownedKnowledgeBaseIds;
|
||||||
|
}
|
||||||
|
BusinessAssert.isTrue(
|
||||||
|
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
|
||||||
|
SystemErrorCode.INSUFFICIENT_PERMISSIONS
|
||||||
|
);
|
||||||
|
return Collections.singletonList(request.getKnowledgeBaseId());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
|
|||||||
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
|
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
|
||||||
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
|
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库仓储接口
|
* 知识库仓储接口
|
||||||
*
|
*
|
||||||
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
|
|||||||
* @param request 查询请求
|
* @param request 查询请求
|
||||||
* @return 知识库分页结果
|
* @return 知识库分页结果
|
||||||
*/
|
*/
|
||||||
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request);
|
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
|
||||||
|
|
||||||
|
List<String> listIdsByCreatedBy(String createdBy);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
|
|||||||
|
|
||||||
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
|
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
|
||||||
|
|
||||||
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request);
|
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
|
|||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库仓储实现类
|
* 知识库仓储实现类
|
||||||
*
|
*
|
||||||
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
|
|||||||
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
|
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) {
|
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
|
||||||
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
|
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
|
||||||
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
|
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
|
||||||
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
|
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
|
||||||
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
|
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
|
||||||
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
|
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
|
||||||
|
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
|
||||||
.orderByDesc(KnowledgeBase::getCreatedAt));
|
.orderByDesc(KnowledgeBase::getCreatedAt));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> listIdsByCreatedBy(String createdBy) {
|
||||||
|
if (!StringUtils.hasText(createdBy)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
return lambdaQuery()
|
||||||
|
.select(KnowledgeBase::getId)
|
||||||
|
.eq(KnowledgeBase::getCreatedBy, createdBy)
|
||||||
|
.list()
|
||||||
|
.stream()
|
||||||
|
.map(KnowledgeBase::getId)
|
||||||
|
.filter(StringUtils::hasText)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) {
|
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
|
||||||
return lambdaQuery()
|
return lambdaQuery()
|
||||||
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
|
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
|
||||||
|
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
|
||||||
|
RagFile::getKnowledgeBaseId,
|
||||||
|
knowledgeBaseIds)
|
||||||
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
|
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
|
||||||
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
|
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
|
||||||
.page(page);
|
.page(page);
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package com.datamate.common.auth.application;
|
||||||
|
|
||||||
|
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessAssert;
|
||||||
|
import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 资源访问控制服务(基于请求用户上下文)
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class ResourceAccessService {
|
||||||
|
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
|
||||||
|
|
||||||
|
public boolean isAdmin() {
|
||||||
|
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getCurrentUserId() {
|
||||||
|
return RequestUserContextHolder.getCurrentUserId();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String requireCurrentUserId() {
|
||||||
|
String currentUserId = getCurrentUserId();
|
||||||
|
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
|
||||||
|
return currentUserId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 资源列表查询的 owner 过滤:
|
||||||
|
* - 管理员返回 null(不过滤)
|
||||||
|
* - 非管理员返回当前用户ID
|
||||||
|
*/
|
||||||
|
public String resolveOwnerFilterUserId() {
|
||||||
|
if (isAdmin()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return requireCurrentUserId();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验当前用户是否可访问 owner 资源
|
||||||
|
*/
|
||||||
|
public void assertOwnerAccess(String ownerUserId) {
|
||||||
|
if (isAdmin()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String currentUserId = requireCurrentUserId();
|
||||||
|
BusinessAssert.isTrue(
|
||||||
|
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
|
||||||
|
SystemErrorCode.INSUFFICIENT_PERMISSIONS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package com.datamate.common.auth.infrastructure.context;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 请求级用户上下文
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class RequestUserContext {
|
||||||
|
private final String userId;
|
||||||
|
private final String username;
|
||||||
|
private final List<String> roles;
|
||||||
|
|
||||||
|
private RequestUserContext(String userId, String username, List<String> roles) {
|
||||||
|
this.userId = userId;
|
||||||
|
this.username = username;
|
||||||
|
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RequestUserContext of(String userId, String username, List<String> roles) {
|
||||||
|
return new RequestUserContext(userId, username, roles);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RequestUserContext empty() {
|
||||||
|
return new RequestUserContext(null, null, Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean hasRole(String roleCode) {
|
||||||
|
if (!StringUtils.hasText(roleCode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package com.datamate.common.auth.infrastructure.context;
|
||||||
|
|
||||||
|
import org.springframework.core.NamedThreadLocal;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 请求级用户上下文持有器
|
||||||
|
*/
|
||||||
|
public final class RequestUserContextHolder {
|
||||||
|
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
|
||||||
|
new NamedThreadLocal<>("datamate-request-user-context");
|
||||||
|
|
||||||
|
private RequestUserContextHolder() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void set(RequestUserContext context) {
|
||||||
|
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RequestUserContext get() {
|
||||||
|
RequestUserContext context = USER_CONTEXT_HOLDER.get();
|
||||||
|
return context == null ? RequestUserContext.empty() : context;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentUserId() {
|
||||||
|
return get().getUserId();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<String> getCurrentRoles() {
|
||||||
|
List<String> roles = get().getRoles();
|
||||||
|
return roles == null ? Collections.emptyList() : roles;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean hasRole(String roleCode) {
|
||||||
|
if (!StringUtils.hasText(roleCode)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return getCurrentRoles().stream()
|
||||||
|
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void clear() {
|
||||||
|
USER_CONTEXT_HOLDER.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package com.datamate.common.auth.infrastructure.context;
|
||||||
|
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import org.springframework.web.servlet.HandlerInterceptor;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从网关透传请求头中提取用户上下文
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class RequestUserContextInterceptor implements HandlerInterceptor {
|
||||||
|
private static final String HEADER_USER_ID = "X-User-Id";
|
||||||
|
private static final String HEADER_USER_NAME = "X-User-Name";
|
||||||
|
private static final String HEADER_USER_ROLES = "X-User-Roles";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
|
||||||
|
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
|
||||||
|
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
|
||||||
|
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
|
||||||
|
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
|
||||||
|
RequestUserContextHolder.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String normalizeValue(String value) {
|
||||||
|
if (!StringUtils.hasText(value)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return value.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> parseRoleCodes(String roleHeader) {
|
||||||
|
if (!StringUtils.hasText(roleHeader)) {
|
||||||
|
return Collections.emptyList();
|
||||||
|
}
|
||||||
|
return Arrays.stream(roleHeader.split(","))
|
||||||
|
.map(String::trim)
|
||||||
|
.filter(StringUtils::hasText)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package com.datamate.common.auth.infrastructure.context;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
|
||||||
|
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 请求用户上下文拦截器注册
|
||||||
|
*/
|
||||||
|
@Configuration
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
|
||||||
|
private final RequestUserContextInterceptor requestUserContextInterceptor;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addInterceptors(InterceptorRegistry registry) {
|
||||||
|
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
package com.datamate.common.infrastructure.config;
|
package com.datamate.common.infrastructure.config;
|
||||||
|
|
||||||
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
|
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
|
||||||
|
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.ibatis.reflection.MetaObject;
|
import org.apache.ibatis.reflection.MetaObject;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
|
|||||||
* 获取当前用户(需要根据你的安全框架实现)
|
* 获取当前用户(需要根据你的安全框架实现)
|
||||||
*/
|
*/
|
||||||
private String getCurrentUser() {
|
private String getCurrentUser() {
|
||||||
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等
|
String currentUserId = RequestUserContextHolder.getCurrentUserId();
|
||||||
// 示例:返回默认用户或从SecurityContext获取
|
if (StringUtils.hasText(currentUserId)) {
|
||||||
try {
|
return currentUserId;
|
||||||
// 如果是Spring Security
|
|
||||||
// return SecurityContextHolder.getContext().getAuthentication().getName();
|
|
||||||
|
|
||||||
// 临时返回默认值,请根据实际情况修改
|
|
||||||
return "system";
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error getting current user", e);
|
|
||||||
return "unknown";
|
|
||||||
}
|
}
|
||||||
|
return "system";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,51 @@
|
|||||||
import { useState } from "react";
|
import { useEffect, useMemo, useState } from "react";
|
||||||
import { Menu } from "antd";
|
import { Menu } from "antd";
|
||||||
import { SettingOutlined } from "@ant-design/icons";
|
import { SettingOutlined, TeamOutlined } from "@ant-design/icons";
|
||||||
import { Component } from "lucide-react";
|
import { Component } from "lucide-react";
|
||||||
import SystemConfig from "./SystemConfig";
|
import SystemConfig from "./SystemConfig";
|
||||||
import ModelAccess from "./ModelAccess";
|
import ModelAccess from "./ModelAccess";
|
||||||
|
import UserPermissionManagement from "./UserPermissionManagement";
|
||||||
|
import { useAppSelector } from "@/store/hooks";
|
||||||
|
import { hasPermission, PermissionCodes } from "@/auth/permissions";
|
||||||
|
|
||||||
export default function SettingsPage() {
|
export default function SettingsPage() {
|
||||||
const [activeTab, setActiveTab] = useState("model-access");
|
const permissions = useAppSelector((state) => state.auth.permissions);
|
||||||
|
const canManageUsers = hasPermission(permissions, PermissionCodes.userManage);
|
||||||
|
const canViewRoles = hasPermission(permissions, PermissionCodes.roleManage);
|
||||||
|
const canViewPermissions = hasPermission(
|
||||||
|
permissions,
|
||||||
|
PermissionCodes.permissionManage
|
||||||
|
);
|
||||||
|
const tabs = useMemo(() => {
|
||||||
|
const nextTabs = [
|
||||||
|
{
|
||||||
|
key: "model-access",
|
||||||
|
icon: <Component className="w-4 h-4" />,
|
||||||
|
label: "模型接入",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "system-config",
|
||||||
|
icon: <SettingOutlined />,
|
||||||
|
label: "参数配置",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
if (canManageUsers || canViewRoles || canViewPermissions) {
|
||||||
|
nextTabs.push({
|
||||||
|
key: "user-permission",
|
||||||
|
icon: <TeamOutlined />,
|
||||||
|
label: "用户与权限",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return nextTabs;
|
||||||
|
}, [canManageUsers, canViewPermissions, canViewRoles]);
|
||||||
|
const [activeTab, setActiveTab] = useState<string>(tabs[0]?.key ?? "model-access");
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const hasActiveTab = tabs.some((tab) => tab.key === activeTab);
|
||||||
|
if (!hasActiveTab && tabs.length > 0) {
|
||||||
|
setActiveTab(tabs[0].key);
|
||||||
|
}
|
||||||
|
}, [activeTab, tabs]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="h-screen flex">
|
<div className="h-screen flex">
|
||||||
@@ -18,21 +57,10 @@ export default function SettingsPage() {
|
|||||||
<div className="h-full">
|
<div className="h-full">
|
||||||
<Menu
|
<Menu
|
||||||
mode="inline"
|
mode="inline"
|
||||||
items={[
|
items={tabs}
|
||||||
{
|
|
||||||
key: "model-access",
|
|
||||||
icon: <Component className="w-4 h-4" />,
|
|
||||||
label: "模型接入",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: "system-config",
|
|
||||||
icon: <SettingOutlined />,
|
|
||||||
label: "参数配置",
|
|
||||||
},
|
|
||||||
]}
|
|
||||||
selectedKeys={[activeTab]}
|
selectedKeys={[activeTab]}
|
||||||
onClick={({ key }) => {
|
onClick={({ key }) => {
|
||||||
setActiveTab(key);
|
setActiveTab(String(key));
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -41,6 +69,13 @@ export default function SettingsPage() {
|
|||||||
{/* 内容区域,根据 activeTab 渲染不同的组件 */}
|
{/* 内容区域,根据 activeTab 渲染不同的组件 */}
|
||||||
{activeTab === "system-config" && <SystemConfig />}
|
{activeTab === "system-config" && <SystemConfig />}
|
||||||
{activeTab === "model-access" && <ModelAccess />}
|
{activeTab === "model-access" && <ModelAccess />}
|
||||||
|
{activeTab === "user-permission" && (
|
||||||
|
<UserPermissionManagement
|
||||||
|
canManageUsers={canManageUsers}
|
||||||
|
canViewRoles={canViewRoles}
|
||||||
|
canViewPermissions={canViewPermissions}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
321
frontend/src/pages/SettingsPage/UserPermissionManagement.tsx
Normal file
321
frontend/src/pages/SettingsPage/UserPermissionManagement.tsx
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Card,
|
||||||
|
Empty,
|
||||||
|
message,
|
||||||
|
Modal,
|
||||||
|
Select,
|
||||||
|
Space,
|
||||||
|
Table,
|
||||||
|
Tag,
|
||||||
|
Typography,
|
||||||
|
} from "antd";
|
||||||
|
import type { ColumnsType } from "antd/es/table";
|
||||||
|
import {
|
||||||
|
assignUserRolesUsingPut,
|
||||||
|
listAuthPermissionsUsingGet,
|
||||||
|
listAuthRolesUsingGet,
|
||||||
|
listAuthUsersUsingGet,
|
||||||
|
} from "./settings.apis";
|
||||||
|
import type {
|
||||||
|
AuthPermissionInfo,
|
||||||
|
AuthRoleInfo,
|
||||||
|
AuthUserWithRoles,
|
||||||
|
} from "./settings.apis";
|
||||||
|
|
||||||
|
interface ApiResponse<T> {
|
||||||
|
code: string;
|
||||||
|
message: string;
|
||||||
|
data: T;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UserPermissionManagementProps {
|
||||||
|
canManageUsers: boolean;
|
||||||
|
canViewRoles: boolean;
|
||||||
|
canViewPermissions: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function UserPermissionManagement({
|
||||||
|
canManageUsers,
|
||||||
|
canViewRoles,
|
||||||
|
canViewPermissions,
|
||||||
|
}: UserPermissionManagementProps) {
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [users, setUsers] = useState<AuthUserWithRoles[]>([]);
|
||||||
|
const [roles, setRoles] = useState<AuthRoleInfo[]>([]);
|
||||||
|
const [permissions, setPermissions] = useState<AuthPermissionInfo[]>([]);
|
||||||
|
const [editingUser, setEditingUser] = useState<AuthUserWithRoles | null>(null);
|
||||||
|
const [selectedRoleCodes, setSelectedRoleCodes] = useState<string[]>([]);
|
||||||
|
const [submitting, setSubmitting] = useState(false);
|
||||||
|
|
||||||
|
const canShowAnything = canManageUsers || canViewRoles || canViewPermissions;
|
||||||
|
const canAssignRoles = canManageUsers && roles.length > 0;
|
||||||
|
|
||||||
|
const roleNameMap = useMemo(
|
||||||
|
() => new Map(roles.map((role) => [role.roleCode, role.roleName || role.roleCode])),
|
||||||
|
[roles]
|
||||||
|
);
|
||||||
|
const roleCodeToIdMap = useMemo(
|
||||||
|
() => new Map(roles.map((role) => [role.roleCode, role.id])),
|
||||||
|
[roles]
|
||||||
|
);
|
||||||
|
|
||||||
|
const loadData = useCallback(async () => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const requestTasks: Array<Promise<unknown>> = [];
|
||||||
|
if (canManageUsers || canViewRoles || canViewPermissions) {
|
||||||
|
requestTasks.push(listAuthUsersUsingGet());
|
||||||
|
}
|
||||||
|
if (canManageUsers || canViewRoles) {
|
||||||
|
requestTasks.push(listAuthRolesUsingGet());
|
||||||
|
}
|
||||||
|
if (canViewPermissions) {
|
||||||
|
requestTasks.push(listAuthPermissionsUsingGet());
|
||||||
|
}
|
||||||
|
const responses = await Promise.all(requestTasks);
|
||||||
|
let index = 0;
|
||||||
|
if (canManageUsers || canViewRoles || canViewPermissions) {
|
||||||
|
const userResponse = responses[index++] as ApiResponse<AuthUserWithRoles[]>;
|
||||||
|
setUsers(userResponse?.data ?? []);
|
||||||
|
}
|
||||||
|
if (canManageUsers || canViewRoles) {
|
||||||
|
const roleResponse = responses[index++] as ApiResponse<AuthRoleInfo[]>;
|
||||||
|
setRoles(roleResponse?.data ?? []);
|
||||||
|
} else {
|
||||||
|
setRoles([]);
|
||||||
|
}
|
||||||
|
if (canViewPermissions) {
|
||||||
|
const permissionResponse = responses[index++] as ApiResponse<AuthPermissionInfo[]>;
|
||||||
|
setPermissions(permissionResponse?.data ?? []);
|
||||||
|
} else {
|
||||||
|
setPermissions([]);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
message.error("加载用户权限信息失败");
|
||||||
|
console.error("加载用户权限信息失败:", error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, [canManageUsers, canViewPermissions, canViewRoles]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!canShowAnything) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
void loadData();
|
||||||
|
}, [canShowAnything, loadData]);
|
||||||
|
|
||||||
|
const userColumns: ColumnsType<AuthUserWithRoles> = [
|
||||||
|
{
|
||||||
|
title: "用户名",
|
||||||
|
dataIndex: "username",
|
||||||
|
key: "username",
|
||||||
|
width: 180,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "姓名",
|
||||||
|
dataIndex: "fullName",
|
||||||
|
key: "fullName",
|
||||||
|
width: 180,
|
||||||
|
render: (value?: string) => value || "-",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "邮箱",
|
||||||
|
dataIndex: "email",
|
||||||
|
key: "email",
|
||||||
|
render: (value?: string) => value || "-",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "状态",
|
||||||
|
dataIndex: "enabled",
|
||||||
|
key: "enabled",
|
||||||
|
width: 120,
|
||||||
|
render: (enabled?: boolean) =>
|
||||||
|
enabled ? <Tag color="green">启用</Tag> : <Tag color="default">禁用</Tag>,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "角色",
|
||||||
|
dataIndex: "roleCodes",
|
||||||
|
key: "roleCodes",
|
||||||
|
render: (roleCodes: string[]) => (
|
||||||
|
<Space wrap>
|
||||||
|
{(roleCodes ?? []).map((roleCode) => (
|
||||||
|
<Tag key={roleCode}>{roleNameMap.get(roleCode) || roleCode}</Tag>
|
||||||
|
))}
|
||||||
|
</Space>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "操作",
|
||||||
|
key: "actions",
|
||||||
|
width: 120,
|
||||||
|
render: (_, record) => (
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
disabled={!canAssignRoles}
|
||||||
|
onClick={() => {
|
||||||
|
setEditingUser(record);
|
||||||
|
setSelectedRoleCodes(record.roleCodes ?? []);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
分配角色
|
||||||
|
</Button>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const roleColumns: ColumnsType<AuthRoleInfo> = [
|
||||||
|
{ title: "角色编码", dataIndex: "roleCode", key: "roleCode", width: 220 },
|
||||||
|
{ title: "角色名称", dataIndex: "roleName", key: "roleName", width: 180 },
|
||||||
|
{
|
||||||
|
title: "状态",
|
||||||
|
dataIndex: "enabled",
|
||||||
|
key: "enabled",
|
||||||
|
width: 120,
|
||||||
|
render: (enabled?: boolean) =>
|
||||||
|
enabled ? <Tag color="green">启用</Tag> : <Tag color="default">禁用</Tag>,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "描述",
|
||||||
|
dataIndex: "description",
|
||||||
|
key: "description",
|
||||||
|
render: (value?: string) => value || "-",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const permissionColumns: ColumnsType<AuthPermissionInfo> = [
|
||||||
|
{
|
||||||
|
title: "权限编码",
|
||||||
|
dataIndex: "permissionCode",
|
||||||
|
key: "permissionCode",
|
||||||
|
width: 260,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "权限名称",
|
||||||
|
dataIndex: "permissionName",
|
||||||
|
key: "permissionName",
|
||||||
|
width: 200,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "模块",
|
||||||
|
dataIndex: "module",
|
||||||
|
key: "module",
|
||||||
|
width: 140,
|
||||||
|
render: (value?: string) => value || "-",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "动作",
|
||||||
|
dataIndex: "action",
|
||||||
|
key: "action",
|
||||||
|
width: 120,
|
||||||
|
render: (value?: string) => value || "-",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: "接口",
|
||||||
|
key: "api",
|
||||||
|
render: (_, record) =>
|
||||||
|
record.pathPattern ? `${record.method || "ALL"} ${record.pathPattern}` : "-",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const handleAssignRoles = async () => {
|
||||||
|
if (!editingUser) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (selectedRoleCodes.length === 0) {
|
||||||
|
message.warning("请至少选择一个角色");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const roleIds = selectedRoleCodes
|
||||||
|
.map((roleCode) => roleCodeToIdMap.get(roleCode))
|
||||||
|
.filter((roleId): roleId is string => Boolean(roleId));
|
||||||
|
if (roleIds.length !== selectedRoleCodes.length) {
|
||||||
|
message.error("角色映射失败,请刷新后重试");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setSubmitting(true);
|
||||||
|
try {
|
||||||
|
await assignUserRolesUsingPut(editingUser.id, roleIds);
|
||||||
|
message.success("角色分配成功");
|
||||||
|
setEditingUser(null);
|
||||||
|
setSelectedRoleCodes([]);
|
||||||
|
await loadData();
|
||||||
|
} catch (error) {
|
||||||
|
message.error("角色分配失败");
|
||||||
|
console.error("角色分配失败:", error);
|
||||||
|
} finally {
|
||||||
|
setSubmitting(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!canShowAnything) {
|
||||||
|
return <Empty description="当前账号无用户与权限管理权限" />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Space direction="vertical" size={16} className="w-full">
|
||||||
|
<Card title="用户管理">
|
||||||
|
<Table
|
||||||
|
loading={loading}
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={users}
|
||||||
|
columns={userColumns}
|
||||||
|
pagination={{ pageSize: 10, showSizeChanger: false }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
{canViewRoles && (
|
||||||
|
<Card title="角色列表">
|
||||||
|
<Table
|
||||||
|
loading={loading}
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={roles}
|
||||||
|
columns={roleColumns}
|
||||||
|
pagination={{ pageSize: 8, showSizeChanger: false }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
{canViewPermissions && (
|
||||||
|
<Card title="权限列表">
|
||||||
|
<Table
|
||||||
|
loading={loading}
|
||||||
|
rowKey="id"
|
||||||
|
dataSource={permissions}
|
||||||
|
columns={permissionColumns}
|
||||||
|
pagination={{ pageSize: 10, showSizeChanger: false }}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
<Modal
|
||||||
|
title={`分配角色 - ${editingUser?.username || ""}`}
|
||||||
|
open={Boolean(editingUser)}
|
||||||
|
confirmLoading={submitting}
|
||||||
|
onOk={() => {
|
||||||
|
void handleAssignRoles();
|
||||||
|
}}
|
||||||
|
onCancel={() => {
|
||||||
|
setEditingUser(null);
|
||||||
|
setSelectedRoleCodes([]);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{roles.length === 0 ? (
|
||||||
|
<Typography.Text type="secondary">暂无可分配角色</Typography.Text>
|
||||||
|
) : (
|
||||||
|
<Select
|
||||||
|
mode="multiple"
|
||||||
|
className="w-full"
|
||||||
|
placeholder="请选择角色"
|
||||||
|
value={selectedRoleCodes}
|
||||||
|
onChange={(values) => setSelectedRoleCodes(values)}
|
||||||
|
options={roles.map((role) => ({
|
||||||
|
value: role.roleCode,
|
||||||
|
label: `${role.roleName} (${role.roleCode})`,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Modal>
|
||||||
|
</Space>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
import { get, post, put, del } from "@/utils/request";
|
import { get, post, put, del } from "@/utils/request";
|
||||||
|
|
||||||
// 模型相关接口
|
// 模型相关接口
|
||||||
export function queryModelProvidersUsingGet(params?: any) {
|
export function queryModelProvidersUsingGet(params?: Record<string, unknown>) {
|
||||||
return get("/api/models/providers", params);
|
return get("/api/models/providers", params);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function queryModelListUsingGet(data: any) {
|
export function queryModelListUsingGet(data: Record<string, unknown>) {
|
||||||
return get("/api/models/list", data);
|
return get("/api/models/list", data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -15,12 +15,12 @@ export function queryModelDetailByIdUsingGet(id: string | number) {
|
|||||||
|
|
||||||
export function updateModelByIdUsingPut(
|
export function updateModelByIdUsingPut(
|
||||||
id: string | number,
|
id: string | number,
|
||||||
data: any
|
data: Record<string, unknown>
|
||||||
) {
|
) {
|
||||||
return put(`/api/models/${id}`, data);
|
return put(`/api/models/${id}`, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createModelUsingPost(data: any) {
|
export function createModelUsingPost(data: Record<string, unknown>) {
|
||||||
return post("/api/models/create", data);
|
return post("/api/models/create", data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,13 +28,60 @@ export function deleteModelByIdUsingDelete(id: string | number) {
|
|||||||
return del(`/api/models/${id}`);
|
return del(`/api/models/${id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// 获取系统参数列表
|
// 获取系统参数列表
|
||||||
export function getSysParamList() {
|
export function getSysParamList() {
|
||||||
return get('/api/sys-param/list');
|
return get("/api/sys-param/list");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新系统参数值
|
// 更新系统参数值
|
||||||
export const updateSysParamValue = async (params: { id: string; paramValue: string }) => {
|
export const updateSysParamValue = async (params: {
|
||||||
|
id: string;
|
||||||
|
paramValue: string;
|
||||||
|
}) => {
|
||||||
return put(`/api/sys-param/${params.id}`, params);
|
return put(`/api/sys-param/${params.id}`, params);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export interface AuthUserWithRoles {
|
||||||
|
id: number;
|
||||||
|
username: string;
|
||||||
|
fullName?: string;
|
||||||
|
email?: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
roleCodes: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AuthRoleInfo {
|
||||||
|
id: string;
|
||||||
|
roleCode: string;
|
||||||
|
roleName: string;
|
||||||
|
description?: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AuthPermissionInfo {
|
||||||
|
id: string;
|
||||||
|
permissionCode: string;
|
||||||
|
permissionName: string;
|
||||||
|
module?: string;
|
||||||
|
action?: string;
|
||||||
|
pathPattern?: string;
|
||||||
|
method?: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户与权限管理接口
|
||||||
|
export function listAuthUsersUsingGet() {
|
||||||
|
return get("/api/auth/users");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listAuthRolesUsingGet() {
|
||||||
|
return get("/api/auth/roles");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listAuthPermissionsUsingGet() {
|
||||||
|
return get("/api/auth/permissions");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function assignUserRolesUsingPut(userId: number, roleIds: string[]) {
|
||||||
|
return put(`/api/auth/users/${userId}/roles`, { roleIds });
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,11 @@ from app.module.shared.schema import StandardResponse
|
|||||||
from app.module.dataset import DatasetManagementService
|
from app.module.dataset import DatasetManagementService
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
from ..security import (
|
||||||
|
RequestUserContext,
|
||||||
|
assert_dataset_access,
|
||||||
|
get_request_user_context,
|
||||||
|
)
|
||||||
from ..schema.auto import (
|
from ..schema.auto import (
|
||||||
CreateAutoAnnotationTaskRequest,
|
CreateAutoAnnotationTaskRequest,
|
||||||
AutoAnnotationTaskResponse,
|
AutoAnnotationTaskResponse,
|
||||||
@@ -39,13 +44,14 @@ service = AutoAnnotationTaskService()
|
|||||||
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
|
||||||
async def list_auto_annotation_tasks(
|
async def list_auto_annotation_tasks(
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""获取自动标注任务列表。
|
"""获取自动标注任务列表。
|
||||||
|
|
||||||
前端当前不传分页参数,这里直接返回所有未删除任务。
|
前端当前不传分页参数,这里直接返回所有未删除任务。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tasks = await service.list_tasks(db)
|
tasks = await service.list_tasks(db, user_context)
|
||||||
return StandardResponse(
|
return StandardResponse(
|
||||||
code=200,
|
code=200,
|
||||||
message="success",
|
message="success",
|
||||||
@@ -57,6 +63,7 @@ async def list_auto_annotation_tasks(
|
|||||||
async def create_auto_annotation_task(
|
async def create_auto_annotation_task(
|
||||||
request: CreateAutoAnnotationTaskRequest,
|
request: CreateAutoAnnotationTaskRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""创建自动标注任务。
|
"""创建自动标注任务。
|
||||||
|
|
||||||
@@ -74,6 +81,7 @@ async def create_auto_annotation_task(
|
|||||||
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
|
||||||
dataset_name = None
|
dataset_name = None
|
||||||
total_images = 0
|
total_images = 0
|
||||||
|
await assert_dataset_access(db, request.dataset_id, user_context)
|
||||||
try:
|
try:
|
||||||
dm_client = DatasetManagementService(db)
|
dm_client = DatasetManagementService(db)
|
||||||
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
|
||||||
@@ -106,13 +114,14 @@ async def create_auto_annotation_task(
|
|||||||
async def get_auto_annotation_task_status(
|
async def get_auto_annotation_task_status(
|
||||||
task_id: str = Path(..., description="任务ID"),
|
task_id: str = Path(..., description="任务ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""获取单个自动标注任务状态。
|
"""获取单个自动标注任务状态。
|
||||||
|
|
||||||
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = await service.get_task(db, task_id)
|
task = await service.get_task(db, task_id, user_context)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
@@ -127,10 +136,11 @@ async def get_auto_annotation_task_status(
|
|||||||
async def delete_auto_annotation_task(
|
async def delete_auto_annotation_task(
|
||||||
task_id: str = Path(..., description="任务ID"),
|
task_id: str = Path(..., description="任务ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
|
||||||
|
|
||||||
ok = await service.soft_delete_task(db, task_id)
|
ok = await service.soft_delete_task(db, task_id, user_context)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
@@ -145,6 +155,7 @@ async def delete_auto_annotation_task(
|
|||||||
async def download_auto_annotation_result(
|
async def download_auto_annotation_result(
|
||||||
task_id: str = Path(..., description="任务ID"),
|
task_id: str = Path(..., description="任务ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""下载指定自动标注任务的结果 ZIP。"""
|
"""下载指定自动标注任务的结果 ZIP。"""
|
||||||
|
|
||||||
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
|
|||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
# 复用服务层获取任务信息
|
# 复用服务层获取任务信息
|
||||||
task = await service.get_task(db, task_id)
|
task = await service.get_task(db, task_id, user_context)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
|
|||||||
UpsertAnnotationResponse,
|
UpsertAnnotationResponse,
|
||||||
)
|
)
|
||||||
from app.module.annotation.service.editor import AnnotationEditorService
|
from app.module.annotation.service.editor import AnnotationEditorService
|
||||||
|
from app.module.annotation.security import (
|
||||||
|
RequestUserContext,
|
||||||
|
get_request_user_context,
|
||||||
|
)
|
||||||
from app.module.shared.schema import StandardResponse
|
from app.module.shared.schema import StandardResponse
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -44,8 +48,9 @@ router = APIRouter(
|
|||||||
async def get_editor_project_info(
|
async def get_editor_project_info(
|
||||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
info = await service.get_project_info(project_id)
|
info = await service.get_project_info(project_id)
|
||||||
return StandardResponse(code=200, message="success", data=info)
|
return StandardResponse(code=200, message="success", data=info)
|
||||||
|
|
||||||
@@ -64,8 +69,9 @@ async def list_editor_tasks(
|
|||||||
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
|
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
result = await service.list_tasks(
|
result = await service.list_tasks(
|
||||||
project_id,
|
project_id,
|
||||||
page=page,
|
page=page,
|
||||||
@@ -86,8 +92,9 @@ async def get_editor_task(
|
|||||||
None, alias="segmentIndex", description="段落索引(分段模式下使用)"
|
None, alias="segmentIndex", description="段落索引(分段模式下使用)"
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
task = await service.get_task(project_id, file_id, segment_index=segment_index)
|
task = await service.get_task(project_id, file_id, segment_index=segment_index)
|
||||||
return StandardResponse(code=200, message="success", data=task)
|
return StandardResponse(code=200, message="success", data=task)
|
||||||
|
|
||||||
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
|
|||||||
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
|
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
result = await service.get_task_segment(project_id, file_id, segment_index)
|
result = await service.get_task_segment(project_id, file_id, segment_index)
|
||||||
return StandardResponse(code=200, message="success", data=result)
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|
||||||
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
|
|||||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
result = await service.upsert_annotation(project_id, file_id, request)
|
result = await service.upsert_annotation(project_id, file_id, request)
|
||||||
return StandardResponse(code=200, message="success", data=result)
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|
||||||
@@ -132,11 +141,12 @@ async def check_file_version(
|
|||||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
检查文件是否有新版本
|
检查文件是否有新版本
|
||||||
"""
|
"""
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
result = await service.check_file_version(project_id, file_id)
|
result = await service.check_file_version(project_id, file_id)
|
||||||
return StandardResponse(code=200, message="success", data=result)
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|
||||||
@@ -149,10 +159,11 @@ async def use_new_version(
|
|||||||
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
|
||||||
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
使用文件新版本并清空标注
|
使用文件新版本并清空标注
|
||||||
"""
|
"""
|
||||||
service = AnnotationEditorService(db)
|
service = AnnotationEditorService(db, user_context)
|
||||||
result = await service.use_new_version(project_id, file_id)
|
result = await service.use_new_version(project_id, file_id)
|
||||||
return StandardResponse(code=200, message="success", data=result)
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
|
|||||||
from app.module.dataset import DatasetManagementService
|
from app.module.dataset import DatasetManagementService
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
from ..security import (
|
||||||
|
RequestUserContext,
|
||||||
|
assert_dataset_access,
|
||||||
|
get_request_user_context,
|
||||||
|
)
|
||||||
from ..service.mapping import DatasetMappingService
|
from ..service.mapping import DatasetMappingService
|
||||||
from ..service.template import AnnotationTemplateService
|
from ..service.template import AnnotationTemplateService
|
||||||
from ..service.knowledge_sync import KnowledgeSyncService
|
from ..service.knowledge_sync import KnowledgeSyncService
|
||||||
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
|
|||||||
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
|
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
|
||||||
)
|
)
|
||||||
async def create_mapping(
|
async def create_mapping(
|
||||||
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
|
request: DatasetMappingCreateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
创建数据集映射
|
创建数据集映射
|
||||||
@@ -58,6 +65,8 @@ async def create_mapping(
|
|||||||
mapping_service = DatasetMappingService(db)
|
mapping_service = DatasetMappingService(db)
|
||||||
template_service = AnnotationTemplateService()
|
template_service = AnnotationTemplateService()
|
||||||
|
|
||||||
|
await assert_dataset_access(db, request.dataset_id, user_context)
|
||||||
|
|
||||||
logger.info(f"Create dataset mapping request: {request.dataset_id}")
|
logger.info(f"Create dataset mapping request: {request.dataset_id}")
|
||||||
|
|
||||||
# 从DM服务获取数据集信息
|
# 从DM服务获取数据集信息
|
||||||
@@ -163,7 +172,7 @@ async def create_mapping(
|
|||||||
try:
|
try:
|
||||||
from ..service.editor import AnnotationEditorService
|
from ..service.editor import AnnotationEditorService
|
||||||
|
|
||||||
editor_service = AnnotationEditorService(db)
|
editor_service = AnnotationEditorService(db, user_context)
|
||||||
# 异步预计算切片(不阻塞创建响应)
|
# 异步预计算切片(不阻塞创建响应)
|
||||||
segmentation_result = (
|
segmentation_result = (
|
||||||
await editor_service.precompute_segmentation_for_project(
|
await editor_service.precompute_segmentation_for_project(
|
||||||
@@ -202,6 +211,7 @@ async def list_mappings(
|
|||||||
False, description="是否包含模板详情", alias="includeTemplate"
|
False, description="是否包含模板详情", alias="includeTemplate"
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
查询所有映射关系(分页)
|
查询所有映射关系(分页)
|
||||||
@@ -230,6 +240,8 @@ async def list_mappings(
|
|||||||
limit=size,
|
limit=size,
|
||||||
include_deleted=False,
|
include_deleted=False,
|
||||||
include_template=include_template,
|
include_template=include_template,
|
||||||
|
current_user_id=user_context.user_id,
|
||||||
|
is_admin=user_context.is_admin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算总页数
|
# 计算总页数
|
||||||
@@ -256,7 +268,11 @@ async def list_mappings(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
|
||||||
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
async def get_mapping(
|
||||||
|
mapping_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
|
||||||
|
|
||||||
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Mapping not found: {mapping_id}"
|
status_code=404, detail=f"Mapping not found: {mapping_id}"
|
||||||
)
|
)
|
||||||
|
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
|
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
|
||||||
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
|
|||||||
True, description="是否包含模板详情", alias="includeTemplate"
|
True, description="是否包含模板详情", alias="includeTemplate"
|
||||||
),
|
),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
|
||||||
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
service = DatasetMappingService(db)
|
service = DatasetMappingService(db)
|
||||||
|
await assert_dataset_access(db, dataset_id, user_context)
|
||||||
|
|
||||||
# 计算 skip
|
# 计算 skip
|
||||||
skip = (page - 1) * size
|
skip = (page - 1) * size
|
||||||
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
|
|||||||
skip=skip,
|
skip=skip,
|
||||||
limit=size,
|
limit=size,
|
||||||
include_template=include_template,
|
include_template=include_template,
|
||||||
|
current_user_id=user_context.user_id,
|
||||||
|
is_admin=user_context.is_admin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算总页数
|
# 计算总页数
|
||||||
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
|
|||||||
async def delete_mapping(
|
async def delete_mapping(
|
||||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
删除映射关系(软删除)
|
删除映射关系(软删除)
|
||||||
@@ -387,6 +409,7 @@ async def delete_mapping(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Mapping either not found or not specified."
|
status_code=404, detail=f"Mapping either not found or not specified."
|
||||||
)
|
)
|
||||||
|
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||||
|
|
||||||
id = mapping.id
|
id = mapping.id
|
||||||
dataset_id = mapping.dataset_id
|
dataset_id = mapping.dataset_id
|
||||||
@@ -428,6 +451,7 @@ async def update_mapping(
|
|||||||
project_id: str = Path(..., description="映射UUID(path param)"),
|
project_id: str = Path(..., description="映射UUID(path param)"),
|
||||||
request: DatasetMappingUpdateRequest = None,
|
request: DatasetMappingUpdateRequest = None,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
更新标注项目信息
|
更新标注项目信息
|
||||||
@@ -456,6 +480,7 @@ async def update_mapping(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail=f"Mapping not found: {project_id}"
|
status_code=404, detail=f"Mapping not found: {project_id}"
|
||||||
)
|
)
|
||||||
|
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
|
||||||
|
|
||||||
# 构建更新数据
|
# 构建更新数据
|
||||||
update_values = {}
|
update_values = {}
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
|
|||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
from ..security import (
|
||||||
|
RequestUserContext,
|
||||||
|
assert_dataset_access,
|
||||||
|
get_request_user_context,
|
||||||
|
)
|
||||||
from ..service.mapping import DatasetMappingService
|
from ..service.mapping import DatasetMappingService
|
||||||
from ..schema import (
|
from ..schema import (
|
||||||
SyncDatasetRequest,
|
SyncDatasetRequest,
|
||||||
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
|
|||||||
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
|
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
|
||||||
async def sync_dataset_content(
|
async def sync_dataset_content(
|
||||||
request: SyncDatasetRequest,
|
request: SyncDatasetRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sync Dataset Content (Files and Annotations)
|
Sync Dataset Content (Files and Annotations)
|
||||||
@@ -51,6 +57,7 @@ async def sync_dataset_content(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Mapping not found: {request.id}"
|
detail=f"Mapping not found: {request.id}"
|
||||||
)
|
)
|
||||||
|
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||||
|
|
||||||
dm_client = DatasetManagementService(db)
|
dm_client = DatasetManagementService(db)
|
||||||
dataset_info = await dm_client.get_dataset(mapping.dataset_id)
|
dataset_info = await dm_client.get_dataset(mapping.dataset_id)
|
||||||
@@ -82,7 +89,8 @@ async def sync_dataset_content(
|
|||||||
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
|
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
|
||||||
async def sync_annotations(
|
async def sync_annotations(
|
||||||
request: SyncAnnotationsRequest,
|
request: SyncAnnotationsRequest,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sync Annotations Only (Bidirectional Support)
|
Sync Annotations Only (Bidirectional Support)
|
||||||
@@ -102,6 +110,7 @@ async def sync_annotations(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Mapping not found: {request.id}"
|
detail=f"Mapping not found: {request.id}"
|
||||||
)
|
)
|
||||||
|
await assert_dataset_access(db, mapping.dataset_id, user_context)
|
||||||
|
|
||||||
result = SyncAnnotationsResponse(
|
result = SyncAnnotationsResponse(
|
||||||
id=mapping.id,
|
id=mapping.id,
|
||||||
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
|
|||||||
async def update_file_tags(
|
async def update_file_tags(
|
||||||
request: UpdateFileTagsRequest,
|
request: UpdateFileTagsRequest,
|
||||||
file_id: str = Path(..., description="文件ID"),
|
file_id: str = Path(..., description="文件ID"),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db),
|
||||||
|
user_context: RequestUserContext = Depends(get_request_user_context),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Update File Tags (Partial Update with Auto Format Conversion)
|
Update File Tags (Partial Update with Auto Format Conversion)
|
||||||
@@ -189,6 +199,7 @@ async def update_file_tags(
|
|||||||
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
|
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
|
||||||
|
|
||||||
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
|
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
|
||||||
|
await assert_dataset_access(db, dataset_id, user_context)
|
||||||
|
|
||||||
# 查找数据集关联的模板ID
|
# 查找数据集关联的模板ID
|
||||||
from ..service.mapping import DatasetMappingService
|
from ..service.mapping import DatasetMappingService
|
||||||
|
|||||||
69
runtime/datamate-python/app/module/annotation/security.py
Normal file
69
runtime/datamate-python/app/module/annotation/security.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models.dataset_management import Dataset
|
||||||
|
|
||||||
|
HEADER_USER_ID = "X-User-Id"
|
||||||
|
HEADER_USER_NAME = "X-User-Name"
|
||||||
|
HEADER_USER_ROLES = "X-User-Roles"
|
||||||
|
ADMIN_ROLE_CODE = "ROLE_ADMIN"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RequestUserContext:
|
||||||
|
user_id: str
|
||||||
|
username: str | None
|
||||||
|
roles: Tuple[str, ...]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_admin(self) -> bool:
|
||||||
|
return any(role.upper() == ADMIN_ROLE_CODE for role in self.roles)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_user_context(request: Request) -> RequestUserContext:
|
||||||
|
user_id = (request.headers.get(HEADER_USER_ID) or "").strip()
|
||||||
|
username = (request.headers.get(HEADER_USER_NAME) or "").strip() or None
|
||||||
|
role_header = request.headers.get(HEADER_USER_ROLES) or ""
|
||||||
|
roles = tuple(
|
||||||
|
role.strip()
|
||||||
|
for role in role_header.split(",")
|
||||||
|
if role and role.strip()
|
||||||
|
)
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="权限不足:缺少用户身份")
|
||||||
|
return RequestUserContext(user_id=user_id, username=username, roles=roles)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_dataset_owner_access(
|
||||||
|
user_context: RequestUserContext,
|
||||||
|
dataset_owner_user_id: str | None,
|
||||||
|
dataset_id: str,
|
||||||
|
) -> None:
|
||||||
|
if user_context.is_admin:
|
||||||
|
return
|
||||||
|
if not dataset_owner_user_id or dataset_owner_user_id != user_context.user_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"无权访问数据集: {dataset_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_dataset_access(
|
||||||
|
db: AsyncSession,
|
||||||
|
dataset_id: str,
|
||||||
|
user_context: RequestUserContext,
|
||||||
|
) -> None:
|
||||||
|
owner_result = await db.execute(
|
||||||
|
select(Dataset.created_by).where(Dataset.id == dataset_id)
|
||||||
|
)
|
||||||
|
dataset_owner = owner_result.scalar_one_or_none()
|
||||||
|
if dataset_owner is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
|
||||||
|
ensure_dataset_owner_access(user_context, str(dataset_owner), dataset_id)
|
||||||
|
|
||||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.db.models.annotation_management import AutoAnnotationTask
|
from app.db.models.annotation_management import AutoAnnotationTask
|
||||||
from app.db.models.dataset_management import Dataset, DatasetFiles
|
from app.db.models.dataset_management import Dataset, DatasetFiles
|
||||||
|
from app.module.annotation.security import RequestUserContext
|
||||||
|
|
||||||
from ..schema.auto import (
|
from ..schema.auto import (
|
||||||
CreateAutoAnnotationTaskRequest,
|
CreateAutoAnnotationTaskRequest,
|
||||||
@@ -63,13 +64,25 @@ class AutoAnnotationTaskService:
|
|||||||
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
|
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
|
||||||
|
if user_context.is_admin:
|
||||||
|
return query
|
||||||
|
return query.join(
|
||||||
|
Dataset,
|
||||||
|
AutoAnnotationTask.dataset_id == Dataset.id,
|
||||||
|
).where(Dataset.created_by == user_context.user_id)
|
||||||
|
|
||||||
|
async def list_tasks(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_context: RequestUserContext,
|
||||||
|
) -> List[AutoAnnotationTaskResponse]:
|
||||||
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
|
||||||
|
|
||||||
|
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
|
||||||
|
query = self._apply_dataset_scope(query, user_context)
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(AutoAnnotationTask)
|
query.order_by(AutoAnnotationTask.created_at.desc())
|
||||||
.where(AutoAnnotationTask.deleted_at.is_(None))
|
|
||||||
.order_by(AutoAnnotationTask.created_at.desc())
|
|
||||||
)
|
)
|
||||||
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
|
||||||
|
|
||||||
@@ -87,13 +100,18 @@ class AutoAnnotationTaskService:
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
|
async def get_task(
|
||||||
result = await db.execute(
|
self,
|
||||||
select(AutoAnnotationTask).where(
|
db: AsyncSession,
|
||||||
AutoAnnotationTask.id == task_id,
|
task_id: str,
|
||||||
AutoAnnotationTask.deleted_at.is_(None),
|
user_context: RequestUserContext,
|
||||||
)
|
) -> Optional[AutoAnnotationTaskResponse]:
|
||||||
|
query = select(AutoAnnotationTask).where(
|
||||||
|
AutoAnnotationTask.id == task_id,
|
||||||
|
AutoAnnotationTask.deleted_at.is_(None),
|
||||||
)
|
)
|
||||||
|
query = self._apply_dataset_scope(query, user_context)
|
||||||
|
result = await db.execute(query)
|
||||||
task = result.scalar_one_or_none()
|
task = result.scalar_one_or_none()
|
||||||
if not task:
|
if not task:
|
||||||
return None
|
return None
|
||||||
@@ -138,13 +156,18 @@ class AutoAnnotationTaskService:
|
|||||||
return [task.dataset_id]
|
return [task.dataset_id]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
|
async def soft_delete_task(
|
||||||
result = await db.execute(
|
self,
|
||||||
select(AutoAnnotationTask).where(
|
db: AsyncSession,
|
||||||
AutoAnnotationTask.id == task_id,
|
task_id: str,
|
||||||
AutoAnnotationTask.deleted_at.is_(None),
|
user_context: RequestUserContext,
|
||||||
)
|
) -> bool:
|
||||||
|
query = select(AutoAnnotationTask).where(
|
||||||
|
AutoAnnotationTask.id == task_id,
|
||||||
|
AutoAnnotationTask.deleted_at.is_(None),
|
||||||
)
|
)
|
||||||
|
query = self._apply_dataset_scope(query, user_context)
|
||||||
|
result = await db.execute(query)
|
||||||
task = result.scalar_one_or_none()
|
task = result.scalar_one_or_none()
|
||||||
if not task:
|
if not task:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
|
|||||||
from app.module.annotation.service.annotation_text_splitter import (
|
from app.module.annotation.service.annotation_text_splitter import (
|
||||||
AnnotationTextSplitter,
|
AnnotationTextSplitter,
|
||||||
)
|
)
|
||||||
|
from app.module.annotation.security import (
|
||||||
|
RequestUserContext,
|
||||||
|
ensure_dataset_owner_access,
|
||||||
|
)
|
||||||
from app.module.annotation.service.text_fetcher import (
|
from app.module.annotation.service.text_fetcher import (
|
||||||
fetch_text_content_via_download_api,
|
fetch_text_content_via_download_api,
|
||||||
)
|
)
|
||||||
@@ -104,8 +108,9 @@ class AnnotationEditorService:
|
|||||||
# 分段阈值:超过此字符数自动分段
|
# 分段阈值:超过此字符数自动分段
|
||||||
SEGMENT_THRESHOLD = 200
|
SEGMENT_THRESHOLD = 200
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession):
|
def __init__(self, db: AsyncSession, user_context: RequestUserContext):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.user_context = user_context
|
||||||
self.template_service = AnnotationTemplateService()
|
self.template_service = AnnotationTemplateService()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -157,14 +162,24 @@ class AnnotationEditorService:
|
|||||||
|
|
||||||
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
|
||||||
result = await self.db.execute(
|
result = await self.db.execute(
|
||||||
select(LabelingProject).where(
|
select(LabelingProject, Dataset.created_by).join(
|
||||||
|
Dataset,
|
||||||
|
LabelingProject.dataset_id == Dataset.id,
|
||||||
|
).where(
|
||||||
LabelingProject.id == project_id,
|
LabelingProject.id == project_id,
|
||||||
LabelingProject.deleted_at.is_(None),
|
LabelingProject.deleted_at.is_(None),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
project = result.scalar_one_or_none()
|
row = result.first()
|
||||||
if not project:
|
if not row:
|
||||||
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
|
||||||
|
project = row[0]
|
||||||
|
dataset_owner = row[1]
|
||||||
|
ensure_dataset_owner_access(
|
||||||
|
self.user_context,
|
||||||
|
str(dataset_owner) if dataset_owner is not None else None,
|
||||||
|
project.dataset_id,
|
||||||
|
)
|
||||||
return project
|
return project
|
||||||
|
|
||||||
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:
|
||||||
|
|||||||
@@ -478,7 +478,9 @@ class DatasetMappingService:
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
include_deleted: bool = False,
|
include_deleted: bool = False,
|
||||||
include_template: bool = False
|
include_template: bool = False,
|
||||||
|
current_user_id: Optional[str] = None,
|
||||||
|
is_admin: bool = False,
|
||||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||||
"""
|
"""
|
||||||
获取所有映射及总数(用于分页)
|
获取所有映射及总数(用于分页)
|
||||||
@@ -495,9 +497,16 @@ class DatasetMappingService:
|
|||||||
query = self._build_query_with_dataset_name()
|
query = self._build_query_with_dataset_name()
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||||
|
if not is_admin:
|
||||||
|
query = query.where(Dataset.created_by == current_user_id)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_query = select(func.count()).select_from(LabelingProject)
|
count_query = select(func.count()).select_from(LabelingProject)
|
||||||
|
if not is_admin:
|
||||||
|
count_query = count_query.join(
|
||||||
|
Dataset,
|
||||||
|
LabelingProject.dataset_id == Dataset.id,
|
||||||
|
).where(Dataset.created_by == current_user_id)
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||||
|
|
||||||
@@ -557,7 +566,9 @@ class DatasetMappingService:
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
include_deleted: bool = False,
|
include_deleted: bool = False,
|
||||||
include_template: bool = False
|
include_template: bool = False,
|
||||||
|
current_user_id: Optional[str] = None,
|
||||||
|
is_admin: bool = False,
|
||||||
) -> Tuple[List[DatasetMappingResponse], int]:
|
) -> Tuple[List[DatasetMappingResponse], int]:
|
||||||
"""
|
"""
|
||||||
根据源数据集ID获取映射关系及总数(用于分页)
|
根据源数据集ID获取映射关系及总数(用于分页)
|
||||||
@@ -578,11 +589,18 @@ class DatasetMappingService:
|
|||||||
|
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
query = query.where(LabelingProject.deleted_at.is_(None))
|
query = query.where(LabelingProject.deleted_at.is_(None))
|
||||||
|
if not is_admin:
|
||||||
|
query = query.where(Dataset.created_by == current_user_id)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数
|
||||||
count_query = select(func.count()).select_from(LabelingProject).where(
|
count_query = select(func.count()).select_from(LabelingProject).where(
|
||||||
LabelingProject.dataset_id == dataset_id
|
LabelingProject.dataset_id == dataset_id
|
||||||
)
|
)
|
||||||
|
if not is_admin:
|
||||||
|
count_query = count_query.join(
|
||||||
|
Dataset,
|
||||||
|
LabelingProject.dataset_id == Dataset.id,
|
||||||
|
).where(Dataset.created_by == current_user_id)
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user