Compare commits

..

7 Commits

Author SHA1 Message Date
329382db47 fix(pdf): 优化PDF文本提取服务异常处理
- 添加FeignException专门处理逻辑
- 实现详细的Feign异常日志记录功能
- 新增响应体解析和根因链构建方法
- 添加异常消息规范化处理
- 改进错误日志的可读性和调试信息完整度
2026-02-06 18:52:51 +08:00
e862925a06 feat(export): 添加逻辑路径构建功能支持文件管理
- 在导出服务中实现_build_logical_path方法用于构建相对路径
- 更新数据集文件记录以包含logical_path字段
- 在比率任务服务中实现build_logical_path静态方法
- 将逻辑路径信息添加到数据集文件记录中
- 规范化路径处理并替换反斜杠为正斜杠
- 添加无效路径验证防止目录遍历安全问题
2026-02-06 18:46:44 +08:00
05752678cc feat(dataset): 添加PDF提取服务中的逻辑路径构建功能
- 移除重复的csv导入语句
- 添加_build_logical_path方法用于构建文件逻辑路径
- 在_create_text_file_record方法中增加logical_path参数
- 更新记录创建调用以传递逻辑路径参数
- 验证逻辑路径不为空并抛出相应异常
- 将逻辑路径存储到数据集文件记录中
2026-02-06 18:30:44 +08:00
0f1dd9ec8d Merge remote-tracking branch 'gitea/lsf' into lsf 2026-02-06 18:29:58 +08:00
38e58ba864 Merge branch 'rbac' into lsf 2026-02-06 15:44:43 +08:00
6a4c4ae3d7 feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证
- 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证
- 修改DatasetRepository接口和实现类,增加按创建者过滤的方法
- 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法
- 在RAG索引器服务中添加知识库访问权限检查和作用域过滤
- 更新实体元对象处理器以使用请求用户上下文获取当前用户
- 在前端设置页面添加用户权限管理功能和角色权限控制
- 为Python标注服务增加用户上下文和数据集访问权限验证
2026-02-06 14:58:46 +08:00
056cee11cc feat(auth): 完善API网关JWT认证和权限控制功能
- 实现网关侧JWT工具类和权限规则匹配器
- 集成JWT认证流程,支持Bearer Token验证
- 添加基于路径和HTTP方法的权限控制机制
- 配置白名单路由规则,优化认证性能
- 更新前端受保护路由组件,实现权限验证
- 添加403禁止访问页面和权限检查逻辑
- 重构登录页面,集成实际认证API调用
- 实现用户信息获取和权限加载功能
- 优化全局异常处理器中的认证错误状态码
- 集成FastJSON2和JJWT依赖库支持
2026-02-06 13:21:20 +08:00
65 changed files with 2631 additions and 250 deletions

View File

@@ -36,6 +36,23 @@
<groupId>com.alibaba.fastjson2</groupId> <groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId> <artifactId>fastjson2</artifactId>
</dependency> </dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
<version>0.11.5</version>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@@ -1,34 +1,124 @@
package com.datamate.gateway.filter; package com.datamate.gateway.filter;
import com.alibaba.fastjson2.JSONObject;
import com.datamate.gateway.security.GatewayJwtUtils;
import com.datamate.gateway.security.PermissionRuleMatcher;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter; import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import java.nio.charset.StandardCharsets;
import java.util.List;
/** /**
* 用户信息过滤器 * 用户信息过滤器
*
*/ */
@Slf4j @Slf4j
@Component @Component
public class UserContextFilter implements GlobalFilter { public class UserContextFilter implements GlobalFilter, Ordered {
@Value("${commercial.switch:false}") private final GatewayJwtUtils gatewayJwtUtils;
private boolean isCommercial; private final PermissionRuleMatcher permissionRuleMatcher;
@Value("${datamate.auth.enabled:true}")
private boolean authEnabled;
public UserContextFilter(GatewayJwtUtils gatewayJwtUtils, PermissionRuleMatcher permissionRuleMatcher) {
this.gatewayJwtUtils = gatewayJwtUtils;
this.permissionRuleMatcher = permissionRuleMatcher;
}
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
if (!isCommercial) { if (!authEnabled) {
return chain.filter(exchange); return chain.filter(exchange);
} }
try { ServerHttpRequest request = exchange.getRequest();
String path = request.getURI().getPath();
HttpMethod method = request.getMethod();
} catch (Exception e) { if (!path.startsWith("/api/")) {
log.error("get current user info error", e);
return chain.filter(exchange); return chain.filter(exchange);
} }
return chain.filter(exchange); if (HttpMethod.OPTIONS.equals(method)) {
return chain.filter(exchange);
}
if (permissionRuleMatcher.isWhitelisted(path)) {
return chain.filter(exchange);
}
String token = extractBearerToken(request.getHeaders().getFirst("Authorization"));
if (!StringUtils.hasText(token)) {
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "未登录或登录状态已失效");
}
Claims claims;
try {
if (!gatewayJwtUtils.validateToken(token)) {
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "登录状态已失效");
}
claims = gatewayJwtUtils.getClaimsFromToken(token);
} catch (Exception ex) {
log.warn("JWT校验失败: {}", ex.getMessage());
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "登录状态已失效");
}
String requiredPermission = permissionRuleMatcher.resolveRequiredPermission(method, path);
if (StringUtils.hasText(requiredPermission)) {
List<String> permissionCodes = gatewayJwtUtils.getStringListClaim(claims, "permissions");
if (!permissionCodes.contains(requiredPermission)) {
return writeError(exchange, HttpStatus.FORBIDDEN, "auth.0006", "权限不足");
}
}
String userId = String.valueOf(claims.get("userId"));
String username = claims.getSubject();
List<String> roles = gatewayJwtUtils.getStringListClaim(claims, "roles");
ServerHttpRequest mutatedRequest = request.mutate()
.header("X-User-Id", userId)
.header("X-User-Name", username)
.header("X-User-Roles", String.join(",", roles))
.build();
return chain.filter(exchange.mutate().request(mutatedRequest).build());
}
@Override
public int getOrder() {
return -200;
}
private String extractBearerToken(String authorizationHeader) {
if (!StringUtils.hasText(authorizationHeader)) {
return null;
}
if (!authorizationHeader.startsWith("Bearer ")) {
return null;
}
String token = authorizationHeader.substring("Bearer ".length()).trim();
return token.isEmpty() ? null : token;
}
private Mono<Void> writeError(ServerWebExchange exchange,
HttpStatus status,
String code,
String message) {
exchange.getResponse().setStatusCode(status);
exchange.getResponse().getHeaders().set("Content-Type", "application/json;charset=UTF-8");
byte[] body = JSONObject.toJSONString(new ErrorBody(code, message, null))
.getBytes(StandardCharsets.UTF_8);
return exchange.getResponse().writeWith(Mono.just(exchange.getResponse().bufferFactory().wrap(body)));
}
private record ErrorBody(String code, String message, Object data) {
} }
} }

View File

@@ -0,0 +1,65 @@
package com.datamate.gateway.security;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
/**
* 网关侧JWT工具
*/
@Component
public class GatewayJwtUtils {
private static final String DEFAULT_SECRET = "datamate-secret-key-for-jwt-token-generation";
@Value("${jwt.secret:" + DEFAULT_SECRET + "}")
private String secret;
public Claims getClaimsFromToken(String token) {
return Jwts.parserBuilder()
.setSigningKey(getSigningKey())
.build()
.parseClaimsJws(token)
.getBody();
}
public boolean validateToken(String token) {
Claims claims = getClaimsFromToken(token);
Date expiration = claims.getExpiration();
return expiration != null && expiration.after(new Date());
}
public List<String> getStringListClaim(Claims claims, String claimName) {
Object claimValue = claims.get(claimName);
if (!(claimValue instanceof Collection<?> values)) {
return Collections.emptyList();
}
return values.stream()
.map(String::valueOf)
.collect(Collectors.toList());
}
private SecretKey getSigningKey() {
String secretValue = StringUtils.hasText(secret) ? secret : DEFAULT_SECRET;
try {
MessageDigest digest = MessageDigest.getInstance("SHA-512");
byte[] keyBytes = digest.digest(secretValue.getBytes(StandardCharsets.UTF_8));
return Keys.hmacShaKeyFor(keyBytes);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot initialize JWT signing key", e);
}
}
}

View File

@@ -0,0 +1,85 @@
package com.datamate.gateway.security;
import lombok.Getter;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* 权限规则匹配器
*/
@Component
public class PermissionRuleMatcher {
private static final Set<HttpMethod> READ_METHODS = Set.of(HttpMethod.GET, HttpMethod.HEAD);
private static final Set<HttpMethod> WRITE_METHODS = Set.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.PATCH, HttpMethod.DELETE);
private final AntPathMatcher pathMatcher = new AntPathMatcher();
private final List<String> whiteListPatterns = List.of(
"/api/auth/login",
"/api/auth/login/**"
);
private final List<PermissionRule> rules = buildRules();
public boolean isWhitelisted(String path) {
return whiteListPatterns.stream().anyMatch(pattern -> pathMatcher.match(pattern, path));
}
public String resolveRequiredPermission(HttpMethod method, String path) {
for (PermissionRule rule : rules) {
if (rule.matches(method, path, pathMatcher)) {
return rule.getPermissionCode();
}
}
return null;
}
private List<PermissionRule> buildRules() {
List<PermissionRule> permissionRules = new ArrayList<>();
addModuleRules(permissionRules, "/api/data-management/**", "module:data-management:read", "module:data-management:write");
addModuleRules(permissionRules, "/api/annotation/**", "module:data-annotation:read", "module:data-annotation:write");
addModuleRules(permissionRules, "/api/data-collection/**", "module:data-collection:read", "module:data-collection:write");
addModuleRules(permissionRules, "/api/evaluation/**", "module:data-evaluation:read", "module:data-evaluation:write");
addModuleRules(permissionRules, "/api/synthesis/**", "module:data-synthesis:read", "module:data-synthesis:write");
addModuleRules(permissionRules, "/api/knowledge-base/**", "module:knowledge-base:read", "module:knowledge-base:write");
addModuleRules(permissionRules, "/api/operator-market/**", "module:operator-market:read", "module:operator-market:write");
addModuleRules(permissionRules, "/api/orchestration/**", "module:orchestration:read", "module:orchestration:write");
addModuleRules(permissionRules, "/api/content-generation/**", "module:content-generation:use", "module:content-generation:use");
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/users/**", "system:user:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/users/**", "system:user:manage"));
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/roles/**", "system:role:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/roles/**", "system:role:manage"));
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/permissions/**", "system:permission:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/permissions/**", "system:permission:manage"));
return permissionRules;
}
private void addModuleRules(List<PermissionRule> rules,
String pathPattern,
String readPermissionCode,
String writePermissionCode) {
rules.add(new PermissionRule(READ_METHODS, pathPattern, readPermissionCode));
rules.add(new PermissionRule(WRITE_METHODS, pathPattern, writePermissionCode));
}
@Getter
private static class PermissionRule {
private final Set<HttpMethod> methods;
private final String pathPattern;
private final String permissionCode;
private PermissionRule(Set<HttpMethod> methods, String pathPattern, String permissionCode) {
this.methods = methods;
this.pathPattern = pathPattern;
this.permissionCode = permissionCode;
}
private boolean matches(HttpMethod method, String path, AntPathMatcher matcher) {
return method != null && methods.contains(method) && matcher.match(pathPattern, path);
}
}
}

View File

@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.domain.utils.ChunksSaver; import com.datamate.common.domain.utils.ChunksSaver;
import com.datamate.common.setting.application.SysParamApplicationService; import com.datamate.common.setting.application.SysParamApplicationService;
import com.datamate.datamanagement.interfaces.dto.*; import com.datamate.datamanagement.interfaces.dto.*;
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
private final CollectionTaskClient collectionTaskClient; private final CollectionTaskClient collectionTaskClient;
private final DatasetFileApplicationService datasetFileApplicationService; private final DatasetFileApplicationService datasetFileApplicationService;
private final SysParamApplicationService sysParamService; private final SysParamApplicationService sysParamService;
private final ResourceAccessService resourceAccessService;
@Value("${datamate.data-management.base-path:/dataset}") @Value("${datamate.data-management.base-path:/dataset}")
private String datasetBasePath; private String datasetBasePath;
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) { public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
if (StringUtils.hasText(updateDatasetRequest.getName())) { if (StringUtils.hasText(updateDatasetRequest.getName())) {
dataset.setName(updateDatasetRequest.getName()); dataset.setName(updateDatasetRequest.getName());
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
public void deleteDataset(String datasetId) { public void deleteDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
long childCount = datasetRepository.countByParentId(datasetId); long childCount = datasetRepository.countByParentId(datasetId);
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN); BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
datasetRepository.removeById(datasetId); datasetRepository.removeById(datasetId);
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
public Dataset getDataset(String datasetId) { public Dataset getDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId); List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
dataset.setFiles(datasetFiles); dataset.setFiles(datasetFiles);
applyVisibleFileCounts(Collections.singletonList(dataset)); applyVisibleFileCounts(Collections.singletonList(dataset));
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
@Transactional(readOnly = true) @Transactional(readOnly = true)
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) { public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize()); IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
page = datasetRepository.findByCriteria(page, query); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
String datasetPvcName = getDatasetPvcName(); String datasetPvcName = getDatasetPvcName();
applyVisibleFileCounts(page.getRecords()); applyVisibleFileCounts(page.getRecords());
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords()); List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR); BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
Dataset dataset = datasetRepository.getById(datasetId); Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND); BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Set<String> sourceTags = normalizeTagNames(dataset.getTags()); Set<String> sourceTags = normalizeTagNames(dataset.getTags());
if (sourceTags.isEmpty()) { if (sourceTags.isEmpty()) {
return Collections.emptyList(); return Collections.emptyList();
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
SIMILAR_DATASET_CANDIDATE_MAX, SIMILAR_DATASET_CANDIDATE_MAX,
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit) Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
); );
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
List<Dataset> candidates = datasetRepository.findSimilarByTags( List<Dataset> candidates = datasetRepository.findSimilarByTags(
new ArrayList<>(sourceTags), new ArrayList<>(sourceTags),
datasetId, datasetId,
candidateLimit candidateLimit,
ownerFilterUserId
); );
if (CollectionUtils.isEmpty(candidates)) { if (CollectionUtils.isEmpty(candidates)) {
return Collections.emptyList(); return Collections.emptyList();
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
if (dataset == null) { if (dataset == null) {
throw new IllegalArgumentException("Dataset not found: " + datasetId); throw new IllegalArgumentException("Dataset not found: " + datasetId);
} }
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Map<String, Object> statistics = new HashMap<>(); Map<String, Object> statistics = new HashMap<>();
@@ -485,7 +495,11 @@ public class DatasetApplicationService {
* 获取所有数据集的汇总统计信息 * 获取所有数据集的汇总统计信息
*/ */
public AllDatasetStatisticsResponse getAllDatasetStatistics() { public AllDatasetStatisticsResponse getAllDatasetStatistics() {
return datasetRepository.getAllDatasetStatistics(); if (resourceAccessService.isAdmin()) {
return datasetRepository.getAllDatasetStatistics();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
} }
/** /**

View File

@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert; import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.CommonErrorCode; import com.datamate.common.infrastructure.exception.CommonErrorCode;
import com.datamate.common.interfaces.PagedResponse; import com.datamate.common.interfaces.PagedResponse;
@@ -40,6 +41,7 @@ import java.util.UUID;
public class KnowledgeSetApplicationService { public class KnowledgeSetApplicationService {
private final KnowledgeSetRepository knowledgeSetRepository; private final KnowledgeSetRepository knowledgeSetRepository;
private final TagMapper tagMapper; private final TagMapper tagMapper;
private final ResourceAccessService resourceAccessService;
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) { public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null, BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) { public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()), BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR); DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
public void deleteKnowledgeSet(String setId) { public void deleteKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
knowledgeSetRepository.removeById(setId); knowledgeSetRepository.removeById(setId);
} }
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet getKnowledgeSet(String setId) { public KnowledgeSet getKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId); KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND); BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
return knowledgeSet; return knowledgeSet;
} }
@Transactional(readOnly = true) @Transactional(readOnly = true)
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) { public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize()); IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
page = knowledgeSetRepository.findByCriteria(page, query); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords()); List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages()); return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
} }

View File

@@ -4,6 +4,8 @@ import com.datamate.common.infrastructure.common.Response;
import com.datamate.datamanagement.infrastructure.client.PdfTextExtractClient; import com.datamate.datamanagement.infrastructure.client.PdfTextExtractClient;
import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractRequest; import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractRequest;
import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractResponse; import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractResponse;
import feign.FeignException;
import feign.Request;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
@@ -47,8 +49,71 @@ public class PdfTextExtractAsyncService {
} else { } else {
log.info("PdfTextExtract succeeded, datasetId={}, fileId={}", datasetId, fileId); log.info("PdfTextExtract succeeded, datasetId={}, fileId={}", datasetId, fileId);
} }
} catch (FeignException feignException) {
logFeignException(datasetId, fileId, feignException);
} catch (Exception e) { } catch (Exception e) {
log.error("PdfTextExtract call failed, datasetId={}, fileId={}", datasetId, fileId, e); log.error("PdfTextExtract call failed, datasetId={}, fileId={}", datasetId, fileId, e);
} }
} }
private void logFeignException(String datasetId, String fileId, FeignException feignException) {
Request request = feignException.request();
String httpMethod = request == null || request.httpMethod() == null
? "UNKNOWN"
: request.httpMethod().name();
String requestUrl = request == null || request.url() == null
? "UNKNOWN"
: request.url();
String responseBody = resolveFeignResponseBody(feignException);
String rootCauseChain = buildCauseChain(feignException, 12);
log.error(
"PdfTextExtract call failed with FeignException, datasetId={}, fileId={}, status={}, method={}, url={}, responseBody=\n{}\nrootCauseChain={}",
datasetId,
fileId,
feignException.status(),
httpMethod,
requestUrl,
responseBody,
rootCauseChain,
feignException
);
}
private String resolveFeignResponseBody(FeignException feignException) {
String responseBody = feignException.contentUTF8();
if (responseBody == null || responseBody.isBlank()) {
responseBody = feignException.getMessage();
}
if (responseBody == null || responseBody.isBlank()) {
return "EMPTY_RESPONSE_BODY";
}
return responseBody;
}
private String buildCauseChain(Throwable throwable, int maxDepth) {
StringBuilder causeChain = new StringBuilder();
Throwable current = throwable;
int depth = 0;
while (current != null && depth < maxDepth) {
if (causeChain.length() > 0) {
causeChain.append(" <- ");
}
causeChain.append(current.getClass().getSimpleName())
.append(": ")
.append(normalizeCauseMessage(current.getMessage()));
current = current.getCause();
depth++;
}
if (current != null) {
causeChain.append(" <- ...");
}
return causeChain.toString();
}
private String normalizeCauseMessage(String message) {
if (message == null || message.isBlank()) {
return "EMPTY_MESSAGE";
}
return message.replace("\r", " ").replace("\n", " ").trim();
}
} }

View File

@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
AllDatasetStatisticsResponse getAllDatasetStatistics(); AllDatasetStatisticsResponse getAllDatasetStatistics();
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query); AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
long countByParentId(String parentDatasetId); long countByParentId(String parentDatasetId);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit); List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
} }

View File

@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> { public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
KnowledgeSet findByName(String name); KnowledgeSet findByName(String name);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query); IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
} }

View File

@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
@Override @Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) { public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
List<Dataset> datasets = lambdaQuery()
.eq(Dataset::getCreatedBy, createdBy)
.list();
long totalFiles = datasets.stream()
.map(Dataset::getFileCount)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
long totalSize = datasets.stream()
.map(Dataset::getSizeBytes)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
response.setTotalDatasets(datasets.size());
response.setTotalFiles(totalFiles);
response.setTotalSize(totalSize);
return response;
}
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>() LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
.eq(query.getType() != null, Dataset::getDatasetType, query.getType()) .eq(query.getType() != null, Dataset::getDatasetType, query.getType())
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus()); .eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
if (query.getParentDatasetId() != null) { if (query.getParentDatasetId() != null) {
if (StringUtils.isBlank(query.getParentDatasetId())) { if (StringUtils.isBlank(query.getParentDatasetId())) {
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
} }
@Override @Override
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) { public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) { if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
return Collections.emptyList(); return Collections.emptyList();
} }
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
if (StringUtils.isNotBlank(excludedDatasetId)) { if (StringUtils.isNotBlank(excludedDatasetId)) {
wrapper.ne(Dataset::getId, excludedDatasetId.trim()); wrapper.ne(Dataset::getId, excludedDatasetId.trim());
} }
if (StringUtils.isNotBlank(createdBy)) {
wrapper.eq(Dataset::getCreatedBy, createdBy);
}
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0"); wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
wrapper.and(condition -> { wrapper.and(condition -> {
boolean hasCondition = false; boolean hasCondition = false;

View File

@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
} }
@Override @Override
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) { public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>() LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus()) .eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain()) .eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity()) .eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType()) .eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom()) .ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo()); .le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
if (StringUtils.isNotBlank(query.getKeyword())) { if (StringUtils.isNotBlank(query.getKeyword())) {
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword()) wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())

View File

@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode; import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse; import com.datamate.common.interfaces.PagedResponse;
import com.datamate.common.interfaces.PagingQuery; import com.datamate.common.interfaces.PagingQuery;
import com.datamate.common.setting.domain.entity.ModelConfig; import com.datamate.common.setting.domain.entity.ModelConfig;
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
private final ApplicationEventPublisher eventPublisher; private final ApplicationEventPublisher eventPublisher;
private final ModelConfigRepository modelConfigRepository; private final ModelConfigRepository modelConfigRepository;
private final MilvusService milvusService; private final MilvusService milvusService;
private final ResourceAccessService resourceAccessService;
/** /**
* 创建知识库 * 创建知识库
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
*/ */
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) { public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) { if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder() milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
*/ */
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void delete(String knowledgeBaseId) { public void delete(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
knowledgeBaseRepository.removeById(knowledgeBaseId); knowledgeBaseRepository.removeById(knowledgeBaseId);
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId); ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build()); milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
} }
public KnowledgeBaseResp getById(String knowledgeBaseId) { public KnowledgeBaseResp getById(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase); KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel())); resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel())); resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) { public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize()); IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
page = knowledgeBaseRepository.page(page, request); String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount // 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList(); List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void addFiles(AddFilesReq request) { public void addFiles(AddFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId())) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> { List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
RagFile ragFile = new RagFile(); RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId()); ragFile.setKnowledgeBaseId(knowledgeBase.getId());
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
} }
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) { public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize()); IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
request.setKnowledgeBaseId(knowledgeBaseId); request.setKnowledgeBaseId(knowledgeBaseId);
page = ragFileRepository.page(page, request); page = ragFileRepository.page(page, request);
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
} }
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) { public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
boolean admin = resourceAccessService.isAdmin();
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
}
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize()); IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
page = ragFileRepository.searchPage(page, request); page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
List<RagFile> records = page.getRecords(); List<RagFile> records = page.getRecords();
if (records.isEmpty()) { if (records.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages()); return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) { public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ragFileRepository.removeByIds(request.getIds()); ragFileRepository.removeByIds(request.getIds());
milvusService.getMilvusClient().delete(DeleteReq.builder() milvusService.getMilvusClient().delete(DeleteReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
} }
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) { public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId)) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder() QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName()) .collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"") .filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
* @return 检索结果 * @return 检索结果
*/ */
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) { public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst())) KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()); ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig); EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content(); Embedding embedding = embeddingModel.embed(request.getQuery()).content();
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
}); });
return searchResults; return searchResults;
} }
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
return knowledgeBase;
}
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
if (admin) {
return Collections.emptyList();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
return ownedKnowledgeBaseIds;
}
BusinessAssert.isTrue(
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
return Collections.singletonList(request.getKnowledgeBaseId());
}
} }

View File

@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
import com.datamate.rag.indexer.domain.model.KnowledgeBase; import com.datamate.rag.indexer.domain.model.KnowledgeBase;
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq; import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import java.util.List;
/** /**
* 知识库仓储接口 * 知识库仓储接口
* *
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
* @param request 查询请求 * @param request 查询请求
* @return 知识库分页结果 * @return 知识库分页结果
*/ */
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request); IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
List<String> listIdsByCreatedBy(String createdBy);
} }

View File

@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request); IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request); IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
} }

View File

@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/** /**
* 知识库仓储实现类 * 知识库仓储实现类
* *
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository { public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
@Override @Override
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) { public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>() return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName()) .like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription()) .like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy()) .like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy()) .like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
.orderByDesc(KnowledgeBase::getCreatedAt)); .orderByDesc(KnowledgeBase::getCreatedAt));
} }
@Override
public List<String> listIdsByCreatedBy(String createdBy) {
if (!StringUtils.hasText(createdBy)) {
return Collections.emptyList();
}
return lambdaQuery()
.select(KnowledgeBase::getId)
.eq(KnowledgeBase::getCreatedBy, createdBy)
.list()
.stream()
.map(KnowledgeBase::getId)
.filter(StringUtils::hasText)
.toList();
}
} }

View File

@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
} }
@Override @Override
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) { public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
return lambdaQuery() return lambdaQuery()
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId()) .eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
RagFile::getKnowledgeBaseId,
knowledgeBaseIds)
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName()) .like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath())) .likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
.page(page); .page(page);

View File

@@ -17,6 +17,11 @@
<description>DDD领域通用组件</description> <description>DDD领域通用组件</description>
<dependencies> <dependencies>
<dependency>
<groupId>com.datamate</groupId>
<artifactId>security-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId> <artifactId>spring-boot-starter</artifactId>

View File

@@ -0,0 +1,203 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.domain.model.AuthUserAccount;
import com.datamate.common.auth.domain.model.AuthUserSummary;
import com.datamate.common.auth.infrastructure.exception.AuthErrorCode;
import com.datamate.common.auth.infrastructure.persistence.mapper.AuthMapper;
import com.datamate.common.auth.interfaces.rest.dto.AuthCurrentUserResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthLoginResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserView;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserWithRolesResponse;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.security.JwtUtils;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtException;
import lombok.RequiredArgsConstructor;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 认证授权应用服务
*/
@Service
@RequiredArgsConstructor
public class AuthApplicationService {
private static final String TOKEN_TYPE = "Bearer";
private final AuthMapper authMapper;
private final JwtUtils jwtUtils;
private final PasswordEncoder passwordEncoder;
public AuthLoginResponse login(String username, String password) {
AuthUserAccount user = authMapper.findUserByUsername(username);
BusinessAssert.notNull(user, AuthErrorCode.INVALID_CREDENTIALS);
BusinessAssert.isTrue(Boolean.TRUE.equals(user.getEnabled()), AuthErrorCode.ACCOUNT_DISABLED);
BusinessAssert.isTrue(passwordEncoder.matches(password, user.getPasswordHash()), AuthErrorCode.INVALID_CREDENTIALS);
AuthBundle authBundle = loadAuthBundle(user.getId());
String token = buildToken(authBundle);
authMapper.updateLastLoginAt(user.getId());
return new AuthLoginResponse(
token,
TOKEN_TYPE,
computeExpiresInSeconds(token),
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public AuthCurrentUserResponse getCurrentUser(String token) {
Claims claims = parseClaims(token);
Long userId = parseUserId(claims);
AuthBundle authBundle = loadAuthBundle(userId);
return new AuthCurrentUserResponse(
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public AuthLoginResponse refreshToken(String token) {
Claims claims = parseClaims(token);
Long userId = parseUserId(claims);
AuthBundle authBundle = loadAuthBundle(userId);
String refreshedToken = buildToken(authBundle);
return new AuthLoginResponse(
refreshedToken,
TOKEN_TYPE,
computeExpiresInSeconds(refreshedToken),
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public List<AuthUserWithRolesResponse> listUsersWithRoles() {
List<AuthUserSummary> users = authMapper.listUsers();
List<AuthUserWithRolesResponse> responses = new ArrayList<>(users.size());
for (AuthUserSummary user : users) {
List<String> roleCodes = authMapper.findRolesByUserId(user.getId())
.stream()
.map(AuthRoleInfo::getRoleCode)
.filter(Objects::nonNull)
.toList();
responses.add(new AuthUserWithRolesResponse(
user.getId(),
user.getUsername(),
user.getFullName(),
user.getEmail(),
user.getEnabled(),
roleCodes
));
}
return responses;
}
public List<AuthRoleInfo> listRoles() {
return authMapper.listRoles();
}
public List<AuthPermissionInfo> listPermissions() {
return authMapper.listPermissions();
}
public void assignUserRoles(Long userId, List<String> roleIds) {
AuthUserAccount user = authMapper.findUserById(userId);
BusinessAssert.notNull(user, AuthErrorCode.USER_NOT_FOUND);
Set<String> distinctRoleIds = new LinkedHashSet<>(roleIds);
BusinessAssert.notEmpty(distinctRoleIds, AuthErrorCode.ROLE_NOT_FOUND);
int existingRoleCount = authMapper.countRolesByIds(new ArrayList<>(distinctRoleIds));
BusinessAssert.isTrue(existingRoleCount == distinctRoleIds.size(), AuthErrorCode.ROLE_NOT_FOUND);
authMapper.deleteUserRoles(userId);
authMapper.insertUserRoles(userId, new ArrayList<>(distinctRoleIds));
}
private String buildToken(AuthBundle authBundle) {
Map<String, Object> claims = Map.of(
"userId", authBundle.user().getId(),
"roles", authBundle.roleCodes(),
"permissions", authBundle.permissionCodes()
);
return jwtUtils.generateToken(authBundle.user().getUsername(), claims);
}
private AuthBundle loadAuthBundle(Long userId) {
AuthUserAccount user = authMapper.findUserById(userId);
BusinessAssert.notNull(user, AuthErrorCode.USER_NOT_FOUND);
BusinessAssert.isTrue(Boolean.TRUE.equals(user.getEnabled()), AuthErrorCode.ACCOUNT_DISABLED);
List<String> roleCodes = authMapper.findRolesByUserId(userId).stream()
.map(AuthRoleInfo::getRoleCode)
.filter(Objects::nonNull)
.toList();
List<String> permissionCodes = authMapper.findPermissionCodesByUserId(userId).stream()
.filter(Objects::nonNull)
.distinct()
.collect(Collectors.toList());
return new AuthBundle(user, roleCodes, permissionCodes);
}
private Claims parseClaims(String token) {
try {
return jwtUtils.getClaimsFromToken(token);
} catch (JwtException | IllegalArgumentException e) {
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
}
private Long parseUserId(Claims claims) {
Object userIdObject = claims.get("userId");
if (userIdObject instanceof Number number) {
return number.longValue();
}
if (userIdObject instanceof String str) {
try {
return Long.parseLong(str);
} catch (NumberFormatException e) {
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
}
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
private long computeExpiresInSeconds(String token) {
Date expirationDate = jwtUtils.getExpirationDateFromToken(token);
long seconds = Duration.between(new Date().toInstant(), expirationDate.toInstant()).toSeconds();
return Math.max(seconds, 0L);
}
private AuthUserView toUserView(AuthUserAccount user) {
return new AuthUserView(
user.getId(),
user.getUsername(),
user.getFullName(),
user.getEmail(),
user.getAvatarUrl(),
user.getOrganization()
);
}
private record AuthBundle(
AuthUserAccount user,
List<String> roleCodes,
List<String> permissionCodes
) {
}
}

View File

@@ -0,0 +1,58 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Objects;
/**
* 资源访问控制服务(基于请求用户上下文)
*/
@Service
public class ResourceAccessService {
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
public boolean isAdmin() {
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
}
public String getCurrentUserId() {
return RequestUserContextHolder.getCurrentUserId();
}
public String requireCurrentUserId() {
String currentUserId = getCurrentUserId();
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
return currentUserId;
}
/**
* 资源列表查询的 owner 过滤:
* - 管理员返回 null(不过滤)
* - 非管理员返回当前用户ID
*/
public String resolveOwnerFilterUserId() {
if (isAdmin()) {
return null;
}
return requireCurrentUserId();
}
/**
* 校验当前用户是否可访问 owner 资源
*/
public void assertOwnerAccess(String ownerUserId) {
if (isAdmin()) {
return;
}
String currentUserId = requireCurrentUserId();
BusinessAssert.isTrue(
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 权限信息
*/
@Getter
@Setter
public class AuthPermissionInfo {
private String id;
private String permissionCode;
private String permissionName;
private String module;
private String action;
private String pathPattern;
private String method;
private Boolean enabled;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 角色信息
*/
@Getter
@Setter
public class AuthRoleInfo {
private String id;
private String roleCode;
private String roleName;
private String description;
private Boolean enabled;
}

View File

@@ -0,0 +1,24 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
import java.time.LocalDateTime;
/**
* 认证用户账户
*/
@Getter
@Setter
public class AuthUserAccount {
private Long id;
private String username;
private String email;
private String passwordHash;
private String fullName;
private String avatarUrl;
private String organization;
private Boolean enabled;
private LocalDateTime lastLoginAt;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 用户摘要
*/
@Getter
@Setter
public class AuthUserSummary {
private Long id;
private String username;
private String email;
private String fullName;
private Boolean enabled;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.infrastructure.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
/**
* 认证模块配置
*/
@Configuration
public class AuthConfiguration {
@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
}

View File

@@ -0,0 +1,40 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.Getter;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* 请求级用户上下文
*/
@Getter
public class RequestUserContext {
private final String userId;
private final String username;
private final List<String> roles;
private RequestUserContext(String userId, String username, List<String> roles) {
this.userId = userId;
this.username = username;
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
}
public static RequestUserContext of(String userId, String username, List<String> roles) {
return new RequestUserContext(userId, username, roles);
}
public static RequestUserContext empty() {
return new RequestUserContext(null, null, Collections.emptyList());
}
public boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
}
}

View File

@@ -0,0 +1,49 @@
package com.datamate.common.auth.infrastructure.context;
import org.springframework.core.NamedThreadLocal;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 请求级用户上下文持有器
*/
public final class RequestUserContextHolder {
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
new NamedThreadLocal<>("datamate-request-user-context");
private RequestUserContextHolder() {
}
public static void set(RequestUserContext context) {
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
}
public static RequestUserContext get() {
RequestUserContext context = USER_CONTEXT_HOLDER.get();
return context == null ? RequestUserContext.empty() : context;
}
public static String getCurrentUserId() {
return get().getUserId();
}
public static List<String> getCurrentRoles() {
List<String> roles = get().getRoles();
return roles == null ? Collections.emptyList() : roles;
}
public static boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return getCurrentRoles().stream()
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
}
public static void clear() {
USER_CONTEXT_HOLDER.remove();
}
}

View File

@@ -0,0 +1,53 @@
package com.datamate.common.auth.infrastructure.context;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* 从网关透传请求头中提取用户上下文
*/
@Component
public class RequestUserContextInterceptor implements HandlerInterceptor {
private static final String HEADER_USER_ID = "X-User-Id";
private static final String HEADER_USER_NAME = "X-User-Name";
private static final String HEADER_USER_ROLES = "X-User-Roles";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
RequestUserContextHolder.clear();
}
private String normalizeValue(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
private List<String> parseRoleCodes(String roleHeader) {
if (!StringUtils.hasText(roleHeader)) {
return Collections.emptyList();
}
return Arrays.stream(roleHeader.split(","))
.map(String::trim)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* 请求用户上下文拦截器注册
*/
@Configuration
@RequiredArgsConstructor
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
private final RequestUserContextInterceptor requestUserContextInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
}
}

View File

@@ -0,0 +1,23 @@
package com.datamate.common.auth.infrastructure.exception;
import com.datamate.common.infrastructure.exception.ErrorCode;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 认证授权错误码
*/
@Getter
@AllArgsConstructor
public enum AuthErrorCode implements ErrorCode {
INVALID_CREDENTIALS("auth.0001", "用户名或密码错误"),
ACCOUNT_DISABLED("auth.0002", "账号已被禁用"),
TOKEN_INVALID("auth.0003", "登录状态已失效"),
USER_NOT_FOUND("auth.0004", "用户不存在"),
ROLE_NOT_FOUND("auth.0005", "角色不存在"),
AUTHORIZATION_DENIED("auth.0006", "无权限执行该操作");
private final String code;
private final String message;
}

View File

@@ -0,0 +1,39 @@
package com.datamate.common.auth.infrastructure.persistence.mapper;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.domain.model.AuthUserAccount;
import com.datamate.common.auth.domain.model.AuthUserSummary;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
/**
* 认证授权数据访问
*/
@Mapper
public interface AuthMapper {
AuthUserAccount findUserByUsername(@Param("username") String username);
AuthUserAccount findUserById(@Param("userId") Long userId);
int updateLastLoginAt(@Param("userId") Long userId);
List<AuthRoleInfo> findRolesByUserId(@Param("userId") Long userId);
List<String> findPermissionCodesByUserId(@Param("userId") Long userId);
List<AuthUserSummary> listUsers();
List<AuthRoleInfo> listRoles();
List<AuthPermissionInfo> listPermissions();
int countRolesByIds(@Param("roleIds") List<String> roleIds);
int deleteUserRoles(@Param("userId") Long userId);
int insertUserRoles(@Param("userId") Long userId, @Param("roleIds") List<String> roleIds);
}

View File

@@ -0,0 +1,82 @@
package com.datamate.common.auth.interfaces.rest;
import com.datamate.common.auth.application.AuthApplicationService;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.interfaces.rest.dto.AssignUserRolesRequest;
import com.datamate.common.auth.interfaces.rest.dto.AuthCurrentUserResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthLoginResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserWithRolesResponse;
import com.datamate.common.auth.interfaces.rest.dto.LoginRequest;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.auth.infrastructure.exception.AuthErrorCode;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
/**
* 认证授权控制器
*/
@RestController
@RequestMapping("/auth")
@RequiredArgsConstructor
public class AuthController {
private final AuthApplicationService authApplicationService;
@PostMapping("/login")
public AuthLoginResponse login(@RequestBody @Valid LoginRequest loginRequest) {
return authApplicationService.login(loginRequest.username(), loginRequest.password());
}
@GetMapping("/me")
public AuthCurrentUserResponse me(HttpServletRequest request) {
return authApplicationService.getCurrentUser(extractBearerToken(request.getHeader("Authorization")));
}
@PostMapping("/refresh")
public AuthLoginResponse refresh(@RequestHeader("Authorization") String authorization) {
return authApplicationService.refreshToken(extractBearerToken(authorization));
}
@GetMapping("/users")
public List<AuthUserWithRolesResponse> listUsers() {
return authApplicationService.listUsersWithRoles();
}
@PutMapping("/users/{userId}/roles")
public void assignRoles(@PathVariable("userId") Long userId,
@RequestBody @Valid AssignUserRolesRequest request) {
authApplicationService.assignUserRoles(userId, request.roleIds());
}
@GetMapping("/roles")
public List<AuthRoleInfo> listRoles() {
return authApplicationService.listRoles();
}
@GetMapping("/permissions")
public List<AuthPermissionInfo> listPermissions() {
return authApplicationService.listPermissions();
}
private String extractBearerToken(String authorizationHeader) {
BusinessAssert.isTrue(
authorizationHeader != null && authorizationHeader.startsWith("Bearer "),
AuthErrorCode.TOKEN_INVALID
);
String token = authorizationHeader.substring("Bearer ".length()).trim();
BusinessAssert.isTrue(!token.isEmpty(), AuthErrorCode.TOKEN_INVALID);
return token;
}
}

View File

@@ -0,0 +1,14 @@
package com.datamate.common.auth.interfaces.rest.dto;
import jakarta.validation.constraints.NotEmpty;
import java.util.List;
/**
* 用户角色分配请求
*/
public record AssignUserRolesRequest(
@NotEmpty(message = "角色列表不能为空") List<String> roleIds
) {
}

View File

@@ -0,0 +1,14 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 当前用户信息响应
*/
public record AuthCurrentUserResponse(
AuthUserView user,
List<String> roles,
List<String> permissions
) {
}

View File

@@ -0,0 +1,17 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 登录响应
*/
public record AuthLoginResponse(
String token,
String tokenType,
long expiresInSeconds,
AuthUserView user,
List<String> roles,
List<String> permissions
) {
}

View File

@@ -0,0 +1,15 @@
package com.datamate.common.auth.interfaces.rest.dto;
/**
* 登录用户信息
*/
public record AuthUserView(
Long id,
String username,
String fullName,
String email,
String avatarUrl,
String organization
) {
}

View File

@@ -0,0 +1,17 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 用户与角色响应
*/
public record AuthUserWithRolesResponse(
Long id,
String username,
String fullName,
String email,
Boolean enabled,
List<String> roleCodes
) {
}

View File

@@ -0,0 +1,13 @@
package com.datamate.common.auth.interfaces.rest.dto;
import jakarta.validation.constraints.NotBlank;
/**
* 登录请求
*/
public record LoginRequest(
@NotBlank(message = "用户名不能为空") String username,
@NotBlank(message = "密码不能为空") String password
) {
}

View File

@@ -1,9 +1,11 @@
package com.datamate.common.infrastructure.config; package com.datamate.common.infrastructure.config;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler; import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.MetaObject;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.util.StringUtils;
import java.time.LocalDateTime; import java.time.LocalDateTime;
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
* 获取当前用户(需要根据你的安全框架实现) * 获取当前用户(需要根据你的安全框架实现)
*/ */
private String getCurrentUser() { private String getCurrentUser() {
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等 String currentUserId = RequestUserContextHolder.getCurrentUserId();
// 示例:返回默认用户或从SecurityContext获取 if (StringUtils.hasText(currentUserId)) {
try { return currentUserId;
// 如果是Spring Security
// return SecurityContextHolder.getContext().getAuthentication().getName();
// 临时返回默认值,请根据实际情况修改
return "system";
} catch (Exception e) {
log.error("Error getting current user", e);
return "unknown";
} }
return "system";
} }
} }

View File

@@ -3,6 +3,7 @@ package com.datamate.common.infrastructure.config;
import com.datamate.common.infrastructure.common.Response; import com.datamate.common.infrastructure.common.Response;
import com.datamate.common.infrastructure.exception.BusinessException; import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode; import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.http.HttpStatus;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.validation.BindException; import org.springframework.validation.BindException;
@@ -28,7 +29,8 @@ public class GlobalExceptionHandler {
@ExceptionHandler(BusinessException.class) @ExceptionHandler(BusinessException.class)
public ResponseEntity<Response<?>> handleBusinessException(BusinessException e) { public ResponseEntity<Response<?>> handleBusinessException(BusinessException e) {
log.warn("BusinessException: code={}, message={}", e.getCode(), e.getMessage(), e); log.warn("BusinessException: code={}, message={}", e.getCode(), e.getMessage(), e);
return ResponseEntity.internalServerError().body(Response.error(e.getErrorCodeEnum())); HttpStatus status = resolveBusinessStatus(e.getCode());
return ResponseEntity.status(status).body(Response.error(e.getErrorCodeEnum()));
} }
/** /**
@@ -51,4 +53,17 @@ public class GlobalExceptionHandler {
log.error("SystemException: ", e); log.error("SystemException: ", e);
return ResponseEntity.internalServerError().body(Response.error(SystemErrorCode.SYSTEM_BUSY)); return ResponseEntity.internalServerError().body(Response.error(SystemErrorCode.SYSTEM_BUSY));
} }
private HttpStatus resolveBusinessStatus(String code) {
if (code == null) {
return HttpStatus.INTERNAL_SERVER_ERROR;
}
if (!code.startsWith("auth.")) {
return HttpStatus.INTERNAL_SERVER_ERROR;
}
if ("auth.0006".equals(code)) {
return HttpStatus.FORBIDDEN;
}
return HttpStatus.UNAUTHORIZED;
}
} }

View File

@@ -0,0 +1,120 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.datamate.common.auth.infrastructure.persistence.mapper.AuthMapper">
<select id="findUserByUsername" resultType="com.datamate.common.auth.domain.model.AuthUserAccount">
SELECT id,
username,
email,
password_hash AS passwordHash,
full_name AS fullName,
avatar_url AS avatarUrl,
organization,
enabled,
last_login_at AS lastLoginAt
FROM users
WHERE username = #{username}
LIMIT 1
</select>
<select id="findUserById" resultType="com.datamate.common.auth.domain.model.AuthUserAccount">
SELECT id,
username,
email,
password_hash AS passwordHash,
full_name AS fullName,
avatar_url AS avatarUrl,
organization,
enabled,
last_login_at AS lastLoginAt
FROM users
WHERE id = #{userId}
LIMIT 1
</select>
<update id="updateLastLoginAt">
UPDATE users
SET last_login_at = NOW()
WHERE id = #{userId}
</update>
<select id="findRolesByUserId" resultType="com.datamate.common.auth.domain.model.AuthRoleInfo">
SELECT r.id,
r.role_code AS roleCode,
r.role_name AS roleName,
r.description,
r.enabled
FROM t_auth_roles r
INNER JOIN t_auth_user_roles ur ON ur.role_id = r.id
WHERE ur.user_id = #{userId}
ORDER BY r.role_code
</select>
<select id="findPermissionCodesByUserId" resultType="string">
SELECT DISTINCT p.permission_code
FROM t_auth_permissions p
INNER JOIN t_auth_role_permissions rp ON rp.permission_id = p.id
INNER JOIN t_auth_user_roles ur ON ur.role_id = rp.role_id
WHERE ur.user_id = #{userId}
AND p.enabled = 1
ORDER BY p.permission_code
</select>
<select id="listUsers" resultType="com.datamate.common.auth.domain.model.AuthUserSummary">
SELECT id,
username,
email,
full_name AS fullName,
enabled
FROM users
ORDER BY id ASC
</select>
<select id="listRoles" resultType="com.datamate.common.auth.domain.model.AuthRoleInfo">
SELECT id,
role_code AS roleCode,
role_name AS roleName,
description,
enabled
FROM t_auth_roles
ORDER BY role_code ASC
</select>
<select id="listPermissions" resultType="com.datamate.common.auth.domain.model.AuthPermissionInfo">
SELECT id,
permission_code AS permissionCode,
permission_name AS permissionName,
module,
action,
path_pattern AS pathPattern,
method,
enabled
FROM t_auth_permissions
ORDER BY module ASC, action ASC
</select>
<select id="countRolesByIds" resultType="int">
SELECT COUNT(1)
FROM t_auth_roles
WHERE id IN
<foreach collection="roleIds" item="roleId" open="(" separator="," close=")">
#{roleId}
</foreach>
</select>
<delete id="deleteUserRoles">
DELETE
FROM t_auth_user_roles
WHERE user_id = #{userId}
</delete>
<insert id="insertUserRoles">
INSERT INTO t_auth_user_roles (user_id, role_id)
VALUES
<foreach collection="roleIds" item="roleId" separator=",">
(#{userId}, #{roleId})
</foreach>
</insert>
</mapper>

View File

@@ -3,9 +3,13 @@ package com.datamate.common.security;
import io.jsonwebtoken.*; import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys; import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.StringUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@@ -15,15 +19,23 @@ import java.util.Map;
*/ */
@Component @Component
public class JwtUtils { public class JwtUtils {
private static final String DEFAULT_SECRET = "datamate-secret-key-for-jwt-token-generation";
@Value("${jwt.secret:datamate-secret-key-for-jwt-token-generation}") @Value("${jwt.secret:" + DEFAULT_SECRET + "}")
private String secret; private String secret;
@Value("${jwt.expiration:86400}") // 24小时 @Value("${jwt.expiration:86400}") // 24小时
private Long expiration; private Long expiration;
private SecretKey getSigningKey() { private SecretKey getSigningKey() {
return Keys.hmacShaKeyFor(secret.getBytes()); String secretValue = StringUtils.hasText(secret) ? secret : DEFAULT_SECRET;
try {
MessageDigest digest = MessageDigest.getInstance("SHA-512");
byte[] keyBytes = digest.digest(secretValue.getBytes(StandardCharsets.UTF_8));
return Keys.hmacShaKeyFor(keyBytes);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot initialize JWT signing key", e);
}
} }
/** /**
@@ -84,7 +96,18 @@ public class JwtUtils {
public Boolean validateToken(String token, String username) { public Boolean validateToken(String token, String username) {
try { try {
String tokenUsername = getUsernameFromToken(token); String tokenUsername = getUsernameFromToken(token);
return (username.equals(tokenUsername) && !isTokenExpired(token)); return (username.equals(tokenUsername) && validateToken(token));
} catch (JwtException | IllegalArgumentException e) {
return false;
}
}
/**
* 仅校验令牌格式与有效期
*/
public Boolean validateToken(String token) {
try {
return !isTokenExpired(token);
} catch (JwtException | IllegalArgumentException e) { } catch (JwtException | IllegalArgumentException e) {
return false; return false;
} }

View File

@@ -0,0 +1,75 @@
export const PermissionCodes = {
dataManagementRead: "module:data-management:read",
dataManagementWrite: "module:data-management:write",
dataAnnotationRead: "module:data-annotation:read",
dataAnnotationWrite: "module:data-annotation:write",
dataCollectionRead: "module:data-collection:read",
dataCollectionWrite: "module:data-collection:write",
dataEvaluationRead: "module:data-evaluation:read",
dataEvaluationWrite: "module:data-evaluation:write",
dataSynthesisRead: "module:data-synthesis:read",
dataSynthesisWrite: "module:data-synthesis:write",
knowledgeManagementRead: "module:knowledge-management:read",
knowledgeManagementWrite: "module:knowledge-management:write",
knowledgeBaseRead: "module:knowledge-base:read",
knowledgeBaseWrite: "module:knowledge-base:write",
operatorMarketRead: "module:operator-market:read",
operatorMarketWrite: "module:operator-market:write",
orchestrationRead: "module:orchestration:read",
orchestrationWrite: "module:orchestration:write",
contentGenerationUse: "module:content-generation:use",
agentUse: "module:agent:use",
userManage: "system:user:manage",
roleManage: "system:role:manage",
permissionManage: "system:permission:manage",
} as const;
const routePermissionRules: Array<{ prefix: string; permission: string }> = [
{ prefix: "/data/management", permission: PermissionCodes.dataManagementRead },
{ prefix: "/data/annotation", permission: PermissionCodes.dataAnnotationRead },
{ prefix: "/data/collection", permission: PermissionCodes.dataCollectionRead },
{ prefix: "/data/evaluation", permission: PermissionCodes.dataEvaluationRead },
{ prefix: "/data/synthesis", permission: PermissionCodes.dataSynthesisRead },
{ prefix: "/data/knowledge-management", permission: PermissionCodes.knowledgeManagementRead },
{ prefix: "/data/knowledge-base", permission: PermissionCodes.knowledgeBaseRead },
{ prefix: "/data/operator-market", permission: PermissionCodes.operatorMarketRead },
{ prefix: "/data/orchestration", permission: PermissionCodes.orchestrationRead },
{ prefix: "/data/content-generation", permission: PermissionCodes.contentGenerationUse },
{ prefix: "/chat", permission: PermissionCodes.agentUse },
];
const defaultRouteCandidates: Array<{ path: string; permission: string }> = [
{ path: "/data/management", permission: PermissionCodes.dataManagementRead },
{ path: "/data/annotation", permission: PermissionCodes.dataAnnotationRead },
{ path: "/data/knowledge-management", permission: PermissionCodes.knowledgeManagementRead },
{ path: "/data/knowledge-base", permission: PermissionCodes.knowledgeBaseRead },
{ path: "/chat", permission: PermissionCodes.agentUse },
];
export function hasPermission(
userPermissions: string[] | undefined,
requiredPermission?: string | null
): boolean {
if (!requiredPermission) {
return true;
}
return (userPermissions ?? []).includes(requiredPermission);
}
export function resolveRequiredPermissionByPath(pathname: string): string | null {
if (pathname === "/403") {
return null;
}
const matchedRule = routePermissionRules.find((rule) =>
pathname.startsWith(rule.prefix)
);
return matchedRule?.permission ?? null;
}
export function resolveDefaultAuthorizedPath(userPermissions: string[]): string {
const matchedPath = defaultRouteCandidates.find((candidate) =>
hasPermission(userPermissions, candidate.permission)
)?.path;
return matchedPath ?? "/403";
}

View File

@@ -1,20 +1,53 @@
import React from 'react'; import React from 'react';
import { Navigate, useLocation, Outlet } from 'react-router'; import { Navigate, useLocation, Outlet } from 'react-router';
import { useAppSelector } from '@/store/hooks'; import { useAppDispatch, useAppSelector } from '@/store/hooks';
import { fetchCurrentUser, markInitialized } from '@/store/slices/authSlice';
import {
hasPermission,
resolveDefaultAuthorizedPath,
resolveRequiredPermissionByPath,
} from '@/auth/permissions';
interface ProtectedRouteProps { interface ProtectedRouteProps {
children?: React.ReactNode; children?: React.ReactNode;
} }
const ProtectedRoute: React.FC<ProtectedRouteProps> = ({ children }) => { const ProtectedRoute: React.FC<ProtectedRouteProps> = ({ children }) => {
const { isAuthenticated } = useAppSelector((state) => state.auth); const dispatch = useAppDispatch();
const { isAuthenticated, token, initialized, loading, permissions } = useAppSelector(
(state) => state.auth
);
const location = useLocation(); const location = useLocation();
const requiredPermission = resolveRequiredPermissionByPath(location.pathname);
React.useEffect(() => {
if (initialized || loading) {
return;
}
if (!token) {
dispatch(markInitialized());
return;
}
void dispatch(fetchCurrentUser());
}, [dispatch, initialized, loading, token]);
if (!initialized || loading) {
return null;
}
if (!isAuthenticated) { if (!isAuthenticated) {
// Redirect to the login page, but save the current location they were trying to go to // Redirect to the login page, but save the current location they were trying to go to
return <Navigate to="/login" state={{ from: location }} replace />; return <Navigate to="/login" state={{ from: location }} replace />;
} }
if (!hasPermission(permissions, requiredPermission)) {
const fallbackPath = resolveDefaultAuthorizedPath(permissions);
if (location.pathname === fallbackPath) {
return <Navigate to="/403" replace />;
}
return <Navigate to={fallbackPath} replace />;
}
return children ? <>{children}</> : <Outlet />; return children ? <>{children}</> : <Outlet />;
}; };

View File

@@ -0,0 +1,24 @@
import React from "react";
import { Button, Result } from "antd";
import { useNavigate } from "react-router";
const ForbiddenPage: React.FC = () => {
const navigate = useNavigate();
return (
<div className="h-screen w-full flex items-center justify-center bg-[#050b14]">
<Result
status="403"
title="403"
subTitle="你当前账号没有访问该页面的权限。"
extra={
<Button type="primary" onClick={() => navigate("/data/management")}>
</Button>
}
/>
</div>
);
};
export default ForbiddenPage;

View File

@@ -1,4 +1,4 @@
import { memo, useCallback, useEffect, useState } from "react"; import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { Button, Drawer, Menu, Popover } from "antd"; import { Button, Drawer, Menu, Popover } from "antd";
import { import {
CloseOutlined, CloseOutlined,
@@ -14,6 +14,7 @@ import SettingsPage from "../SettingsPage/SettingsPage";
import { useAppSelector, useAppDispatch } from "@/store/hooks"; import { useAppSelector, useAppDispatch } from "@/store/hooks";
import { showSettings, hideSettings } from "@/store/slices/settingsSlice"; import { showSettings, hideSettings } from "@/store/slices/settingsSlice";
import { logout } from "@/store/slices/authSlice"; import { logout } from "@/store/slices/authSlice";
import { hasPermission } from "@/auth/permissions";
const isPathMatch = (currentPath: string, targetPath: string) => const isPathMatch = (currentPath: string, targetPath: string) =>
currentPath === targetPath || currentPath.startsWith(`${targetPath}/`); currentPath === targetPath || currentPath.startsWith(`${targetPath}/`);
@@ -25,13 +26,36 @@ const AsiderAndHeaderLayout = () => {
const [sidebarOpen, setSidebarOpen] = useState(true); const [sidebarOpen, setSidebarOpen] = useState(true);
const [taskCenterVisible, setTaskCenterVisible] = useState(false); const [taskCenterVisible, setTaskCenterVisible] = useState(false);
const settingVisible = useAppSelector((state) => state.settings.visible); const settingVisible = useAppSelector((state) => state.settings.visible);
const permissions = useAppSelector((state) => state.auth.permissions);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const visibleMenuItems = useMemo(
() =>
menuItems
.map((item) => ({
...item,
children: item.children?.filter((subItem) =>
hasPermission(permissions, (subItem as { permissionCode?: string }).permissionCode)
),
}))
.filter((item) => {
const selfVisible = hasPermission(
permissions,
(item as { permissionCode?: string }).permissionCode
);
if (item.children && item.children.length > 0) {
return selfVisible;
}
return selfVisible;
}),
[permissions]
);
// Initialize active item based on current pathname // Initialize active item based on current pathname
const initActiveItem = useCallback(() => { const initActiveItem = useCallback(() => {
const dataPath = pathname.startsWith("/data/") ? pathname.slice(6) : pathname; const dataPath = pathname.startsWith("/data/") ? pathname.slice(6) : pathname;
for (let index = 0; index < menuItems.length; index++) { for (let index = 0; index < visibleMenuItems.length; index++) {
const element = menuItems[index]; const element = visibleMenuItems[index];
if (element.children) { if (element.children) {
for (const subItem of element.children) { for (const subItem of element.children) {
if (isPathMatch(dataPath, subItem.id)) { if (isPathMatch(dataPath, subItem.id)) {
@@ -44,7 +68,8 @@ const AsiderAndHeaderLayout = () => {
return; return;
} }
} }
}, [pathname]); setActiveItem(visibleMenuItems[0]?.id ?? "");
}, [pathname, visibleMenuItems]);
useEffect(() => { useEffect(() => {
initActiveItem(); initActiveItem();
@@ -100,7 +125,7 @@ const AsiderAndHeaderLayout = () => {
<Menu <Menu
mode="inline" mode="inline"
inlineCollapsed={!sidebarOpen} inlineCollapsed={!sidebarOpen}
items={menuItems.map((item) => ({ items={visibleMenuItems.map((item) => ({
key: item.id, key: item.id,
label: item.title, label: item.title,
icon: item.icon ? <item.icon className="w-4 h-4" /> : null, icon: item.icon ? <item.icon className="w-4 h-4" /> : null,

View File

@@ -13,6 +13,7 @@ import {
// Store, // Store,
// Merge, // Merge,
} from "lucide-react"; } from "lucide-react";
import { PermissionCodes } from "@/auth/permissions";
export const menuItems = [ export const menuItems = [
// { // {
@@ -26,6 +27,7 @@ export const menuItems = [
id: "management", id: "management",
title: "数集管理", title: "数集管理",
icon: FolderOpen, icon: FolderOpen,
permissionCode: PermissionCodes.dataManagementRead,
description: "创建、导入和管理数据集", description: "创建、导入和管理数据集",
color: "bg-blue-500", color: "bg-blue-500",
}, },
@@ -33,6 +35,7 @@ export const menuItems = [
id: "annotation", id: "annotation",
title: "数据标注", title: "数据标注",
icon: Tag, icon: Tag,
permissionCode: PermissionCodes.dataAnnotationRead,
description: "对数据进行标注和标记", description: "对数据进行标注和标记",
color: "bg-green-500", color: "bg-green-500",
}, },
@@ -40,6 +43,7 @@ export const menuItems = [
id: "content-generation", id: "content-generation",
title: "内容生成", title: "内容生成",
icon: Sparkles, icon: Sparkles,
permissionCode: PermissionCodes.contentGenerationUse,
description: "智能内容生成与创作", description: "智能内容生成与创作",
color: "bg-purple-500", color: "bg-purple-500",
}, },
@@ -47,6 +51,7 @@ export const menuItems = [
id: "knowledge-management", id: "knowledge-management",
title: "知识管理", title: "知识管理",
icon: Shield, icon: Shield,
permissionCode: PermissionCodes.knowledgeManagementRead,
description: "管理知识集与知识条目", description: "管理知识集与知识条目",
color: "bg-indigo-500", color: "bg-indigo-500",
}, },

View File

@@ -1,9 +1,9 @@
import React, { useState } from 'react'; import React from 'react';
import { useNavigate, useLocation } from 'react-router'; import { useNavigate, useLocation } from 'react-router';
import { Form, Input, Button, Typography, message, Card } from 'antd'; import { Form, Input, Button, Typography, message } from 'antd';
import { UserOutlined, LockOutlined } from '@ant-design/icons'; import { UserOutlined, LockOutlined } from '@ant-design/icons';
import { useAppDispatch, useAppSelector } from '@/store/hooks'; import { useAppDispatch, useAppSelector } from '@/store/hooks';
import { loginLocal } from '@/store/slices/authSlice'; import { loginUser } from '@/store/slices/authSlice';
const { Title, Text } = Typography; const { Title, Text } = Typography;
@@ -11,19 +11,20 @@ const LoginPage: React.FC = () => {
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { loading, error } = useAppSelector((state) => state.auth); const { loading } = useAppSelector((state) => state.auth);
const [messageApi, contextHolder] = message.useMessage(); const [messageApi, contextHolder] = message.useMessage();
const from = location.state?.from?.pathname || '/data'; const from = location.state?.from?.pathname || '/data';
const onFinish = (values: any) => { const onFinish = async (values: { username: string; password: string }) => {
dispatch(loginLocal(values)); try {
// The reducer updates state synchronously. await dispatch(loginUser(values)).unwrap();
if (values.username === 'admin' && values.password === '123456') { messageApi.success('登录成功');
messageApi.success('登录成功'); navigate(from, { replace: true });
navigate(from, { replace: true }); } catch (loginError) {
} else { const messageText =
messageApi.error('账号或密码错误'); typeof loginError === 'string' ? loginError : '账号或密码错误';
messageApi.error(messageText);
} }
}; };
@@ -59,9 +60,9 @@ const LoginPage: React.FC = () => {
</Text> </Text>
</div> </div>
<Form <Form<{ username: string; password: string }>
name="login" name="login"
initialValues={{ remember: true, username: 'admin', password: '123456' }} initialValues={{ username: 'admin', password: '123456' }}
onFinish={onFinish} onFinish={onFinish}
layout="vertical" layout="vertical"
size="large" size="large"

View File

@@ -1,12 +1,51 @@
import { useState } from "react"; import { useEffect, useMemo, useState } from "react";
import { Menu } from "antd"; import { Menu } from "antd";
import { SettingOutlined } from "@ant-design/icons"; import { SettingOutlined, TeamOutlined } from "@ant-design/icons";
import { Component } from "lucide-react"; import { Component } from "lucide-react";
import SystemConfig from "./SystemConfig"; import SystemConfig from "./SystemConfig";
import ModelAccess from "./ModelAccess"; import ModelAccess from "./ModelAccess";
import UserPermissionManagement from "./UserPermissionManagement";
import { useAppSelector } from "@/store/hooks";
import { hasPermission, PermissionCodes } from "@/auth/permissions";
export default function SettingsPage() { export default function SettingsPage() {
const [activeTab, setActiveTab] = useState("model-access"); const permissions = useAppSelector((state) => state.auth.permissions);
const canManageUsers = hasPermission(permissions, PermissionCodes.userManage);
const canViewRoles = hasPermission(permissions, PermissionCodes.roleManage);
const canViewPermissions = hasPermission(
permissions,
PermissionCodes.permissionManage
);
const tabs = useMemo(() => {
const nextTabs = [
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
];
if (canManageUsers || canViewRoles || canViewPermissions) {
nextTabs.push({
key: "user-permission",
icon: <TeamOutlined />,
label: "用户与权限",
});
}
return nextTabs;
}, [canManageUsers, canViewPermissions, canViewRoles]);
const [activeTab, setActiveTab] = useState<string>(tabs[0]?.key ?? "model-access");
useEffect(() => {
const hasActiveTab = tabs.some((tab) => tab.key === activeTab);
if (!hasActiveTab && tabs.length > 0) {
setActiveTab(tabs[0].key);
}
}, [activeTab, tabs]);
return ( return (
<div className="h-screen flex"> <div className="h-screen flex">
@@ -18,21 +57,10 @@ export default function SettingsPage() {
<div className="h-full"> <div className="h-full">
<Menu <Menu
mode="inline" mode="inline"
items={[ items={tabs}
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
]}
selectedKeys={[activeTab]} selectedKeys={[activeTab]}
onClick={({ key }) => { onClick={({ key }) => {
setActiveTab(key); setActiveTab(String(key));
}} }}
/> />
</div> </div>
@@ -41,6 +69,13 @@ export default function SettingsPage() {
{/* 内容区域,根据 activeTab 渲染不同的组件 */} {/* 内容区域,根据 activeTab 渲染不同的组件 */}
{activeTab === "system-config" && <SystemConfig />} {activeTab === "system-config" && <SystemConfig />}
{activeTab === "model-access" && <ModelAccess />} {activeTab === "model-access" && <ModelAccess />}
{activeTab === "user-permission" && (
<UserPermissionManagement
canManageUsers={canManageUsers}
canViewRoles={canViewRoles}
canViewPermissions={canViewPermissions}
/>
)}
</div> </div>
</div> </div>
); );

View File

@@ -0,0 +1,321 @@
import { useCallback, useEffect, useMemo, useState } from "react";
import {
Button,
Card,
Empty,
message,
Modal,
Select,
Space,
Table,
Tag,
Typography,
} from "antd";
import type { ColumnsType } from "antd/es/table";
import {
assignUserRolesUsingPut,
listAuthPermissionsUsingGet,
listAuthRolesUsingGet,
listAuthUsersUsingGet,
} from "./settings.apis";
import type {
AuthPermissionInfo,
AuthRoleInfo,
AuthUserWithRoles,
} from "./settings.apis";
interface ApiResponse<T> {
code: string;
message: string;
data: T;
}
interface UserPermissionManagementProps {
canManageUsers: boolean;
canViewRoles: boolean;
canViewPermissions: boolean;
}
export default function UserPermissionManagement({
canManageUsers,
canViewRoles,
canViewPermissions,
}: UserPermissionManagementProps) {
const [loading, setLoading] = useState(false);
const [users, setUsers] = useState<AuthUserWithRoles[]>([]);
const [roles, setRoles] = useState<AuthRoleInfo[]>([]);
const [permissions, setPermissions] = useState<AuthPermissionInfo[]>([]);
const [editingUser, setEditingUser] = useState<AuthUserWithRoles | null>(null);
const [selectedRoleCodes, setSelectedRoleCodes] = useState<string[]>([]);
const [submitting, setSubmitting] = useState(false);
const canShowAnything = canManageUsers || canViewRoles || canViewPermissions;
const canAssignRoles = canManageUsers && roles.length > 0;
const roleNameMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.roleName || role.roleCode])),
[roles]
);
const roleCodeToIdMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.id])),
[roles]
);
const loadData = useCallback(async () => {
setLoading(true);
try {
const requestTasks: Array<Promise<unknown>> = [];
if (canManageUsers || canViewRoles || canViewPermissions) {
requestTasks.push(listAuthUsersUsingGet());
}
if (canManageUsers || canViewRoles) {
requestTasks.push(listAuthRolesUsingGet());
}
if (canViewPermissions) {
requestTasks.push(listAuthPermissionsUsingGet());
}
const responses = await Promise.all(requestTasks);
let index = 0;
if (canManageUsers || canViewRoles || canViewPermissions) {
const userResponse = responses[index++] as ApiResponse<AuthUserWithRoles[]>;
setUsers(userResponse?.data ?? []);
}
if (canManageUsers || canViewRoles) {
const roleResponse = responses[index++] as ApiResponse<AuthRoleInfo[]>;
setRoles(roleResponse?.data ?? []);
} else {
setRoles([]);
}
if (canViewPermissions) {
const permissionResponse = responses[index++] as ApiResponse<AuthPermissionInfo[]>;
setPermissions(permissionResponse?.data ?? []);
} else {
setPermissions([]);
}
} catch (error) {
message.error("加载用户权限信息失败");
console.error("加载用户权限信息失败:", error);
} finally {
setLoading(false);
}
}, [canManageUsers, canViewPermissions, canViewRoles]);
useEffect(() => {
if (!canShowAnything) {
return;
}
void loadData();
}, [canShowAnything, loadData]);
const userColumns: ColumnsType<AuthUserWithRoles> = [
{
title: "用户名",
dataIndex: "username",
key: "username",
width: 180,
},
{
title: "姓名",
dataIndex: "fullName",
key: "fullName",
width: 180,
render: (value?: string) => value || "-",
},
{
title: "邮箱",
dataIndex: "email",
key: "email",
render: (value?: string) => value || "-",
},
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "角色",
dataIndex: "roleCodes",
key: "roleCodes",
render: (roleCodes: string[]) => (
<Space wrap>
{(roleCodes ?? []).map((roleCode) => (
<Tag key={roleCode}>{roleNameMap.get(roleCode) || roleCode}</Tag>
))}
</Space>
),
},
{
title: "操作",
key: "actions",
width: 120,
render: (_, record) => (
<Button
type="link"
disabled={!canAssignRoles}
onClick={() => {
setEditingUser(record);
setSelectedRoleCodes(record.roleCodes ?? []);
}}
>
</Button>
),
},
];
const roleColumns: ColumnsType<AuthRoleInfo> = [
{ title: "角色编码", dataIndex: "roleCode", key: "roleCode", width: 220 },
{ title: "角色名称", dataIndex: "roleName", key: "roleName", width: 180 },
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "描述",
dataIndex: "description",
key: "description",
render: (value?: string) => value || "-",
},
];
const permissionColumns: ColumnsType<AuthPermissionInfo> = [
{
title: "权限编码",
dataIndex: "permissionCode",
key: "permissionCode",
width: 260,
},
{
title: "权限名称",
dataIndex: "permissionName",
key: "permissionName",
width: 200,
},
{
title: "模块",
dataIndex: "module",
key: "module",
width: 140,
render: (value?: string) => value || "-",
},
{
title: "动作",
dataIndex: "action",
key: "action",
width: 120,
render: (value?: string) => value || "-",
},
{
title: "接口",
key: "api",
render: (_, record) =>
record.pathPattern ? `${record.method || "ALL"} ${record.pathPattern}` : "-",
},
];
const handleAssignRoles = async () => {
if (!editingUser) {
return;
}
if (selectedRoleCodes.length === 0) {
message.warning("请至少选择一个角色");
return;
}
const roleIds = selectedRoleCodes
.map((roleCode) => roleCodeToIdMap.get(roleCode))
.filter((roleId): roleId is string => Boolean(roleId));
if (roleIds.length !== selectedRoleCodes.length) {
message.error("角色映射失败,请刷新后重试");
return;
}
setSubmitting(true);
try {
await assignUserRolesUsingPut(editingUser.id, roleIds);
message.success("角色分配成功");
setEditingUser(null);
setSelectedRoleCodes([]);
await loadData();
} catch (error) {
message.error("角色分配失败");
console.error("角色分配失败:", error);
} finally {
setSubmitting(false);
}
};
if (!canShowAnything) {
return <Empty description="当前账号无用户与权限管理权限" />;
}
return (
<Space direction="vertical" size={16} className="w-full">
<Card title="用户管理">
<Table
loading={loading}
rowKey="id"
dataSource={users}
columns={userColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
{canViewRoles && (
<Card title="角色列表">
<Table
loading={loading}
rowKey="id"
dataSource={roles}
columns={roleColumns}
pagination={{ pageSize: 8, showSizeChanger: false }}
/>
</Card>
)}
{canViewPermissions && (
<Card title="权限列表">
<Table
loading={loading}
rowKey="id"
dataSource={permissions}
columns={permissionColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
)}
<Modal
title={`分配角色 - ${editingUser?.username || ""}`}
open={Boolean(editingUser)}
confirmLoading={submitting}
onOk={() => {
void handleAssignRoles();
}}
onCancel={() => {
setEditingUser(null);
setSelectedRoleCodes([]);
}}
>
{roles.length === 0 ? (
<Typography.Text type="secondary"></Typography.Text>
) : (
<Select
mode="multiple"
className="w-full"
placeholder="请选择角色"
value={selectedRoleCodes}
onChange={(values) => setSelectedRoleCodes(values)}
options={roles.map((role) => ({
value: role.roleCode,
label: `${role.roleName} (${role.roleCode})`,
}))}
/>
)}
</Modal>
</Space>
);
}

View File

@@ -1,11 +1,11 @@
import { get, post, put, del } from "@/utils/request"; import { get, post, put, del } from "@/utils/request";
// 模型相关接口 // 模型相关接口
export function queryModelProvidersUsingGet(params?: any) { export function queryModelProvidersUsingGet(params?: Record<string, unknown>) {
return get("/api/models/providers", params); return get("/api/models/providers", params);
} }
export function queryModelListUsingGet(data: any) { export function queryModelListUsingGet(data: Record<string, unknown>) {
return get("/api/models/list", data); return get("/api/models/list", data);
} }
@@ -15,12 +15,12 @@ export function queryModelDetailByIdUsingGet(id: string | number) {
export function updateModelByIdUsingPut( export function updateModelByIdUsingPut(
id: string | number, id: string | number,
data: any data: Record<string, unknown>
) { ) {
return put(`/api/models/${id}`, data); return put(`/api/models/${id}`, data);
} }
export function createModelUsingPost(data: any) { export function createModelUsingPost(data: Record<string, unknown>) {
return post("/api/models/create", data); return post("/api/models/create", data);
} }
@@ -28,13 +28,60 @@ export function deleteModelByIdUsingDelete(id: string | number) {
return del(`/api/models/${id}`); return del(`/api/models/${id}`);
} }
// 获取系统参数列表 // 获取系统参数列表
export function getSysParamList() { export function getSysParamList() {
return get('/api/sys-param/list'); return get("/api/sys-param/list");
} }
// 更新系统参数值 // 更新系统参数值
export const updateSysParamValue = async (params: { id: string; paramValue: string }) => { export const updateSysParamValue = async (params: {
id: string;
paramValue: string;
}) => {
return put(`/api/sys-param/${params.id}`, params); return put(`/api/sys-param/${params.id}`, params);
}; };
export interface AuthUserWithRoles {
id: number;
username: string;
fullName?: string;
email?: string;
enabled?: boolean;
roleCodes: string[];
}
export interface AuthRoleInfo {
id: string;
roleCode: string;
roleName: string;
description?: string;
enabled?: boolean;
}
export interface AuthPermissionInfo {
id: string;
permissionCode: string;
permissionName: string;
module?: string;
action?: string;
pathPattern?: string;
method?: string;
enabled?: boolean;
}
// 用户与权限管理接口
export function listAuthUsersUsingGet() {
return get("/api/auth/users");
}
export function listAuthRolesUsingGet() {
return get("/api/auth/roles");
}
export function listAuthPermissionsUsingGet() {
return get("/api/auth/permissions");
}
export function assignUserRolesUsingPut(userId: number, roleIds: string[]) {
return put(`/api/auth/users/${userId}/roles`, { roleIds });
}

View File

@@ -51,6 +51,7 @@ import Home from "@/pages/Home/Home";
import ContentGenerationPage from "@/pages/ContentGeneration/ContentGenerationPage"; import ContentGenerationPage from "@/pages/ContentGeneration/ContentGenerationPage";
import LoginPage from "@/pages/Login/LoginPage"; import LoginPage from "@/pages/Login/LoginPage";
import ProtectedRoute from "@/components/ProtectedRoute"; import ProtectedRoute from "@/components/ProtectedRoute";
import ForbiddenPage from "@/pages/Forbidden/ForbiddenPage";
const router = createBrowserRouter([ const router = createBrowserRouter([
{ {
@@ -64,6 +65,10 @@ const router = createBrowserRouter([
{ {
Component: ProtectedRoute, Component: ProtectedRoute,
children: [ children: [
{
path: "/403",
Component: ForbiddenPage,
},
{ {
path: "/chat", path: "/chat",
Component: withErrorBoundary(AgentPage), Component: withErrorBoundary(AgentPage),
@@ -299,4 +304,4 @@ const router = createBrowserRouter([
} }
]); ]);
export default router; export default router;

View File

@@ -1,66 +1,124 @@
// store/slices/authSlice.js import { createAsyncThunk, createSlice } from "@reduxjs/toolkit";
import { createSlice, createAsyncThunk } from '@reduxjs/toolkit'; import { get, post } from "@/utils/request";
// 异步 thunk interface AuthUserView {
export const loginUser = createAsyncThunk( id: number;
'auth/login', username: string;
async (credentials, { rejectWithValue }) => { fullName?: string;
try { email?: string;
const response = await fetch('/api/auth/login', { avatarUrl?: string;
method: 'POST', organization?: string;
headers: { }
'Content-Type': 'application/json',
},
body: JSON.stringify(credentials),
});
if (!response.ok) { interface AuthLoginPayload {
throw new Error('Login failed'); token: string;
} tokenType: string;
expiresInSeconds: number;
user: AuthUserView;
roles: string[];
permissions: string[];
}
const data = await response.json(); interface AuthCurrentUserPayload {
return data; user: AuthUserView;
} catch (error) { roles: string[];
return rejectWithValue(error.message); permissions: string[];
} }
interface ApiResponse<T> {
code: string;
message: string;
data: T;
}
interface AuthState {
user: AuthUserView | null;
token: string | null;
roles: string[];
permissions: string[];
isAuthenticated: boolean;
initialized: boolean;
loading: boolean;
error: string | null;
}
interface LoginCredentials {
username: string;
password: string;
}
const extractErrorMessage = (error: unknown): string => {
if (error instanceof Error) {
const nestedMessage = (error as { data?: { message?: string } }).data?.message;
return nestedMessage ?? error.message;
} }
); return "登录失败,请稍后重试";
};
export const loginUser = createAsyncThunk<
AuthLoginPayload,
LoginCredentials,
{ rejectValue: string }
>("auth/login", async (credentials, { rejectWithValue }) => {
try {
const response = (await post("/api/auth/login", credentials)) as ApiResponse<AuthLoginPayload>;
if (!response?.data?.token) {
return rejectWithValue(response?.message ?? "登录失败");
}
return response.data;
} catch (error) {
return rejectWithValue(extractErrorMessage(error));
}
});
export const fetchCurrentUser = createAsyncThunk<
AuthCurrentUserPayload,
void,
{ rejectValue: string }
>("auth/fetchCurrentUser", async (_, { rejectWithValue }) => {
try {
const response = (await get("/api/auth/me")) as ApiResponse<AuthCurrentUserPayload>;
if (!response?.data?.user) {
return rejectWithValue(response?.message ?? "用户信息加载失败");
}
return response.data;
} catch (error) {
return rejectWithValue(extractErrorMessage(error));
}
});
const initialToken = localStorage.getItem("token");
const initialState: AuthState = {
user: null,
token: initialToken,
roles: [],
permissions: [],
isAuthenticated: Boolean(initialToken),
initialized: false,
loading: false,
error: null,
};
const authSlice = createSlice({ const authSlice = createSlice({
name: 'auth', name: "auth",
initialState: { initialState,
user: null,
token: localStorage.getItem('token'),
isAuthenticated: !!localStorage.getItem('token'),
loading: false,
error: null,
},
reducers: { reducers: {
logout: (state) => { logout: (state) => {
state.user = null; state.user = null;
state.token = null; state.token = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false; state.isAuthenticated = false;
localStorage.removeItem('token'); state.error = null;
state.initialized = true;
localStorage.removeItem("token");
}, },
clearError: (state) => { clearError: (state) => {
state.error = null; state.error = null;
}, },
setToken: (state, action) => { markInitialized: (state) => {
state.token = action.payload; state.initialized = true;
localStorage.setItem('token', action.payload);
},
loginLocal: (state, action) => {
const { username, password } = action.payload;
if (username === 'admin' && password === '123456') {
state.user = { username: 'admin', role: 'admin' };
state.token = 'mock-token-' + Date.now();
state.isAuthenticated = true;
localStorage.setItem('token', state.token);
state.error = null;
} else {
state.error = 'Invalid credentials';
state.isAuthenticated = false;
}
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
@@ -71,18 +129,52 @@ const authSlice = createSlice({
}) })
.addCase(loginUser.fulfilled, (state, action) => { .addCase(loginUser.fulfilled, (state, action) => {
state.loading = false; state.loading = false;
state.initialized = true;
state.user = action.payload.user; state.user = action.payload.user;
state.token = action.payload.token; state.token = action.payload.token;
state.roles = action.payload.roles ?? [];
state.permissions = action.payload.permissions ?? [];
state.isAuthenticated = true; state.isAuthenticated = true;
localStorage.setItem('token', action.payload.token); state.error = null;
localStorage.setItem("token", action.payload.token);
}) })
.addCase(loginUser.rejected, (state, action) => { .addCase(loginUser.rejected, (state, action) => {
state.loading = false; state.loading = false;
state.error = action.payload; state.initialized = true;
state.user = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false; state.isAuthenticated = false;
state.token = null;
state.error = action.payload ?? "登录失败";
localStorage.removeItem("token");
})
.addCase(fetchCurrentUser.pending, (state) => {
state.loading = true;
state.error = null;
})
.addCase(fetchCurrentUser.fulfilled, (state, action) => {
state.loading = false;
state.initialized = true;
state.user = action.payload.user;
state.roles = action.payload.roles ?? [];
state.permissions = action.payload.permissions ?? [];
state.isAuthenticated = Boolean(state.token);
state.error = null;
})
.addCase(fetchCurrentUser.rejected, (state, action) => {
state.loading = false;
state.initialized = true;
state.user = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false;
state.token = null;
state.error = action.payload ?? "登录状态已失效";
localStorage.removeItem("token");
}); });
}, },
}); });
export const { logout, clearError, setToken, loginLocal } = authSlice.actions; export const { logout, clearError, markInitialized } = authSlice.actions;
export default authSlice.reducer; export default authSlice.reducer;

View File

@@ -524,8 +524,16 @@ request.addRequestInterceptor((config) => {
// 添加默认响应拦截器 - 错误处理 // 添加默认响应拦截器 - 错误处理
request.addResponseInterceptor((response) => { request.addResponseInterceptor((response) => {
// 可以在这里添加全局错误处理逻辑 if (response.status === 401) {
// 比如token过期自动跳转登录页等 localStorage.removeItem("token");
sessionStorage.removeItem("token");
if (window.location.pathname !== "/login") {
window.location.href = "/login";
}
}
if (response.status === 403 && window.location.pathname !== "/403") {
window.location.href = "/403";
}
return response; return response;
}); });

View File

@@ -17,12 +17,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.session import get_db from app.db.session import get_db
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from ..schema.auto import ( from ..security import (
CreateAutoAnnotationTaskRequest, RequestUserContext,
AutoAnnotationTaskResponse, assert_dataset_access,
get_request_user_context,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
) )
from ..service.auto import AutoAnnotationTaskService from ..service.auto import AutoAnnotationTaskService
@@ -37,15 +42,16 @@ service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]]) @router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks( async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取自动标注任务列表。 """获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。 前端当前不传分页参数,这里直接返回所有未删除任务。
""" """
tasks = await service.list_tasks(db) tasks = await service.list_tasks(db, user_context)
return StandardResponse( return StandardResponse(
code=200, code=200,
message="success", message="success",
@@ -54,28 +60,30 @@ async def list_auto_annotation_tasks(
@router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.post("", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def create_auto_annotation_task( async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest, request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""创建自动标注任务。 """创建自动标注任务。
当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。 当前仅创建任务记录并置为 pending,实际执行由后续调度/worker 完成。
""" """
logger.info( logger.info(
"Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s", "Creating auto annotation task: name=%s, dataset_id=%s, config=%s, file_ids=%s",
request.name, request.name,
request.dataset_id, request.dataset_id,
request.config.model_dump(by_alias=True), request.config.model_dump(by_alias=True),
request.file_ids, request.file_ids,
) )
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建 # 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None dataset_name = None
total_images = 0 total_images = 0
try: await assert_dataset_access(db, request.dataset_id, user_context)
dm_client = DatasetManagementService(db) try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount # Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
dataset = await dm_client.get_dataset(request.dataset_id) dataset = await dm_client.get_dataset(request.dataset_id)
if dataset is not None: if dataset is not None:
@@ -103,16 +111,17 @@ async def create_auto_annotation_task(
@router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse]) @router.get("/{task_id}/status", response_model=StandardResponse[AutoAnnotationTaskResponse])
async def get_auto_annotation_task_status( async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个自动标注任务状态。 """获取单个自动标注任务状态。
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。 前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
""" """
task = await service.get_task(db, task_id) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -124,13 +133,14 @@ async def get_auto_annotation_task_status(
@router.delete("/{task_id}", response_model=StandardResponse[bool]) @router.delete("/{task_id}", response_model=StandardResponse[bool])
async def delete_auto_annotation_task( async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。""" """删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id) ok = await service.soft_delete_task(db, task_id, user_context)
if not ok: if not ok:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -142,10 +152,11 @@ async def delete_auto_annotation_task(
@router.get("/{task_id}/download") @router.get("/{task_id}/download")
async def download_auto_annotation_result( async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"), task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。""" """下载指定自动标注任务的结果 ZIP。"""
import io import io
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
import tempfile import tempfile
# 复用服务层获取任务信息 # 复用服务层获取任务信息
task = await service.get_task(db, task_id) task = await service.get_task(db, task_id, user_context)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
UpsertAnnotationResponse, UpsertAnnotationResponse,
) )
from app.module.annotation.service.editor import AnnotationEditorService from app.module.annotation.service.editor import AnnotationEditorService
from app.module.annotation.security import (
RequestUserContext,
get_request_user_context,
)
from app.module.shared.schema import StandardResponse from app.module.shared.schema import StandardResponse
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -44,8 +48,9 @@ router = APIRouter(
async def get_editor_project_info( async def get_editor_project_info(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
info = await service.get_project_info(project_id) info = await service.get_project_info(project_id)
return StandardResponse(code=200, message="success", data=info) return StandardResponse(code=200, message="success", data=info)
@@ -64,8 +69,9 @@ async def list_editor_tasks(
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)", description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.list_tasks( result = await service.list_tasks(
project_id, project_id,
page=page, page=page,
@@ -86,8 +92,9 @@ async def get_editor_task(
None, alias="segmentIndex", description="段落索引(分段模式下使用)" None, alias="segmentIndex", description="段落索引(分段模式下使用)"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
task = await service.get_task(project_id, file_id, segment_index=segment_index) task = await service.get_task(project_id, file_id, segment_index=segment_index)
return StandardResponse(code=200, message="success", data=task) return StandardResponse(code=200, message="success", data=task)
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)" ..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.get_task_segment(project_id, file_id, segment_index) result = await service.get_task_segment(project_id, file_id, segment_index)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.upsert_annotation(project_id, file_id, request) result = await service.upsert_annotation(project_id, file_id, request)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -132,11 +141,12 @@ async def check_file_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
检查文件是否有新版本 检查文件是否有新版本
""" """
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.check_file_version(project_id, file_id) result = await service.check_file_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)
@@ -149,10 +159,11 @@ async def use_new_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"), project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"), file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
使用文件新版本并清空标注 使用文件新版本并清空标注
""" """
service = AnnotationEditorService(db) service = AnnotationEditorService(db, user_context)
result = await service.use_new_version(project_id, file_id) result = await service.use_new_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result) return StandardResponse(code=200, message="success", data=result)

View File

@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService from ..service.template import AnnotationTemplateService
from ..service.knowledge_sync import KnowledgeSyncService from ..service.knowledge_sync import KnowledgeSyncService
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201 "", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
) )
async def create_mapping( async def create_mapping(
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db) request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
创建数据集映射 创建数据集映射
@@ -58,6 +65,8 @@ async def create_mapping(
mapping_service = DatasetMappingService(db) mapping_service = DatasetMappingService(db)
template_service = AnnotationTemplateService() template_service = AnnotationTemplateService()
await assert_dataset_access(db, request.dataset_id, user_context)
logger.info(f"Create dataset mapping request: {request.dataset_id}") logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息 # 从DM服务获取数据集信息
@@ -163,7 +172,7 @@ async def create_mapping(
try: try:
from ..service.editor import AnnotationEditorService from ..service.editor import AnnotationEditorService
editor_service = AnnotationEditorService(db) editor_service = AnnotationEditorService(db, user_context)
# 异步预计算切片(不阻塞创建响应) # 异步预计算切片(不阻塞创建响应)
segmentation_result = ( segmentation_result = (
await editor_service.precompute_segmentation_for_project( await editor_service.precompute_segmentation_for_project(
@@ -202,6 +211,7 @@ async def list_mappings(
False, description="是否包含模板详情", alias="includeTemplate" False, description="是否包含模板详情", alias="includeTemplate"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
查询所有映射关系(分页) 查询所有映射关系(分页)
@@ -230,6 +240,8 @@ async def list_mappings(
limit=size, limit=size,
include_deleted=False, include_deleted=False,
include_template=include_template, include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
) )
# 计算总页数 # 计算总页数
@@ -256,7 +268,11 @@ async def list_mappings(
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse]) @router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)): async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
""" """
根据 UUID 查询单个映射关系(包含关联的标注模板详情) 根据 UUID 查询单个映射关系(包含关联的标注模板详情)
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping not found: {mapping_id}" status_code=404, detail=f"Mapping not found: {mapping_id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
logger.info( logger.info(
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}" f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
True, description="是否包含模板详情", alias="includeTemplate" True, description="是否包含模板详情", alias="includeTemplate"
), ),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
根据源数据集 ID 查询所有映射关系(分页,包含模板详情) 根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
""" """
try: try:
service = DatasetMappingService(db) service = DatasetMappingService(db)
await assert_dataset_access(db, dataset_id, user_context)
# 计算 skip # 计算 skip
skip = (page - 1) * size skip = (page - 1) * size
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
skip=skip, skip=skip,
limit=size, limit=size,
include_template=include_template, include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
) )
# 计算总页数 # 计算总页数
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
async def delete_mapping( async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
删除映射关系(软删除) 删除映射关系(软删除)
@@ -387,6 +409,7 @@ async def delete_mapping(
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping either not found or not specified." status_code=404, detail=f"Mapping either not found or not specified."
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
id = mapping.id id = mapping.id
dataset_id = mapping.dataset_id dataset_id = mapping.dataset_id
@@ -428,6 +451,7 @@ async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"), project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None, request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
更新标注项目信息 更新标注项目信息
@@ -456,6 +480,7 @@ async def update_mapping(
raise HTTPException( raise HTTPException(
status_code=404, detail=f"Mapping not found: {project_id}" status_code=404, detail=f"Mapping not found: {project_id}"
) )
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
# 构建更新数据 # 构建更新数据
update_values = {} update_values = {}

View File

@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.config import settings from app.core.config import settings
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService
from ..schema import ( from ..schema import (
SyncDatasetRequest, SyncDatasetRequest,
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse]) @router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content( async def sync_dataset_content(
request: SyncDatasetRequest, request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Sync Dataset Content (Files and Annotations) Sync Dataset Content (Files and Annotations)
@@ -51,6 +57,7 @@ async def sync_dataset_content(
status_code=404, status_code=404,
detail=f"Mapping not found: {request.id}" detail=f"Mapping not found: {request.id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
dm_client = DatasetManagementService(db) dm_client = DatasetManagementService(db)
dataset_info = await dm_client.get_dataset(mapping.dataset_id) dataset_info = await dm_client.get_dataset(mapping.dataset_id)
@@ -82,7 +89,8 @@ async def sync_dataset_content(
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse]) @router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
async def sync_annotations( async def sync_annotations(
request: SyncAnnotationsRequest, request: SyncAnnotationsRequest,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Sync Annotations Only (Bidirectional Support) Sync Annotations Only (Bidirectional Support)
@@ -102,6 +110,7 @@ async def sync_annotations(
status_code=404, status_code=404,
detail=f"Mapping not found: {request.id}" detail=f"Mapping not found: {request.id}"
) )
await assert_dataset_access(db, mapping.dataset_id, user_context)
result = SyncAnnotationsResponse( result = SyncAnnotationsResponse(
id=mapping.id, id=mapping.id,
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
async def update_file_tags( async def update_file_tags(
request: UpdateFileTagsRequest, request: UpdateFileTagsRequest,
file_id: str = Path(..., description="文件ID"), file_id: str = Path(..., description="文件ID"),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
): ):
""" """
Update File Tags (Partial Update with Auto Format Conversion) Update File Tags (Partial Update with Auto Format Conversion)
@@ -189,6 +199,7 @@ async def update_file_tags(
raise HTTPException(status_code=404, detail=f"File not found: {file_id}") raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
await assert_dataset_access(db, dataset_id, user_context)
# 查找数据集关联的模板ID # 查找数据集关联的模板ID
from ..service.mapping import DatasetMappingService from ..service.mapping import DatasetMappingService

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
from fastapi import HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.dataset_management import Dataset
HEADER_USER_ID = "X-User-Id"
HEADER_USER_NAME = "X-User-Name"
HEADER_USER_ROLES = "X-User-Roles"
ADMIN_ROLE_CODE = "ROLE_ADMIN"
@dataclass(frozen=True)
class RequestUserContext:
user_id: str
username: str | None
roles: Tuple[str, ...]
@property
def is_admin(self) -> bool:
return any(role.upper() == ADMIN_ROLE_CODE for role in self.roles)
def get_request_user_context(request: Request) -> RequestUserContext:
user_id = (request.headers.get(HEADER_USER_ID) or "").strip()
username = (request.headers.get(HEADER_USER_NAME) or "").strip() or None
role_header = request.headers.get(HEADER_USER_ROLES) or ""
roles = tuple(
role.strip()
for role in role_header.split(",")
if role and role.strip()
)
if not user_id:
raise HTTPException(status_code=403, detail="权限不足:缺少用户身份")
return RequestUserContext(user_id=user_id, username=username, roles=roles)
def ensure_dataset_owner_access(
user_context: RequestUserContext,
dataset_owner_user_id: str | None,
dataset_id: str,
) -> None:
if user_context.is_admin:
return
if not dataset_owner_user_id or dataset_owner_user_id != user_context.user_id:
raise HTTPException(
status_code=403,
detail=f"无权访问数据集: {dataset_id}",
)
async def assert_dataset_access(
db: AsyncSession,
dataset_id: str,
user_context: RequestUserContext,
) -> None:
owner_result = await db.execute(
select(Dataset.created_by).where(Dataset.id == dataset_id)
)
dataset_owner = owner_result.scalar_one_or_none()
if dataset_owner is None:
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
ensure_dataset_owner_access(user_context, str(dataset_owner), dataset_id)

View File

@@ -5,11 +5,12 @@ from typing import List, Optional
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
from ..schema.auto import ( from ..schema.auto import (
CreateAutoAnnotationTaskRequest, CreateAutoAnnotationTaskRequest,
@@ -17,7 +18,7 @@ from ..schema.auto import (
) )
class AutoAnnotationTaskService: class AutoAnnotationTaskService:
"""自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)""" """自动标注任务服务(仅管理任务元数据,真正执行由 runtime 负责)"""
async def create_task( async def create_task(
@@ -63,15 +64,27 @@ class AutoAnnotationTaskService:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id] resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp return resp
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]: def _apply_dataset_scope(self, query, user_context: RequestUserContext):
"""获取未软删除的自动标注任务列表,按创建时间倒序。""" if user_context.is_admin:
return query
result = await db.execute( return query.join(
select(AutoAnnotationTask) Dataset,
.where(AutoAnnotationTask.deleted_at.is_(None)) AutoAnnotationTask.dataset_id == Dataset.id,
.order_by(AutoAnnotationTask.created_at.desc()) ).where(Dataset.created_by == user_context.user_id)
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all()) async def list_tasks(
self,
db: AsyncSession,
user_context: RequestUserContext,
) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(
query.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
responses: List[AutoAnnotationTaskResponse] = [] responses: List[AutoAnnotationTaskResponse] = []
for task in tasks: for task in tasks:
@@ -87,16 +100,21 @@ class AutoAnnotationTaskService:
return responses return responses
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]: async def get_task(
result = await db.execute( self,
select(AutoAnnotationTask).where( db: AsyncSession,
AutoAnnotationTask.id == task_id, task_id: str,
AutoAnnotationTask.deleted_at.is_(None), user_context: RequestUserContext,
) ) -> Optional[AutoAnnotationTaskResponse]:
) query = select(AutoAnnotationTask).where(
task = result.scalar_one_or_none() AutoAnnotationTask.id == task_id,
if not task: AutoAnnotationTask.deleted_at.is_(None),
return None )
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return None
resp = AutoAnnotationTaskResponse.model_validate(task) resp = AutoAnnotationTaskResponse.model_validate(task)
try: try:
@@ -138,16 +156,21 @@ class AutoAnnotationTaskService:
return [task.dataset_id] return [task.dataset_id]
return [] return []
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool: async def soft_delete_task(
result = await db.execute( self,
select(AutoAnnotationTask).where( db: AsyncSession,
AutoAnnotationTask.id == task_id, task_id: str,
AutoAnnotationTask.deleted_at.is_(None), user_context: RequestUserContext,
) ) -> bool:
) query = select(AutoAnnotationTask).where(
task = result.scalar_one_or_none() AutoAnnotationTask.id == task_id,
if not task: AutoAnnotationTask.deleted_at.is_(None),
return False )
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return False
task.deleted_at = datetime.now() task.deleted_at = datetime.now()
await db.commit() await db.commit()

View File

@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
from app.module.annotation.service.annotation_text_splitter import ( from app.module.annotation.service.annotation_text_splitter import (
AnnotationTextSplitter, AnnotationTextSplitter,
) )
from app.module.annotation.security import (
RequestUserContext,
ensure_dataset_owner_access,
)
from app.module.annotation.service.text_fetcher import ( from app.module.annotation.service.text_fetcher import (
fetch_text_content_via_download_api, fetch_text_content_via_download_api,
) )
@@ -104,8 +108,9 @@ class AnnotationEditorService:
# 分段阈值:超过此字符数自动分段 # 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200 SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession): def __init__(self, db: AsyncSession, user_context: RequestUserContext):
self.db = db self.db = db
self.user_context = user_context
self.template_service = AnnotationTemplateService() self.template_service = AnnotationTemplateService()
@staticmethod @staticmethod
@@ -157,14 +162,24 @@ class AnnotationEditorService:
async def _get_project_or_404(self, project_id: str) -> LabelingProject: async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute( result = await self.db.execute(
select(LabelingProject).where( select(LabelingProject, Dataset.created_by).join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(
LabelingProject.id == project_id, LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None), LabelingProject.deleted_at.is_(None),
) )
) )
project = result.scalar_one_or_none() row = result.first()
if not project: if not row:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}") raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
project = row[0]
dataset_owner = row[1]
ensure_dataset_owner_access(
self.user_context,
str(dataset_owner) if dataset_owner is not None else None,
project.dataset_id,
)
return project return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]: async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:

View File

@@ -478,7 +478,9 @@ class DatasetMappingService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
include_deleted: bool = False, include_deleted: bool = False,
include_template: bool = False include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]: ) -> Tuple[List[DatasetMappingResponse], int]:
""" """
获取所有映射及总数(用于分页) 获取所有映射及总数(用于分页)
@@ -495,9 +497,16 @@ class DatasetMappingService:
query = self._build_query_with_dataset_name() query = self._build_query_with_dataset_name()
if not include_deleted: if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(LabelingProject) count_query = select(func.count()).select_from(LabelingProject)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted: if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))
@@ -557,7 +566,9 @@ class DatasetMappingService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
include_deleted: bool = False, include_deleted: bool = False,
include_template: bool = False include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]: ) -> Tuple[List[DatasetMappingResponse], int]:
""" """
根据源数据集ID获取映射关系及总数(用于分页) 根据源数据集ID获取映射关系及总数(用于分页)
@@ -578,11 +589,18 @@ class DatasetMappingService:
if not include_deleted: if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None)) query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数 # 获取总数
count_query = select(func.count()).select_from(LabelingProject).where( count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id LabelingProject.dataset_id == dataset_id
) )
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted: if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None)) count_query = count_query.where(LabelingProject.deleted_at.is_(None))

View File

@@ -1,5 +1,4 @@
import csv import csv
import csv
import datetime import datetime
import os import os
from io import StringIO from io import StringIO
@@ -76,6 +75,7 @@ class PdfTextExtractService:
source_path = self._resolve_source_path(file_record) source_path = self._resolve_source_path(file_record)
dataset_path = self._resolve_dataset_path(dataset) dataset_path = self._resolve_dataset_path(dataset)
target_path = self._resolve_target_path(dataset_path, source_path, file_record, file_id, file_type) target_path = self._resolve_target_path(dataset_path, source_path, file_record, file_id, file_type)
logical_path = self._build_logical_path(dataset_path, target_path)
existing_record = await self._find_existing_text_record(dataset_id, target_path) existing_record = await self._find_existing_text_record(dataset_id, target_path)
if existing_record: if existing_record:
@@ -85,7 +85,7 @@ class PdfTextExtractService:
file_size = self._get_file_size(target_path) file_size = self._get_file_size(target_path)
parser_name = PARSER_BY_FILE_TYPE.get(file_type, "") parser_name = PARSER_BY_FILE_TYPE.get(file_type, "")
record = await self._create_text_file_record( record = await self._create_text_file_record(
dataset, file_record, target_path, file_size, parser_name, derived_file_type dataset, file_record, target_path, logical_path, file_size, parser_name, derived_file_type
) )
return self._build_response(dataset_id, file_id, record) return self._build_response(dataset_id, file_id, record)
@@ -94,7 +94,7 @@ class PdfTextExtractService:
self._write_text_file(target_path, text_content) self._write_text_file(target_path, text_content)
file_size = self._get_file_size(target_path) file_size = self._get_file_size(target_path)
record = await self._create_text_file_record( record = await self._create_text_file_record(
dataset, file_record, target_path, file_size, parser_name, derived_file_type dataset, file_record, target_path, logical_path, file_size, parser_name, derived_file_type
) )
return self._build_response(dataset_id, file_id, record) return self._build_response(dataset_id, file_id, record)
@@ -170,6 +170,19 @@ class PdfTextExtractService:
target_dir.mkdir(parents=True, exist_ok=True) target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / output_name return target_dir / output_name
@staticmethod
def _build_logical_path(dataset_path: Path, target_path: Path) -> str:
normalized_dataset_path = dataset_path.resolve()
normalized_target_path = target_path.resolve()
try:
relative_path = normalized_target_path.relative_to(normalized_dataset_path)
except ValueError as exc:
raise HTTPException(status_code=400, detail="解析文件路径超出数据集目录") from exc
logical_path = str(relative_path).replace("\\", "/").strip()
if not logical_path:
raise HTTPException(status_code=500, detail="解析文件逻辑路径为空")
return logical_path
async def _find_existing_text_record(self, dataset_id: str, target_path: Path) -> DatasetFiles | None: async def _find_existing_text_record(self, dataset_id: str, target_path: Path) -> DatasetFiles | None:
result = await self.db.execute( result = await self.db.execute(
select(DatasetFiles).where( select(DatasetFiles).where(
@@ -259,10 +272,12 @@ class PdfTextExtractService:
dataset: Dataset, dataset: Dataset,
source_file: DatasetFiles, source_file: DatasetFiles,
target_path: Path, target_path: Path,
logical_path: str,
file_size: int, file_size: int,
parser_name: str, parser_name: str,
derived_file_type: str, derived_file_type: str,
) -> DatasetFiles: ) -> DatasetFiles:
assert logical_path
assert parser_name assert parser_name
assert derived_file_type assert derived_file_type
metadata = { metadata = {
@@ -275,6 +290,7 @@ class PdfTextExtractService:
dataset_id=dataset.id, # type: ignore[arg-type] dataset_id=dataset.id, # type: ignore[arg-type]
file_name=target_path.name, file_name=target_path.name,
file_path=str(target_path), file_path=str(target_path),
logical_path=logical_path,
file_type=derived_file_type, file_type=derived_file_type,
file_size=file_size, file_size=file_size,
dataset_filemetadata=metadata, dataset_filemetadata=metadata,

View File

@@ -74,6 +74,7 @@ class SynthesisDatasetExporter:
file_path = os.path.join(base_path, archived_file_name) file_path = os.path.join(base_path, archived_file_name)
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
self._write_jsonl(file_path, records) self._write_jsonl(file_path, records)
logical_path = self._build_logical_path(base_path, file_path)
# 计算文件大小 # 计算文件大小
try: try:
@@ -85,6 +86,7 @@ class SynthesisDatasetExporter:
dataset_id=dataset.id, dataset_id=dataset.id,
file_name=archived_file_name, file_name=archived_file_name,
file_path=file_path, file_path=file_path,
logical_path=logical_path,
file_type="jsonl", file_type="jsonl",
file_size=file_size, file_size=file_size,
last_access_time=datetime.datetime.now(), last_access_time=datetime.datetime.now(),
@@ -158,3 +160,12 @@ class SynthesisDatasetExporter:
raise SynthesisExportError("Dataset path is empty") raise SynthesisExportError("Dataset path is empty")
os.makedirs(dataset.path, exist_ok=True) os.makedirs(dataset.path, exist_ok=True)
return dataset.path return dataset.path
@staticmethod
def _build_logical_path(dataset_path: str, file_path: str) -> str:
normalized_dataset_path = os.path.abspath(dataset_path)
normalized_file_path = os.path.abspath(file_path)
relative_path = os.path.relpath(normalized_file_path, normalized_dataset_path).replace("\\", "/").strip()
if relative_path in ("", ".") or relative_path.startswith("../"):
raise SynthesisExportError(f"Invalid logical path generated for file: {file_path}")
return relative_path

View File

@@ -187,11 +187,13 @@ class RatioTaskService:
dst_dir = os.path.dirname(new_path) dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True) await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path) await asyncio.to_thread(shutil.copy2, src_path, new_path)
logical_path = RatioTaskService.build_logical_path(dst_prefix, new_path)
file_data = { file_data = {
"dataset_id": target_ds.id, # type: ignore "dataset_id": target_ds.id, # type: ignore
"file_name": file_name, "file_name": file_name,
"file_path": new_path, "file_path": new_path,
"logical_path": logical_path,
"file_type": f.file_type, "file_type": f.file_type,
"file_size": f.file_size, "file_size": f.file_size,
"check_sum": f.check_sum, "check_sum": f.check_sum,
@@ -204,6 +206,15 @@ class RatioTaskService:
session.add(DatasetFiles(**file_record)) session.add(DatasetFiles(**file_record))
existing_paths.add(new_path) existing_paths.add(new_path)
@staticmethod
def build_logical_path(dataset_prefix: str, file_path: str) -> str:
normalized_dataset_prefix = os.path.abspath(dataset_prefix)
normalized_file_path = os.path.abspath(file_path)
relative_path = os.path.relpath(normalized_file_path, normalized_dataset_prefix).replace("\\", "/").strip()
if relative_path in ("", ".") or relative_path.startswith("../"):
raise ValueError(f"Invalid logical path generated for file: {file_path}")
return relative_path
@staticmethod @staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str: def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name file_name = f.file_name

149
scripts/db/zz-auth-init.sql Normal file
View File

@@ -0,0 +1,149 @@
USE datamate;
-- =============================================
-- 认证与授权(RBAC)基础表
-- 注意:该脚本命名为 zz- 前缀,确保在 users 表初始化后执行
-- =============================================
CREATE TABLE IF NOT EXISTS t_auth_roles
(
id VARCHAR(36) PRIMARY KEY COMMENT '角色ID',
role_code VARCHAR(100) NOT NULL COMMENT '角色编码',
role_name VARCHAR(100) NOT NULL COMMENT '角色名称',
description VARCHAR(255) DEFAULT '' COMMENT '角色描述',
enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用',
is_built_in TINYINT DEFAULT 1 COMMENT '是否内置:1-是,0-否',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_auth_role_code (role_code)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='角色表';
CREATE TABLE IF NOT EXISTS t_auth_permissions
(
id VARCHAR(36) PRIMARY KEY COMMENT '权限ID',
permission_code VARCHAR(120) NOT NULL COMMENT '权限编码',
permission_name VARCHAR(120) NOT NULL COMMENT '权限名称',
module VARCHAR(100) NOT NULL COMMENT '模块',
action VARCHAR(50) NOT NULL COMMENT '动作',
path_pattern VARCHAR(255) DEFAULT '' COMMENT '路径模式',
method VARCHAR(20) DEFAULT '' COMMENT 'HTTP方法',
enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用',
is_built_in TINYINT DEFAULT 1 COMMENT '是否内置:1-是,0-否',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_auth_permission_code (permission_code),
INDEX idx_auth_permission_module_action (module, action)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='权限表';
CREATE TABLE IF NOT EXISTS t_auth_role_permissions
(
id BIGINT PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
role_id VARCHAR(36) NOT NULL COMMENT '角色ID',
permission_id VARCHAR(36) NOT NULL COMMENT '权限ID',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
UNIQUE KEY uk_auth_role_permission (role_id, permission_id),
INDEX idx_auth_role_permission_role (role_id),
INDEX idx_auth_role_permission_permission (permission_id),
CONSTRAINT fk_auth_rp_role FOREIGN KEY (role_id) REFERENCES t_auth_roles (id) ON DELETE CASCADE,
CONSTRAINT fk_auth_rp_permission FOREIGN KEY (permission_id) REFERENCES t_auth_permissions (id) ON DELETE CASCADE
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='角色权限关系表';
CREATE TABLE IF NOT EXISTS t_auth_user_roles
(
id BIGINT PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
user_id BIGINT NOT NULL COMMENT '用户ID(users.id)',
role_id VARCHAR(36) NOT NULL COMMENT '角色ID',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
UNIQUE KEY uk_auth_user_role (user_id, role_id),
INDEX idx_auth_user_role_user (user_id),
INDEX idx_auth_user_role_role (role_id),
CONSTRAINT fk_auth_ur_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
CONSTRAINT fk_auth_ur_role FOREIGN KEY (role_id) REFERENCES t_auth_roles (id) ON DELETE CASCADE
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='用户角色关系表';
-- =============================================
-- 角色初始化
-- =============================================
INSERT IGNORE INTO t_auth_roles (id, role_code, role_name, description, enabled, is_built_in)
VALUES ('role-admin', 'ROLE_ADMIN', '系统管理员', '拥有平台全部权限', 1, 1),
('role-data-editor', 'ROLE_DATA_EDITOR', '数据运营', '拥有业务模块读写权限', 1, 1),
('role-knowledge-user', 'ROLE_KNOWLEDGE_USER', '知识用户', '以知识管理为主的业务权限', 1, 1);
-- =============================================
-- 权限初始化(接口级)
-- =============================================
INSERT IGNORE INTO t_auth_permissions (id, permission_code, permission_name, module, action, path_pattern, method, enabled, is_built_in)
VALUES ('perm-dm-read', 'module:data-management:read', '数据管理读取', 'data-management', 'read', '/api/data-management/**', 'GET', 1, 1),
('perm-dm-write', 'module:data-management:write', '数据管理写入', 'data-management', 'write', '/api/data-management/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-da-read', 'module:data-annotation:read', '数据标注读取', 'data-annotation', 'read', '/api/annotation/**', 'GET', 1, 1),
('perm-da-write', 'module:data-annotation:write', '数据标注写入', 'data-annotation', 'write', '/api/annotation/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-dc-read', 'module:data-collection:read', '数据归集读取', 'data-collection', 'read', '/api/data-collection/**', 'GET', 1, 1),
('perm-dc-write', 'module:data-collection:write', '数据归集写入', 'data-collection', 'write', '/api/data-collection/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-de-read', 'module:data-evaluation:read', '数据评估读取', 'data-evaluation', 'read', '/api/evaluation/**', 'GET', 1, 1),
('perm-de-write', 'module:data-evaluation:write', '数据评估写入', 'data-evaluation', 'write', '/api/evaluation/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-ds-read', 'module:data-synthesis:read', '数据合成读取', 'data-synthesis', 'read', '/api/synthesis/**', 'GET', 1, 1),
('perm-ds-write', 'module:data-synthesis:write', '数据合成写入', 'data-synthesis', 'write', '/api/synthesis/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-km-read', 'module:knowledge-management:read', '知识管理读取', 'knowledge-management', 'read', '/api/data-management/knowledge/**', 'GET', 1, 1),
('perm-km-write', 'module:knowledge-management:write', '知识管理写入', 'knowledge-management', 'write', '/api/data-management/knowledge/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-kb-read', 'module:knowledge-base:read', '知识库读取', 'knowledge-base', 'read', '/api/knowledge-base/**', 'GET', 1, 1),
('perm-kb-write', 'module:knowledge-base:write', '知识库写入', 'knowledge-base', 'write', '/api/knowledge-base/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-om-read', 'module:operator-market:read', '算子市场读取', 'operator-market', 'read', '/api/operator-market/**', 'GET', 1, 1),
('perm-om-write', 'module:operator-market:write', '算子市场写入', 'operator-market', 'write', '/api/operator-market/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-orch-read', 'module:orchestration:read', '流程编排读取', 'orchestration', 'read', '/api/orchestration/**', 'GET', 1, 1),
('perm-orch-write', 'module:orchestration:write', '流程编排写入', 'orchestration', 'write', '/api/orchestration/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-agent-use', 'module:agent:use', '对话助手使用', 'agent', 'use', '/chat/**', 'GET', 1, 1),
('perm-content-use', 'module:content-generation:use', '内容生成功能使用', 'content-generation', 'use', '/api/content-generation/**', 'POST,PUT,PATCH', 1, 1),
('perm-user-manage', 'system:user:manage', '用户管理', 'system', 'manage-user', '/api/auth/users/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1),
('perm-role-manage', 'system:role:manage', '角色管理', 'system', 'manage-role', '/api/auth/roles/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1),
('perm-perm-manage', 'system:permission:manage', '权限管理', 'system', 'manage-permission', '/api/auth/permissions/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1);
-- 管理员拥有所有权限
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-admin', p.id
FROM t_auth_permissions p;
-- 数据运营拥有业务模块读写权限(不含系统管理)
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-data-editor', p.id
FROM t_auth_permissions p
WHERE p.permission_code IN (
'module:data-management:read', 'module:data-management:write',
'module:data-annotation:read', 'module:data-annotation:write',
'module:data-collection:read', 'module:data-collection:write',
'module:data-evaluation:read', 'module:data-evaluation:write',
'module:data-synthesis:read', 'module:data-synthesis:write',
'module:knowledge-management:read', 'module:knowledge-management:write',
'module:knowledge-base:read', 'module:knowledge-base:write',
'module:operator-market:read', 'module:operator-market:write',
'module:orchestration:read', 'module:orchestration:write',
'module:agent:use', 'module:content-generation:use'
);
-- 知识用户拥有知识相关权限及必要数据读取权限
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-knowledge-user', p.id
FROM t_auth_permissions p
WHERE p.permission_code IN (
'module:data-management:read',
'module:knowledge-management:read', 'module:knowledge-management:write',
'module:knowledge-base:read', 'module:knowledge-base:write',
'module:agent:use'
);
-- =============================================
-- 用户角色初始化(绑定到已有 users)
-- =============================================
INSERT IGNORE INTO t_auth_user_roles (user_id, role_id)
SELECT u.id, 'role-admin'
FROM users u
WHERE u.username = 'admin';
INSERT IGNORE INTO t_auth_user_roles (user_id, role_id)
SELECT u.id, 'role-knowledge-user'
FROM users u
WHERE u.username = 'knowledge_user';