Compare commits

..

7 Commits

Author SHA1 Message Date
329382db47 fix(pdf): 优化PDF文本提取服务异常处理
- 添加FeignException专门处理逻辑
- 实现详细的Feign异常日志记录功能
- 新增响应体解析和根因链构建方法
- 添加异常消息规范化处理
- 改进错误日志的可读性和调试信息完整度
2026-02-06 18:52:51 +08:00
e862925a06 feat(export): 添加逻辑路径构建功能支持文件管理
- 在导出服务中实现_build_logical_path方法用于构建相对路径
- 更新数据集文件记录以包含logical_path字段
- 在比率任务服务中实现build_logical_path静态方法
- 将逻辑路径信息添加到数据集文件记录中
- 规范化路径处理并替换反斜杠为正斜杠
- 添加无效路径验证防止目录遍历安全问题
2026-02-06 18:46:44 +08:00
05752678cc feat(dataset): 添加PDF提取服务中的逻辑路径构建功能
- 移除重复的csv导入语句
- 添加_build_logical_path方法用于构建文件逻辑路径
- 在_create_text_file_record方法中增加logical_path参数
- 更新记录创建调用以传递逻辑路径参数
- 验证逻辑路径不为空并抛出相应异常
- 将逻辑路径存储到数据集文件记录中
2026-02-06 18:30:44 +08:00
0f1dd9ec8d Merge remote-tracking branch 'gitea/lsf' into lsf 2026-02-06 18:29:58 +08:00
38e58ba864 Merge branch 'rbac' into lsf 2026-02-06 15:44:43 +08:00
6a4c4ae3d7 feat(auth): 为数据管理和RAG服务增加资源访问控制
- 在DatasetApplicationService中注入ResourceAccessService并添加所有权验证
- 在KnowledgeSetApplicationService中注入ResourceAccessService并添加所有权验证
- 修改DatasetRepository接口和实现类,增加按创建者过滤的方法
- 修改KnowledgeSetRepository接口和实现类,增加按创建者过滤的方法
- 在RAG索引器服务中添加知识库访问权限检查和作用域过滤
- 更新实体元对象处理器以使用请求用户上下文获取当前用户
- 在前端设置页面添加用户权限管理功能和角色权限控制
- 为Python标注服务增加用户上下文和数据集访问权限验证
2026-02-06 14:58:46 +08:00
056cee11cc feat(auth): 完善API网关JWT认证和权限控制功能
- 实现网关侧JWT工具类和权限规则匹配器
- 集成JWT认证流程,支持Bearer Token验证
- 添加基于路径和HTTP方法的权限控制机制
- 配置白名单路由规则,优化认证性能
- 更新前端受保护路由组件,实现权限验证
- 添加403禁止访问页面和权限检查逻辑
- 重构登录页面,集成实际认证API调用
- 实现用户信息获取和权限加载功能
- 优化全局异常处理器中的认证错误状态码
- 集成FastJSON2和JJWT依赖库支持
2026-02-06 13:21:20 +08:00
65 changed files with 2631 additions and 250 deletions

View File

@@ -36,6 +36,23 @@
<groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
<version>0.11.5</version>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<version>0.11.5</version>
<scope>runtime</scope>
</dependency>
</dependencies>
<build>

View File

@@ -1,34 +1,124 @@
package com.datamate.gateway.filter;
import com.alibaba.fastjson2.JSONObject;
import com.datamate.gateway.security.GatewayJwtUtils;
import com.datamate.gateway.security.PermissionRuleMatcher;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.nio.charset.StandardCharsets;
import java.util.List;
/**
* 用户信息过滤器
*
*/
@Slf4j
@Component
public class UserContextFilter implements GlobalFilter {
@Value("${commercial.switch:false}")
private boolean isCommercial;
public class UserContextFilter implements GlobalFilter, Ordered {
private final GatewayJwtUtils gatewayJwtUtils;
private final PermissionRuleMatcher permissionRuleMatcher;
@Value("${datamate.auth.enabled:true}")
private boolean authEnabled;
public UserContextFilter(GatewayJwtUtils gatewayJwtUtils, PermissionRuleMatcher permissionRuleMatcher) {
this.gatewayJwtUtils = gatewayJwtUtils;
this.permissionRuleMatcher = permissionRuleMatcher;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
if (!isCommercial) {
if (!authEnabled) {
return chain.filter(exchange);
}
try {
ServerHttpRequest request = exchange.getRequest();
String path = request.getURI().getPath();
HttpMethod method = request.getMethod();
} catch (Exception e) {
log.error("get current user info error", e);
if (!path.startsWith("/api/")) {
return chain.filter(exchange);
}
if (HttpMethod.OPTIONS.equals(method)) {
return chain.filter(exchange);
}
if (permissionRuleMatcher.isWhitelisted(path)) {
return chain.filter(exchange);
}
String token = extractBearerToken(request.getHeaders().getFirst("Authorization"));
if (!StringUtils.hasText(token)) {
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "未登录或登录状态已失效");
}
Claims claims;
try {
if (!gatewayJwtUtils.validateToken(token)) {
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "登录状态已失效");
}
claims = gatewayJwtUtils.getClaimsFromToken(token);
} catch (Exception ex) {
log.warn("JWT校验失败: {}", ex.getMessage());
return writeError(exchange, HttpStatus.UNAUTHORIZED, "auth.0003", "登录状态已失效");
}
String requiredPermission = permissionRuleMatcher.resolveRequiredPermission(method, path);
if (StringUtils.hasText(requiredPermission)) {
List<String> permissionCodes = gatewayJwtUtils.getStringListClaim(claims, "permissions");
if (!permissionCodes.contains(requiredPermission)) {
return writeError(exchange, HttpStatus.FORBIDDEN, "auth.0006", "权限不足");
}
}
String userId = String.valueOf(claims.get("userId"));
String username = claims.getSubject();
List<String> roles = gatewayJwtUtils.getStringListClaim(claims, "roles");
ServerHttpRequest mutatedRequest = request.mutate()
.header("X-User-Id", userId)
.header("X-User-Name", username)
.header("X-User-Roles", String.join(",", roles))
.build();
return chain.filter(exchange.mutate().request(mutatedRequest).build());
}
@Override
public int getOrder() {
return -200;
}
private String extractBearerToken(String authorizationHeader) {
if (!StringUtils.hasText(authorizationHeader)) {
return null;
}
if (!authorizationHeader.startsWith("Bearer ")) {
return null;
}
String token = authorizationHeader.substring("Bearer ".length()).trim();
return token.isEmpty() ? null : token;
}
private Mono<Void> writeError(ServerWebExchange exchange,
HttpStatus status,
String code,
String message) {
exchange.getResponse().setStatusCode(status);
exchange.getResponse().getHeaders().set("Content-Type", "application/json;charset=UTF-8");
byte[] body = JSONObject.toJSONString(new ErrorBody(code, message, null))
.getBytes(StandardCharsets.UTF_8);
return exchange.getResponse().writeWith(Mono.just(exchange.getResponse().bufferFactory().wrap(body)));
}
private record ErrorBody(String code, String message, Object data) {
}
}

View File

@@ -0,0 +1,65 @@
package com.datamate.gateway.security;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
/**
* 网关侧JWT工具
*/
@Component
public class GatewayJwtUtils {
private static final String DEFAULT_SECRET = "datamate-secret-key-for-jwt-token-generation";
@Value("${jwt.secret:" + DEFAULT_SECRET + "}")
private String secret;
public Claims getClaimsFromToken(String token) {
return Jwts.parserBuilder()
.setSigningKey(getSigningKey())
.build()
.parseClaimsJws(token)
.getBody();
}
public boolean validateToken(String token) {
Claims claims = getClaimsFromToken(token);
Date expiration = claims.getExpiration();
return expiration != null && expiration.after(new Date());
}
public List<String> getStringListClaim(Claims claims, String claimName) {
Object claimValue = claims.get(claimName);
if (!(claimValue instanceof Collection<?> values)) {
return Collections.emptyList();
}
return values.stream()
.map(String::valueOf)
.collect(Collectors.toList());
}
private SecretKey getSigningKey() {
String secretValue = StringUtils.hasText(secret) ? secret : DEFAULT_SECRET;
try {
MessageDigest digest = MessageDigest.getInstance("SHA-512");
byte[] keyBytes = digest.digest(secretValue.getBytes(StandardCharsets.UTF_8));
return Keys.hmacShaKeyFor(keyBytes);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot initialize JWT signing key", e);
}
}
}

View File

@@ -0,0 +1,85 @@
package com.datamate.gateway.security;
import lombok.Getter;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
/**
* 权限规则匹配器
*/
@Component
public class PermissionRuleMatcher {
private static final Set<HttpMethod> READ_METHODS = Set.of(HttpMethod.GET, HttpMethod.HEAD);
private static final Set<HttpMethod> WRITE_METHODS = Set.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.PATCH, HttpMethod.DELETE);
private final AntPathMatcher pathMatcher = new AntPathMatcher();
private final List<String> whiteListPatterns = List.of(
"/api/auth/login",
"/api/auth/login/**"
);
private final List<PermissionRule> rules = buildRules();
public boolean isWhitelisted(String path) {
return whiteListPatterns.stream().anyMatch(pattern -> pathMatcher.match(pattern, path));
}
public String resolveRequiredPermission(HttpMethod method, String path) {
for (PermissionRule rule : rules) {
if (rule.matches(method, path, pathMatcher)) {
return rule.getPermissionCode();
}
}
return null;
}
private List<PermissionRule> buildRules() {
List<PermissionRule> permissionRules = new ArrayList<>();
addModuleRules(permissionRules, "/api/data-management/**", "module:data-management:read", "module:data-management:write");
addModuleRules(permissionRules, "/api/annotation/**", "module:data-annotation:read", "module:data-annotation:write");
addModuleRules(permissionRules, "/api/data-collection/**", "module:data-collection:read", "module:data-collection:write");
addModuleRules(permissionRules, "/api/evaluation/**", "module:data-evaluation:read", "module:data-evaluation:write");
addModuleRules(permissionRules, "/api/synthesis/**", "module:data-synthesis:read", "module:data-synthesis:write");
addModuleRules(permissionRules, "/api/knowledge-base/**", "module:knowledge-base:read", "module:knowledge-base:write");
addModuleRules(permissionRules, "/api/operator-market/**", "module:operator-market:read", "module:operator-market:write");
addModuleRules(permissionRules, "/api/orchestration/**", "module:orchestration:read", "module:orchestration:write");
addModuleRules(permissionRules, "/api/content-generation/**", "module:content-generation:use", "module:content-generation:use");
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/users/**", "system:user:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/users/**", "system:user:manage"));
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/roles/**", "system:role:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/roles/**", "system:role:manage"));
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/permissions/**", "system:permission:manage"));
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/permissions/**", "system:permission:manage"));
return permissionRules;
}
private void addModuleRules(List<PermissionRule> rules,
String pathPattern,
String readPermissionCode,
String writePermissionCode) {
rules.add(new PermissionRule(READ_METHODS, pathPattern, readPermissionCode));
rules.add(new PermissionRule(WRITE_METHODS, pathPattern, writePermissionCode));
}
@Getter
private static class PermissionRule {
private final Set<HttpMethod> methods;
private final String pathPattern;
private final String permissionCode;
private PermissionRule(Set<HttpMethod> methods, String pathPattern, String permissionCode) {
this.methods = methods;
this.pathPattern = pathPattern;
this.permissionCode = permissionCode;
}
private boolean matches(HttpMethod method, String path, AntPathMatcher matcher) {
return method != null && methods.contains(method) && matcher.match(pathPattern, path);
}
}
}

View File

@@ -3,6 +3,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.domain.utils.ChunksSaver;
import com.datamate.common.setting.application.SysParamApplicationService;
import com.datamate.datamanagement.interfaces.dto.*;
@@ -64,6 +65,7 @@ public class DatasetApplicationService {
private final CollectionTaskClient collectionTaskClient;
private final DatasetFileApplicationService datasetFileApplicationService;
private final SysParamApplicationService sysParamService;
private final ResourceAccessService resourceAccessService;
@Value("${datamate.data-management.base-path:/dataset}")
private String datasetBasePath;
@@ -102,6 +104,7 @@ public class DatasetApplicationService {
public Dataset updateDataset(String datasetId, UpdateDatasetRequest updateDatasetRequest) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
if (StringUtils.hasText(updateDatasetRequest.getName())) {
dataset.setName(updateDatasetRequest.getName());
@@ -151,6 +154,7 @@ public class DatasetApplicationService {
public void deleteDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
long childCount = datasetRepository.countByParentId(datasetId);
BusinessAssert.isTrue(childCount == 0, DataManagementErrorCode.DATASET_HAS_CHILDREN);
datasetRepository.removeById(datasetId);
@@ -164,6 +168,7 @@ public class DatasetApplicationService {
public Dataset getDataset(String datasetId) {
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
List<DatasetFile> datasetFiles = datasetFileRepository.findAllVisibleByDatasetId(datasetId);
dataset.setFiles(datasetFiles);
applyVisibleFileCounts(Collections.singletonList(dataset));
@@ -176,7 +181,8 @@ public class DatasetApplicationService {
@Transactional(readOnly = true)
public PagedResponse<DatasetResponse> getDatasets(DatasetPagingQuery query) {
IPage<Dataset> page = new Page<>(query.getPage(), query.getSize());
page = datasetRepository.findByCriteria(page, query);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = datasetRepository.findByCriteria(page, query, ownerFilterUserId);
String datasetPvcName = getDatasetPvcName();
applyVisibleFileCounts(page.getRecords());
List<DatasetResponse> datasetResponses = DatasetConverter.INSTANCE.convertToResponse(page.getRecords());
@@ -189,6 +195,7 @@ public class DatasetApplicationService {
BusinessAssert.isTrue(StringUtils.hasText(datasetId), CommonErrorCode.PARAM_ERROR);
Dataset dataset = datasetRepository.getById(datasetId);
BusinessAssert.notNull(dataset, DataManagementErrorCode.DATASET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Set<String> sourceTags = normalizeTagNames(dataset.getTags());
if (sourceTags.isEmpty()) {
return Collections.emptyList();
@@ -198,10 +205,12 @@ public class DatasetApplicationService {
SIMILAR_DATASET_CANDIDATE_MAX,
Math.max(safeLimit * SIMILAR_DATASET_CANDIDATE_FACTOR, safeLimit)
);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
List<Dataset> candidates = datasetRepository.findSimilarByTags(
new ArrayList<>(sourceTags),
datasetId,
candidateLimit
candidateLimit,
ownerFilterUserId
);
if (CollectionUtils.isEmpty(candidates)) {
return Collections.emptyList();
@@ -436,6 +445,7 @@ public class DatasetApplicationService {
if (dataset == null) {
throw new IllegalArgumentException("Dataset not found: " + datasetId);
}
resourceAccessService.assertOwnerAccess(dataset.getCreatedBy());
Map<String, Object> statistics = new HashMap<>();
@@ -485,8 +495,12 @@ public class DatasetApplicationService {
* 获取所有数据集的汇总统计信息
*/
public AllDatasetStatisticsResponse getAllDatasetStatistics() {
if (resourceAccessService.isAdmin()) {
return datasetRepository.getAllDatasetStatistics();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
return datasetRepository.getAllDatasetStatisticsByCreatedBy(currentUserId);
}
/**
* 异步处理数据源文件扫描

View File

@@ -2,6 +2,7 @@ package com.datamate.datamanagement.application;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.CommonErrorCode;
import com.datamate.common.interfaces.PagedResponse;
@@ -40,6 +41,7 @@ import java.util.UUID;
public class KnowledgeSetApplicationService {
private final KnowledgeSetRepository knowledgeSetRepository;
private final TagMapper tagMapper;
private final ResourceAccessService resourceAccessService;
public KnowledgeSet createKnowledgeSet(CreateKnowledgeSetRequest request) {
BusinessAssert.isTrue(knowledgeSetRepository.findByName(request.getName()) == null,
@@ -64,6 +66,7 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet updateKnowledgeSet(String setId, UpdateKnowledgeSetRequest request) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
BusinessAssert.isTrue(!isReadOnlyStatus(knowledgeSet.getStatus()),
DataManagementErrorCode.KNOWLEDGE_SET_STATUS_ERROR);
@@ -119,6 +122,7 @@ public class KnowledgeSetApplicationService {
public void deleteKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
knowledgeSetRepository.removeById(setId);
}
@@ -126,13 +130,15 @@ public class KnowledgeSetApplicationService {
public KnowledgeSet getKnowledgeSet(String setId) {
KnowledgeSet knowledgeSet = knowledgeSetRepository.getById(setId);
BusinessAssert.notNull(knowledgeSet, DataManagementErrorCode.KNOWLEDGE_SET_NOT_FOUND);
resourceAccessService.assertOwnerAccess(knowledgeSet.getCreatedBy());
return knowledgeSet;
}
@Transactional(readOnly = true)
public PagedResponse<KnowledgeSetResponse> getKnowledgeSets(KnowledgeSetPagingQuery query) {
IPage<KnowledgeSet> page = new Page<>(query.getPage(), query.getSize());
page = knowledgeSetRepository.findByCriteria(page, query);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeSetRepository.findByCriteria(page, query, ownerFilterUserId);
List<KnowledgeSetResponse> responses = KnowledgeConverter.INSTANCE.convertSetResponses(page.getRecords());
return PagedResponse.of(responses, page.getCurrent(), page.getTotal(), page.getPages());
}

View File

@@ -4,6 +4,8 @@ import com.datamate.common.infrastructure.common.Response;
import com.datamate.datamanagement.infrastructure.client.PdfTextExtractClient;
import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractRequest;
import com.datamate.datamanagement.infrastructure.client.dto.PdfTextExtractResponse;
import feign.FeignException;
import feign.Request;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
@@ -47,8 +49,71 @@ public class PdfTextExtractAsyncService {
} else {
log.info("PdfTextExtract succeeded, datasetId={}, fileId={}", datasetId, fileId);
}
} catch (FeignException feignException) {
logFeignException(datasetId, fileId, feignException);
} catch (Exception e) {
log.error("PdfTextExtract call failed, datasetId={}, fileId={}", datasetId, fileId, e);
}
}
private void logFeignException(String datasetId, String fileId, FeignException feignException) {
Request request = feignException.request();
String httpMethod = request == null || request.httpMethod() == null
? "UNKNOWN"
: request.httpMethod().name();
String requestUrl = request == null || request.url() == null
? "UNKNOWN"
: request.url();
String responseBody = resolveFeignResponseBody(feignException);
String rootCauseChain = buildCauseChain(feignException, 12);
log.error(
"PdfTextExtract call failed with FeignException, datasetId={}, fileId={}, status={}, method={}, url={}, responseBody=\n{}\nrootCauseChain={}",
datasetId,
fileId,
feignException.status(),
httpMethod,
requestUrl,
responseBody,
rootCauseChain,
feignException
);
}
private String resolveFeignResponseBody(FeignException feignException) {
String responseBody = feignException.contentUTF8();
if (responseBody == null || responseBody.isBlank()) {
responseBody = feignException.getMessage();
}
if (responseBody == null || responseBody.isBlank()) {
return "EMPTY_RESPONSE_BODY";
}
return responseBody;
}
private String buildCauseChain(Throwable throwable, int maxDepth) {
StringBuilder causeChain = new StringBuilder();
Throwable current = throwable;
int depth = 0;
while (current != null && depth < maxDepth) {
if (causeChain.length() > 0) {
causeChain.append(" <- ");
}
causeChain.append(current.getClass().getSimpleName())
.append(": ")
.append(normalizeCauseMessage(current.getMessage()));
current = current.getCause();
depth++;
}
if (current != null) {
causeChain.append(" <- ...");
}
return causeChain.toString();
}
private String normalizeCauseMessage(String message) {
if (message == null || message.isBlank()) {
return "EMPTY_MESSAGE";
}
return message.replace("\r", " ").replace("\n", " ").trim();
}
}

View File

@@ -25,9 +25,11 @@ public interface DatasetRepository extends IRepository<Dataset> {
AllDatasetStatisticsResponse getAllDatasetStatistics();
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query);
AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy);
IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy);
long countByParentId(String parentDatasetId);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit);
List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy);
}

View File

@@ -11,5 +11,5 @@ import com.datamate.datamanagement.interfaces.dto.KnowledgeSetPagingQuery;
public interface KnowledgeSetRepository extends IRepository<KnowledgeSet> {
KnowledgeSet findByName(String name);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query);
IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy);
}

View File

@@ -51,10 +51,34 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query) {
public AllDatasetStatisticsResponse getAllDatasetStatisticsByCreatedBy(String createdBy) {
List<Dataset> datasets = lambdaQuery()
.eq(Dataset::getCreatedBy, createdBy)
.list();
long totalFiles = datasets.stream()
.map(Dataset::getFileCount)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
long totalSize = datasets.stream()
.map(Dataset::getSizeBytes)
.filter(java.util.Objects::nonNull)
.mapToLong(Long::longValue)
.sum();
AllDatasetStatisticsResponse response = new AllDatasetStatisticsResponse();
response.setTotalDatasets(datasets.size());
response.setTotalFiles(totalFiles);
response.setTotalSize(totalSize);
return response;
}
@Override
public IPage<Dataset> findByCriteria(IPage<Dataset> page, DatasetPagingQuery query, String createdBy) {
LambdaQueryWrapper<Dataset> wrapper = new LambdaQueryWrapper<Dataset>()
.eq(query.getType() != null, Dataset::getDatasetType, query.getType())
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus());
.eq(query.getStatus() != null, Dataset::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(createdBy), Dataset::getCreatedBy, createdBy);
if (query.getParentDatasetId() != null) {
if (StringUtils.isBlank(query.getParentDatasetId())) {
@@ -92,7 +116,7 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
}
@Override
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit) {
public List<Dataset> findSimilarByTags(List<String> tagNames, String excludedDatasetId, int limit, String createdBy) {
if (limit <= 0 || tagNames == null || tagNames.isEmpty()) {
return Collections.emptyList();
}
@@ -109,6 +133,9 @@ public class DatasetRepositoryImpl extends CrudRepository<DatasetMapper, Dataset
if (StringUtils.isNotBlank(excludedDatasetId)) {
wrapper.ne(Dataset::getId, excludedDatasetId.trim());
}
if (StringUtils.isNotBlank(createdBy)) {
wrapper.eq(Dataset::getCreatedBy, createdBy);
}
wrapper.apply("tags IS NOT NULL AND JSON_VALID(tags) = 1 AND JSON_LENGTH(tags) > 0");
wrapper.and(condition -> {
boolean hasCondition = false;

View File

@@ -25,7 +25,7 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
}
@Override
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query) {
public IPage<KnowledgeSet> findByCriteria(IPage<KnowledgeSet> page, KnowledgeSetPagingQuery query, String createdBy) {
LambdaQueryWrapper<KnowledgeSet> wrapper = new LambdaQueryWrapper<KnowledgeSet>()
.eq(query.getStatus() != null, KnowledgeSet::getStatus, query.getStatus())
.eq(StringUtils.isNotBlank(query.getDomain()), KnowledgeSet::getDomain, query.getDomain())
@@ -34,7 +34,8 @@ public class KnowledgeSetRepositoryImpl extends CrudRepository<KnowledgeSetMappe
.eq(StringUtils.isNotBlank(query.getSensitivity()), KnowledgeSet::getSensitivity, query.getSensitivity())
.eq(query.getSourceType() != null, KnowledgeSet::getSourceType, query.getSourceType())
.ge(query.getValidFrom() != null, KnowledgeSet::getValidFrom, query.getValidFrom())
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo());
.le(query.getValidTo() != null, KnowledgeSet::getValidTo, query.getValidTo())
.eq(StringUtils.isNotBlank(createdBy), KnowledgeSet::getCreatedBy, createdBy);
if (StringUtils.isNotBlank(query.getKeyword())) {
wrapper.and(w -> w.like(KnowledgeSet::getName, query.getKeyword())

View File

@@ -2,8 +2,11 @@ package com.datamate.rag.indexer.application;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.KnowledgeBaseErrorCode;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.common.interfaces.PagingQuery;
import com.datamate.common.setting.domain.entity.ModelConfig;
@@ -55,6 +58,7 @@ public class KnowledgeBaseService {
private final ApplicationEventPublisher eventPublisher;
private final ModelConfigRepository modelConfigRepository;
private final MilvusService milvusService;
private final ResourceAccessService resourceAccessService;
/**
* 创建知识库
@@ -77,8 +81,7 @@ public class KnowledgeBaseService {
*/
@Transactional(rollbackFor = Exception.class)
public void update(String knowledgeBaseId, KnowledgeBaseUpdateReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
if (StringUtils.hasText(request.getName()) && !knowledgeBase.getName().equals(request.getName())) {
milvusService.getMilvusClient().renameCollection(RenameCollectionReq.builder()
.collectionName(knowledgeBase.getName())
@@ -98,16 +101,14 @@ public class KnowledgeBaseService {
*/
@Transactional(rollbackFor = Exception.class)
public void delete(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
knowledgeBaseRepository.removeById(knowledgeBaseId);
ragFileRepository.removeByKnowledgeBaseId(knowledgeBaseId);
milvusService.getMilvusClient().dropCollection(DropCollectionReq.builder().collectionName(knowledgeBase.getName()).build());
}
public KnowledgeBaseResp getById(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
KnowledgeBaseResp resp = getKnowledgeBaseResp(knowledgeBase);
resp.setEmbedding(modelConfigRepository.getById(knowledgeBase.getEmbeddingModel()));
resp.setChat(modelConfigRepository.getById(knowledgeBase.getChatModel()));
@@ -133,7 +134,8 @@ public class KnowledgeBaseService {
public PagedResponse<KnowledgeBaseResp> list(KnowledgeBaseQueryReq request) {
IPage<KnowledgeBase> page = new Page<>(request.getPage(), request.getSize());
page = knowledgeBaseRepository.page(page, request);
String ownerFilterUserId = resourceAccessService.resolveOwnerFilterUserId();
page = knowledgeBaseRepository.page(page, request, ownerFilterUserId);
// 将 KnowledgeBase 转换为 KnowledgeBaseResp,并计算 fileCount 和 chunkCount
List<KnowledgeBaseResp> respList = page.getRecords().stream().map(this::getKnowledgeBaseResp).toList();
@@ -143,8 +145,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class)
public void addFiles(AddFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseId()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseId());
List<RagFile> ragFiles = request.getFiles().stream().map(fileInfo -> {
RagFile ragFile = new RagFile();
ragFile.setKnowledgeBaseId(knowledgeBase.getId());
@@ -170,6 +171,7 @@ public class KnowledgeBaseService {
}
public PagedResponse<RagFile> listFiles(String knowledgeBaseId, RagFileReq request) {
getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
request.setKnowledgeBaseId(knowledgeBaseId);
page = ragFileRepository.page(page, request);
@@ -177,8 +179,13 @@ public class KnowledgeBaseService {
}
public PagedResponse<KnowledgeBaseFileSearchResp> searchFiles(KnowledgeBaseFileSearchReq request) {
boolean admin = resourceAccessService.isAdmin();
List<String> scopedKnowledgeBaseIds = resolveSearchScopeKnowledgeBaseIds(request, admin);
if (!admin && scopedKnowledgeBaseIds.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), request.getPage(), 0L, 0);
}
IPage<RagFile> page = new Page<>(request.getPage(), request.getSize());
page = ragFileRepository.searchPage(page, request);
page = ragFileRepository.searchPage(page, request, scopedKnowledgeBaseIds);
List<RagFile> records = page.getRecords();
if (records.isEmpty()) {
return PagedResponse.of(Collections.emptyList(), page.getCurrent(), page.getTotal(), page.getPages());
@@ -213,8 +220,7 @@ public class KnowledgeBaseService {
@Transactional(rollbackFor = Exception.class)
public void deleteFiles(String knowledgeBaseId, DeleteFilesReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
ragFileRepository.removeByIds(request.getIds());
milvusService.getMilvusClient().delete(DeleteReq.builder()
.collectionName(knowledgeBase.getName())
@@ -223,8 +229,7 @@ public class KnowledgeBaseService {
}
public PagedResponse<RagChunk> getChunks(String knowledgeBaseId, String ragFileId, PagingQuery pagingQuery) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(knowledgeBaseId);
QueryResp results = milvusService.getMilvusClient().query(QueryReq.builder()
.collectionName(knowledgeBase.getName())
.filter("metadata[\"rag_file_id\"] == \"" + ragFileId + "\"")
@@ -259,8 +264,7 @@ public class KnowledgeBaseService {
* @return 检索结果
*/
public List<SearchResp.SearchResult> retrieve(RetrieveReq request) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(request.getKnowledgeBaseIds().getFirst()))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
KnowledgeBase knowledgeBase = getKnowledgeBaseWithAccessCheck(request.getKnowledgeBaseIds().getFirst());
ModelConfig modelConfig = modelConfigRepository.getById(knowledgeBase.getEmbeddingModel());
EmbeddingModel embeddingModel = ModelClient.invokeEmbeddingModel(modelConfig);
Embedding embedding = embeddingModel.embed(request.getQuery()).content();
@@ -273,4 +277,27 @@ public class KnowledgeBaseService {
});
return searchResults;
}
private KnowledgeBase getKnowledgeBaseWithAccessCheck(String knowledgeBaseId) {
KnowledgeBase knowledgeBase = Optional.ofNullable(knowledgeBaseRepository.getById(knowledgeBaseId))
.orElseThrow(() -> BusinessException.of(KnowledgeBaseErrorCode.KNOWLEDGE_BASE_NOT_FOUND));
resourceAccessService.assertOwnerAccess(knowledgeBase.getCreatedBy());
return knowledgeBase;
}
private List<String> resolveSearchScopeKnowledgeBaseIds(KnowledgeBaseFileSearchReq request, boolean admin) {
if (admin) {
return Collections.emptyList();
}
String currentUserId = resourceAccessService.requireCurrentUserId();
List<String> ownedKnowledgeBaseIds = knowledgeBaseRepository.listIdsByCreatedBy(currentUserId);
if (!StringUtils.hasText(request.getKnowledgeBaseId())) {
return ownedKnowledgeBaseIds;
}
BusinessAssert.isTrue(
ownedKnowledgeBaseIds.contains(request.getKnowledgeBaseId()),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
return Collections.singletonList(request.getKnowledgeBaseId());
}
}

View File

@@ -5,6 +5,8 @@ import com.baomidou.mybatisplus.extension.repository.IRepository;
import com.datamate.rag.indexer.domain.model.KnowledgeBase;
import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import java.util.List;
/**
* 知识库仓储接口
*
@@ -19,5 +21,7 @@ public interface KnowledgeBaseRepository extends IRepository<KnowledgeBase> {
* @param request 查询请求
* @return 知识库分页结果
*/
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request);
IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy);
List<String> listIdsByCreatedBy(String createdBy);
}

View File

@@ -23,5 +23,5 @@ public interface RagFileRepository extends IRepository<RagFile> {
IPage<RagFile> page(IPage<RagFile> page, RagFileReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request);
IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds);
}

View File

@@ -10,6 +10,9 @@ import com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq;
import org.springframework.stereotype.Repository;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 知识库仓储实现类
*
@@ -20,12 +23,28 @@ import org.springframework.util.StringUtils;
public class KnowledgeBaseRepositoryImpl extends CrudRepository<KnowledgeBaseMapper, KnowledgeBase> implements KnowledgeBaseRepository {
@Override
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request) {
public IPage<KnowledgeBase> page(IPage<KnowledgeBase> page, KnowledgeBaseQueryReq request, String createdBy) {
return this.page(page, new LambdaQueryWrapper<KnowledgeBase>()
.like(StringUtils.hasText(request.getName()), KnowledgeBase::getName, request.getName())
.like(StringUtils.hasText(request.getDescription()), KnowledgeBase::getDescription, request.getDescription())
.like(StringUtils.hasText(request.getCreatedBy()), KnowledgeBase::getCreatedBy, request.getCreatedBy())
.like(StringUtils.hasText(request.getUpdatedBy()), KnowledgeBase::getUpdatedBy, request.getUpdatedBy())
.eq(StringUtils.hasText(createdBy), KnowledgeBase::getCreatedBy, createdBy)
.orderByDesc(KnowledgeBase::getCreatedAt));
}
@Override
public List<String> listIdsByCreatedBy(String createdBy) {
if (!StringUtils.hasText(createdBy)) {
return Collections.emptyList();
}
return lambdaQuery()
.select(KnowledgeBase::getId)
.eq(KnowledgeBase::getCreatedBy, createdBy)
.list()
.stream()
.map(KnowledgeBase::getId)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -52,9 +52,12 @@ public class RagFileRepositoryImpl extends CrudRepository<RagFileMapper, RagFile
}
@Override
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request) {
public IPage<RagFile> searchPage(IPage<RagFile> page, KnowledgeBaseFileSearchReq request, List<String> knowledgeBaseIds) {
return lambdaQuery()
.eq(StringUtils.hasText(request.getKnowledgeBaseId()), RagFile::getKnowledgeBaseId, request.getKnowledgeBaseId())
.in(!StringUtils.hasText(request.getKnowledgeBaseId()) && knowledgeBaseIds != null && !knowledgeBaseIds.isEmpty(),
RagFile::getKnowledgeBaseId,
knowledgeBaseIds)
.like(StringUtils.hasText(request.getFileName()), RagFile::getFileName, request.getFileName())
.likeRight(StringUtils.hasText(request.getRelativePath()), RagFile::getRelativePath, normalizeRelativePath(request.getRelativePath()))
.page(page);

View File

@@ -17,6 +17,11 @@
<description>DDD领域通用组件</description>
<dependencies>
<dependency>
<groupId>com.datamate</groupId>
<artifactId>security-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter</artifactId>

View File

@@ -0,0 +1,203 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.domain.model.AuthUserAccount;
import com.datamate.common.auth.domain.model.AuthUserSummary;
import com.datamate.common.auth.infrastructure.exception.AuthErrorCode;
import com.datamate.common.auth.infrastructure.persistence.mapper.AuthMapper;
import com.datamate.common.auth.interfaces.rest.dto.AuthCurrentUserResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthLoginResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserView;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserWithRolesResponse;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.security.JwtUtils;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtException;
import lombok.RequiredArgsConstructor;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 认证授权应用服务
*/
@Service
@RequiredArgsConstructor
public class AuthApplicationService {
private static final String TOKEN_TYPE = "Bearer";
private final AuthMapper authMapper;
private final JwtUtils jwtUtils;
private final PasswordEncoder passwordEncoder;
public AuthLoginResponse login(String username, String password) {
AuthUserAccount user = authMapper.findUserByUsername(username);
BusinessAssert.notNull(user, AuthErrorCode.INVALID_CREDENTIALS);
BusinessAssert.isTrue(Boolean.TRUE.equals(user.getEnabled()), AuthErrorCode.ACCOUNT_DISABLED);
BusinessAssert.isTrue(passwordEncoder.matches(password, user.getPasswordHash()), AuthErrorCode.INVALID_CREDENTIALS);
AuthBundle authBundle = loadAuthBundle(user.getId());
String token = buildToken(authBundle);
authMapper.updateLastLoginAt(user.getId());
return new AuthLoginResponse(
token,
TOKEN_TYPE,
computeExpiresInSeconds(token),
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public AuthCurrentUserResponse getCurrentUser(String token) {
Claims claims = parseClaims(token);
Long userId = parseUserId(claims);
AuthBundle authBundle = loadAuthBundle(userId);
return new AuthCurrentUserResponse(
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public AuthLoginResponse refreshToken(String token) {
Claims claims = parseClaims(token);
Long userId = parseUserId(claims);
AuthBundle authBundle = loadAuthBundle(userId);
String refreshedToken = buildToken(authBundle);
return new AuthLoginResponse(
refreshedToken,
TOKEN_TYPE,
computeExpiresInSeconds(refreshedToken),
toUserView(authBundle.user()),
authBundle.roleCodes(),
authBundle.permissionCodes()
);
}
public List<AuthUserWithRolesResponse> listUsersWithRoles() {
List<AuthUserSummary> users = authMapper.listUsers();
List<AuthUserWithRolesResponse> responses = new ArrayList<>(users.size());
for (AuthUserSummary user : users) {
List<String> roleCodes = authMapper.findRolesByUserId(user.getId())
.stream()
.map(AuthRoleInfo::getRoleCode)
.filter(Objects::nonNull)
.toList();
responses.add(new AuthUserWithRolesResponse(
user.getId(),
user.getUsername(),
user.getFullName(),
user.getEmail(),
user.getEnabled(),
roleCodes
));
}
return responses;
}
public List<AuthRoleInfo> listRoles() {
return authMapper.listRoles();
}
public List<AuthPermissionInfo> listPermissions() {
return authMapper.listPermissions();
}
public void assignUserRoles(Long userId, List<String> roleIds) {
AuthUserAccount user = authMapper.findUserById(userId);
BusinessAssert.notNull(user, AuthErrorCode.USER_NOT_FOUND);
Set<String> distinctRoleIds = new LinkedHashSet<>(roleIds);
BusinessAssert.notEmpty(distinctRoleIds, AuthErrorCode.ROLE_NOT_FOUND);
int existingRoleCount = authMapper.countRolesByIds(new ArrayList<>(distinctRoleIds));
BusinessAssert.isTrue(existingRoleCount == distinctRoleIds.size(), AuthErrorCode.ROLE_NOT_FOUND);
authMapper.deleteUserRoles(userId);
authMapper.insertUserRoles(userId, new ArrayList<>(distinctRoleIds));
}
private String buildToken(AuthBundle authBundle) {
Map<String, Object> claims = Map.of(
"userId", authBundle.user().getId(),
"roles", authBundle.roleCodes(),
"permissions", authBundle.permissionCodes()
);
return jwtUtils.generateToken(authBundle.user().getUsername(), claims);
}
private AuthBundle loadAuthBundle(Long userId) {
AuthUserAccount user = authMapper.findUserById(userId);
BusinessAssert.notNull(user, AuthErrorCode.USER_NOT_FOUND);
BusinessAssert.isTrue(Boolean.TRUE.equals(user.getEnabled()), AuthErrorCode.ACCOUNT_DISABLED);
List<String> roleCodes = authMapper.findRolesByUserId(userId).stream()
.map(AuthRoleInfo::getRoleCode)
.filter(Objects::nonNull)
.toList();
List<String> permissionCodes = authMapper.findPermissionCodesByUserId(userId).stream()
.filter(Objects::nonNull)
.distinct()
.collect(Collectors.toList());
return new AuthBundle(user, roleCodes, permissionCodes);
}
private Claims parseClaims(String token) {
try {
return jwtUtils.getClaimsFromToken(token);
} catch (JwtException | IllegalArgumentException e) {
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
}
private Long parseUserId(Claims claims) {
Object userIdObject = claims.get("userId");
if (userIdObject instanceof Number number) {
return number.longValue();
}
if (userIdObject instanceof String str) {
try {
return Long.parseLong(str);
} catch (NumberFormatException e) {
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
}
throw com.datamate.common.infrastructure.exception.BusinessException.of(AuthErrorCode.TOKEN_INVALID);
}
private long computeExpiresInSeconds(String token) {
Date expirationDate = jwtUtils.getExpirationDateFromToken(token);
long seconds = Duration.between(new Date().toInstant(), expirationDate.toInstant()).toSeconds();
return Math.max(seconds, 0L);
}
private AuthUserView toUserView(AuthUserAccount user) {
return new AuthUserView(
user.getId(),
user.getUsername(),
user.getFullName(),
user.getEmail(),
user.getAvatarUrl(),
user.getOrganization()
);
}
private record AuthBundle(
AuthUserAccount user,
List<String> roleCodes,
List<String> permissionCodes
) {
}
}

View File

@@ -0,0 +1,58 @@
package com.datamate.common.auth.application;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import java.util.Objects;
/**
* 资源访问控制服务(基于请求用户上下文)
*/
@Service
public class ResourceAccessService {
public static final String ADMIN_ROLE_CODE = "ROLE_ADMIN";
public boolean isAdmin() {
return RequestUserContextHolder.hasRole(ADMIN_ROLE_CODE);
}
public String getCurrentUserId() {
return RequestUserContextHolder.getCurrentUserId();
}
public String requireCurrentUserId() {
String currentUserId = getCurrentUserId();
BusinessAssert.isTrue(StringUtils.hasText(currentUserId), SystemErrorCode.INSUFFICIENT_PERMISSIONS);
return currentUserId;
}
/**
* 资源列表查询的 owner 过滤:
* - 管理员返回 null(不过滤)
* - 非管理员返回当前用户ID
*/
public String resolveOwnerFilterUserId() {
if (isAdmin()) {
return null;
}
return requireCurrentUserId();
}
/**
* 校验当前用户是否可访问 owner 资源
*/
public void assertOwnerAccess(String ownerUserId) {
if (isAdmin()) {
return;
}
String currentUserId = requireCurrentUserId();
BusinessAssert.isTrue(
StringUtils.hasText(ownerUserId) && Objects.equals(ownerUserId, currentUserId),
SystemErrorCode.INSUFFICIENT_PERMISSIONS
);
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 权限信息
*/
@Getter
@Setter
public class AuthPermissionInfo {
private String id;
private String permissionCode;
private String permissionName;
private String module;
private String action;
private String pathPattern;
private String method;
private Boolean enabled;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 角色信息
*/
@Getter
@Setter
public class AuthRoleInfo {
private String id;
private String roleCode;
private String roleName;
private String description;
private Boolean enabled;
}

View File

@@ -0,0 +1,24 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
import java.time.LocalDateTime;
/**
* 认证用户账户
*/
@Getter
@Setter
public class AuthUserAccount {
private Long id;
private String username;
private String email;
private String passwordHash;
private String fullName;
private String avatarUrl;
private String organization;
private Boolean enabled;
private LocalDateTime lastLoginAt;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.domain.model;
import lombok.Getter;
import lombok.Setter;
/**
* 用户摘要
*/
@Getter
@Setter
public class AuthUserSummary {
private Long id;
private String username;
private String email;
private String fullName;
private Boolean enabled;
}

View File

@@ -0,0 +1,18 @@
package com.datamate.common.auth.infrastructure.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
/**
* 认证模块配置
*/
@Configuration
public class AuthConfiguration {
@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
}

View File

@@ -0,0 +1,40 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.Getter;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* 请求级用户上下文
*/
@Getter
public class RequestUserContext {
private final String userId;
private final String username;
private final List<String> roles;
private RequestUserContext(String userId, String username, List<String> roles) {
this.userId = userId;
this.username = username;
this.roles = roles == null ? Collections.emptyList() : List.copyOf(roles);
}
public static RequestUserContext of(String userId, String username, List<String> roles) {
return new RequestUserContext(userId, username, roles);
}
public static RequestUserContext empty() {
return new RequestUserContext(null, null, Collections.emptyList());
}
public boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return roles.stream().anyMatch(role -> StringUtils.hasText(role) && Objects.equals(role.trim(), roleCode));
}
}

View File

@@ -0,0 +1,49 @@
package com.datamate.common.auth.infrastructure.context;
import org.springframework.core.NamedThreadLocal;
import org.springframework.util.StringUtils;
import java.util.Collections;
import java.util.List;
/**
* 请求级用户上下文持有器
*/
public final class RequestUserContextHolder {
private static final ThreadLocal<RequestUserContext> USER_CONTEXT_HOLDER =
new NamedThreadLocal<>("datamate-request-user-context");
private RequestUserContextHolder() {
}
public static void set(RequestUserContext context) {
USER_CONTEXT_HOLDER.set(context == null ? RequestUserContext.empty() : context);
}
public static RequestUserContext get() {
RequestUserContext context = USER_CONTEXT_HOLDER.get();
return context == null ? RequestUserContext.empty() : context;
}
public static String getCurrentUserId() {
return get().getUserId();
}
public static List<String> getCurrentRoles() {
List<String> roles = get().getRoles();
return roles == null ? Collections.emptyList() : roles;
}
public static boolean hasRole(String roleCode) {
if (!StringUtils.hasText(roleCode)) {
return false;
}
return getCurrentRoles().stream()
.anyMatch(role -> StringUtils.hasText(role) && roleCode.equalsIgnoreCase(role.trim()));
}
public static void clear() {
USER_CONTEXT_HOLDER.remove();
}
}

View File

@@ -0,0 +1,53 @@
package com.datamate.common.auth.infrastructure.context;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.HandlerInterceptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* 从网关透传请求头中提取用户上下文
*/
@Component
public class RequestUserContextInterceptor implements HandlerInterceptor {
private static final String HEADER_USER_ID = "X-User-Id";
private static final String HEADER_USER_NAME = "X-User-Name";
private static final String HEADER_USER_ROLES = "X-User-Roles";
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
String userId = normalizeValue(request.getHeader(HEADER_USER_ID));
String username = normalizeValue(request.getHeader(HEADER_USER_NAME));
List<String> roleCodes = parseRoleCodes(request.getHeader(HEADER_USER_ROLES));
RequestUserContextHolder.set(RequestUserContext.of(userId, username, roleCodes));
return true;
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) {
RequestUserContextHolder.clear();
}
private String normalizeValue(String value) {
if (!StringUtils.hasText(value)) {
return null;
}
return value.trim();
}
private List<String> parseRoleCodes(String roleHeader) {
if (!StringUtils.hasText(roleHeader)) {
return Collections.emptyList();
}
return Arrays.stream(roleHeader.split(","))
.map(String::trim)
.filter(StringUtils::hasText)
.toList();
}
}

View File

@@ -0,0 +1,21 @@
package com.datamate.common.auth.infrastructure.context;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
/**
* 请求用户上下文拦截器注册
*/
@Configuration
@RequiredArgsConstructor
public class RequestUserContextWebMvcConfigurer implements WebMvcConfigurer {
private final RequestUserContextInterceptor requestUserContextInterceptor;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(requestUserContextInterceptor).addPathPatterns("/**");
}
}

View File

@@ -0,0 +1,23 @@
package com.datamate.common.auth.infrastructure.exception;
import com.datamate.common.infrastructure.exception.ErrorCode;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 认证授权错误码
*/
@Getter
@AllArgsConstructor
public enum AuthErrorCode implements ErrorCode {
INVALID_CREDENTIALS("auth.0001", "用户名或密码错误"),
ACCOUNT_DISABLED("auth.0002", "账号已被禁用"),
TOKEN_INVALID("auth.0003", "登录状态已失效"),
USER_NOT_FOUND("auth.0004", "用户不存在"),
ROLE_NOT_FOUND("auth.0005", "角色不存在"),
AUTHORIZATION_DENIED("auth.0006", "无权限执行该操作");
private final String code;
private final String message;
}

View File

@@ -0,0 +1,39 @@
package com.datamate.common.auth.infrastructure.persistence.mapper;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.domain.model.AuthUserAccount;
import com.datamate.common.auth.domain.model.AuthUserSummary;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
/**
* 认证授权数据访问
*/
@Mapper
public interface AuthMapper {
AuthUserAccount findUserByUsername(@Param("username") String username);
AuthUserAccount findUserById(@Param("userId") Long userId);
int updateLastLoginAt(@Param("userId") Long userId);
List<AuthRoleInfo> findRolesByUserId(@Param("userId") Long userId);
List<String> findPermissionCodesByUserId(@Param("userId") Long userId);
List<AuthUserSummary> listUsers();
List<AuthRoleInfo> listRoles();
List<AuthPermissionInfo> listPermissions();
int countRolesByIds(@Param("roleIds") List<String> roleIds);
int deleteUserRoles(@Param("userId") Long userId);
int insertUserRoles(@Param("userId") Long userId, @Param("roleIds") List<String> roleIds);
}

View File

@@ -0,0 +1,82 @@
package com.datamate.common.auth.interfaces.rest;
import com.datamate.common.auth.application.AuthApplicationService;
import com.datamate.common.auth.domain.model.AuthPermissionInfo;
import com.datamate.common.auth.domain.model.AuthRoleInfo;
import com.datamate.common.auth.interfaces.rest.dto.AssignUserRolesRequest;
import com.datamate.common.auth.interfaces.rest.dto.AuthCurrentUserResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthLoginResponse;
import com.datamate.common.auth.interfaces.rest.dto.AuthUserWithRolesResponse;
import com.datamate.common.auth.interfaces.rest.dto.LoginRequest;
import com.datamate.common.infrastructure.exception.BusinessAssert;
import com.datamate.common.auth.infrastructure.exception.AuthErrorCode;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
/**
* 认证授权控制器
*/
@RestController
@RequestMapping("/auth")
@RequiredArgsConstructor
public class AuthController {
private final AuthApplicationService authApplicationService;
@PostMapping("/login")
public AuthLoginResponse login(@RequestBody @Valid LoginRequest loginRequest) {
return authApplicationService.login(loginRequest.username(), loginRequest.password());
}
@GetMapping("/me")
public AuthCurrentUserResponse me(HttpServletRequest request) {
return authApplicationService.getCurrentUser(extractBearerToken(request.getHeader("Authorization")));
}
@PostMapping("/refresh")
public AuthLoginResponse refresh(@RequestHeader("Authorization") String authorization) {
return authApplicationService.refreshToken(extractBearerToken(authorization));
}
@GetMapping("/users")
public List<AuthUserWithRolesResponse> listUsers() {
return authApplicationService.listUsersWithRoles();
}
@PutMapping("/users/{userId}/roles")
public void assignRoles(@PathVariable("userId") Long userId,
@RequestBody @Valid AssignUserRolesRequest request) {
authApplicationService.assignUserRoles(userId, request.roleIds());
}
@GetMapping("/roles")
public List<AuthRoleInfo> listRoles() {
return authApplicationService.listRoles();
}
@GetMapping("/permissions")
public List<AuthPermissionInfo> listPermissions() {
return authApplicationService.listPermissions();
}
private String extractBearerToken(String authorizationHeader) {
BusinessAssert.isTrue(
authorizationHeader != null && authorizationHeader.startsWith("Bearer "),
AuthErrorCode.TOKEN_INVALID
);
String token = authorizationHeader.substring("Bearer ".length()).trim();
BusinessAssert.isTrue(!token.isEmpty(), AuthErrorCode.TOKEN_INVALID);
return token;
}
}

View File

@@ -0,0 +1,14 @@
package com.datamate.common.auth.interfaces.rest.dto;
import jakarta.validation.constraints.NotEmpty;
import java.util.List;
/**
* 用户角色分配请求
*/
public record AssignUserRolesRequest(
@NotEmpty(message = "角色列表不能为空") List<String> roleIds
) {
}

View File

@@ -0,0 +1,14 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 当前用户信息响应
*/
public record AuthCurrentUserResponse(
AuthUserView user,
List<String> roles,
List<String> permissions
) {
}

View File

@@ -0,0 +1,17 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 登录响应
*/
public record AuthLoginResponse(
String token,
String tokenType,
long expiresInSeconds,
AuthUserView user,
List<String> roles,
List<String> permissions
) {
}

View File

@@ -0,0 +1,15 @@
package com.datamate.common.auth.interfaces.rest.dto;
/**
* 登录用户信息
*/
public record AuthUserView(
Long id,
String username,
String fullName,
String email,
String avatarUrl,
String organization
) {
}

View File

@@ -0,0 +1,17 @@
package com.datamate.common.auth.interfaces.rest.dto;
import java.util.List;
/**
* 用户与角色响应
*/
public record AuthUserWithRolesResponse(
Long id,
String username,
String fullName,
String email,
Boolean enabled,
List<String> roleCodes
) {
}

View File

@@ -0,0 +1,13 @@
package com.datamate.common.auth.interfaces.rest.dto;
import jakarta.validation.constraints.NotBlank;
/**
* 登录请求
*/
public record LoginRequest(
@NotBlank(message = "用户名不能为空") String username,
@NotBlank(message = "密码不能为空") String password
) {
}

View File

@@ -1,9 +1,11 @@
package com.datamate.common.infrastructure.config;
import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import com.datamate.common.auth.infrastructure.context.RequestUserContextHolder;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.reflection.MetaObject;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.StringUtils;
import java.time.LocalDateTime;
@@ -44,17 +46,10 @@ public class EntityMetaObjectHandler implements MetaObjectHandler {
* 获取当前用户(需要根据你的安全框架实现)
*/
private String getCurrentUser() {
// todo 这里需要根据你的安全框架实现,例如Spring Security、Shiro等
// 示例:返回默认用户或从SecurityContext获取
try {
// 如果是Spring Security
// return SecurityContextHolder.getContext().getAuthentication().getName();
// 临时返回默认值,请根据实际情况修改
return "system";
} catch (Exception e) {
log.error("Error getting current user", e);
return "unknown";
String currentUserId = RequestUserContextHolder.getCurrentUserId();
if (StringUtils.hasText(currentUserId)) {
return currentUserId;
}
return "system";
}
}

View File

@@ -3,6 +3,7 @@ package com.datamate.common.infrastructure.config;
import com.datamate.common.infrastructure.common.Response;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import org.springframework.http.HttpStatus;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.BindException;
@@ -28,7 +29,8 @@ public class GlobalExceptionHandler {
@ExceptionHandler(BusinessException.class)
public ResponseEntity<Response<?>> handleBusinessException(BusinessException e) {
log.warn("BusinessException: code={}, message={}", e.getCode(), e.getMessage(), e);
return ResponseEntity.internalServerError().body(Response.error(e.getErrorCodeEnum()));
HttpStatus status = resolveBusinessStatus(e.getCode());
return ResponseEntity.status(status).body(Response.error(e.getErrorCodeEnum()));
}
/**
@@ -51,4 +53,17 @@ public class GlobalExceptionHandler {
log.error("SystemException: ", e);
return ResponseEntity.internalServerError().body(Response.error(SystemErrorCode.SYSTEM_BUSY));
}
private HttpStatus resolveBusinessStatus(String code) {
if (code == null) {
return HttpStatus.INTERNAL_SERVER_ERROR;
}
if (!code.startsWith("auth.")) {
return HttpStatus.INTERNAL_SERVER_ERROR;
}
if ("auth.0006".equals(code)) {
return HttpStatus.FORBIDDEN;
}
return HttpStatus.UNAUTHORIZED;
}
}

View File

@@ -0,0 +1,120 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.datamate.common.auth.infrastructure.persistence.mapper.AuthMapper">
<select id="findUserByUsername" resultType="com.datamate.common.auth.domain.model.AuthUserAccount">
SELECT id,
username,
email,
password_hash AS passwordHash,
full_name AS fullName,
avatar_url AS avatarUrl,
organization,
enabled,
last_login_at AS lastLoginAt
FROM users
WHERE username = #{username}
LIMIT 1
</select>
<select id="findUserById" resultType="com.datamate.common.auth.domain.model.AuthUserAccount">
SELECT id,
username,
email,
password_hash AS passwordHash,
full_name AS fullName,
avatar_url AS avatarUrl,
organization,
enabled,
last_login_at AS lastLoginAt
FROM users
WHERE id = #{userId}
LIMIT 1
</select>
<update id="updateLastLoginAt">
UPDATE users
SET last_login_at = NOW()
WHERE id = #{userId}
</update>
<select id="findRolesByUserId" resultType="com.datamate.common.auth.domain.model.AuthRoleInfo">
SELECT r.id,
r.role_code AS roleCode,
r.role_name AS roleName,
r.description,
r.enabled
FROM t_auth_roles r
INNER JOIN t_auth_user_roles ur ON ur.role_id = r.id
WHERE ur.user_id = #{userId}
ORDER BY r.role_code
</select>
<select id="findPermissionCodesByUserId" resultType="string">
SELECT DISTINCT p.permission_code
FROM t_auth_permissions p
INNER JOIN t_auth_role_permissions rp ON rp.permission_id = p.id
INNER JOIN t_auth_user_roles ur ON ur.role_id = rp.role_id
WHERE ur.user_id = #{userId}
AND p.enabled = 1
ORDER BY p.permission_code
</select>
<select id="listUsers" resultType="com.datamate.common.auth.domain.model.AuthUserSummary">
SELECT id,
username,
email,
full_name AS fullName,
enabled
FROM users
ORDER BY id ASC
</select>
<select id="listRoles" resultType="com.datamate.common.auth.domain.model.AuthRoleInfo">
SELECT id,
role_code AS roleCode,
role_name AS roleName,
description,
enabled
FROM t_auth_roles
ORDER BY role_code ASC
</select>
<select id="listPermissions" resultType="com.datamate.common.auth.domain.model.AuthPermissionInfo">
SELECT id,
permission_code AS permissionCode,
permission_name AS permissionName,
module,
action,
path_pattern AS pathPattern,
method,
enabled
FROM t_auth_permissions
ORDER BY module ASC, action ASC
</select>
<select id="countRolesByIds" resultType="int">
SELECT COUNT(1)
FROM t_auth_roles
WHERE id IN
<foreach collection="roleIds" item="roleId" open="(" separator="," close=")">
#{roleId}
</foreach>
</select>
<delete id="deleteUserRoles">
DELETE
FROM t_auth_user_roles
WHERE user_id = #{userId}
</delete>
<insert id="insertUserRoles">
INSERT INTO t_auth_user_roles (user_id, role_id)
VALUES
<foreach collection="roleIds" item="roleId" separator=",">
(#{userId}, #{roleId})
</foreach>
</insert>
</mapper>

View File

@@ -3,9 +3,13 @@ package com.datamate.common.security;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.StringUtils;
import org.springframework.stereotype.Component;
import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
@@ -15,15 +19,23 @@ import java.util.Map;
*/
@Component
public class JwtUtils {
private static final String DEFAULT_SECRET = "datamate-secret-key-for-jwt-token-generation";
@Value("${jwt.secret:datamate-secret-key-for-jwt-token-generation}")
@Value("${jwt.secret:" + DEFAULT_SECRET + "}")
private String secret;
@Value("${jwt.expiration:86400}") // 24小时
private Long expiration;
private SecretKey getSigningKey() {
return Keys.hmacShaKeyFor(secret.getBytes());
String secretValue = StringUtils.hasText(secret) ? secret : DEFAULT_SECRET;
try {
MessageDigest digest = MessageDigest.getInstance("SHA-512");
byte[] keyBytes = digest.digest(secretValue.getBytes(StandardCharsets.UTF_8));
return Keys.hmacShaKeyFor(keyBytes);
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot initialize JWT signing key", e);
}
}
/**
@@ -84,7 +96,18 @@ public class JwtUtils {
public Boolean validateToken(String token, String username) {
try {
String tokenUsername = getUsernameFromToken(token);
return (username.equals(tokenUsername) && !isTokenExpired(token));
return (username.equals(tokenUsername) && validateToken(token));
} catch (JwtException | IllegalArgumentException e) {
return false;
}
}
/**
* 仅校验令牌格式与有效期
*/
public Boolean validateToken(String token) {
try {
return !isTokenExpired(token);
} catch (JwtException | IllegalArgumentException e) {
return false;
}

View File

@@ -0,0 +1,75 @@
export const PermissionCodes = {
dataManagementRead: "module:data-management:read",
dataManagementWrite: "module:data-management:write",
dataAnnotationRead: "module:data-annotation:read",
dataAnnotationWrite: "module:data-annotation:write",
dataCollectionRead: "module:data-collection:read",
dataCollectionWrite: "module:data-collection:write",
dataEvaluationRead: "module:data-evaluation:read",
dataEvaluationWrite: "module:data-evaluation:write",
dataSynthesisRead: "module:data-synthesis:read",
dataSynthesisWrite: "module:data-synthesis:write",
knowledgeManagementRead: "module:knowledge-management:read",
knowledgeManagementWrite: "module:knowledge-management:write",
knowledgeBaseRead: "module:knowledge-base:read",
knowledgeBaseWrite: "module:knowledge-base:write",
operatorMarketRead: "module:operator-market:read",
operatorMarketWrite: "module:operator-market:write",
orchestrationRead: "module:orchestration:read",
orchestrationWrite: "module:orchestration:write",
contentGenerationUse: "module:content-generation:use",
agentUse: "module:agent:use",
userManage: "system:user:manage",
roleManage: "system:role:manage",
permissionManage: "system:permission:manage",
} as const;
const routePermissionRules: Array<{ prefix: string; permission: string }> = [
{ prefix: "/data/management", permission: PermissionCodes.dataManagementRead },
{ prefix: "/data/annotation", permission: PermissionCodes.dataAnnotationRead },
{ prefix: "/data/collection", permission: PermissionCodes.dataCollectionRead },
{ prefix: "/data/evaluation", permission: PermissionCodes.dataEvaluationRead },
{ prefix: "/data/synthesis", permission: PermissionCodes.dataSynthesisRead },
{ prefix: "/data/knowledge-management", permission: PermissionCodes.knowledgeManagementRead },
{ prefix: "/data/knowledge-base", permission: PermissionCodes.knowledgeBaseRead },
{ prefix: "/data/operator-market", permission: PermissionCodes.operatorMarketRead },
{ prefix: "/data/orchestration", permission: PermissionCodes.orchestrationRead },
{ prefix: "/data/content-generation", permission: PermissionCodes.contentGenerationUse },
{ prefix: "/chat", permission: PermissionCodes.agentUse },
];
const defaultRouteCandidates: Array<{ path: string; permission: string }> = [
{ path: "/data/management", permission: PermissionCodes.dataManagementRead },
{ path: "/data/annotation", permission: PermissionCodes.dataAnnotationRead },
{ path: "/data/knowledge-management", permission: PermissionCodes.knowledgeManagementRead },
{ path: "/data/knowledge-base", permission: PermissionCodes.knowledgeBaseRead },
{ path: "/chat", permission: PermissionCodes.agentUse },
];
export function hasPermission(
userPermissions: string[] | undefined,
requiredPermission?: string | null
): boolean {
if (!requiredPermission) {
return true;
}
return (userPermissions ?? []).includes(requiredPermission);
}
export function resolveRequiredPermissionByPath(pathname: string): string | null {
if (pathname === "/403") {
return null;
}
const matchedRule = routePermissionRules.find((rule) =>
pathname.startsWith(rule.prefix)
);
return matchedRule?.permission ?? null;
}
export function resolveDefaultAuthorizedPath(userPermissions: string[]): string {
const matchedPath = defaultRouteCandidates.find((candidate) =>
hasPermission(userPermissions, candidate.permission)
)?.path;
return matchedPath ?? "/403";
}

View File

@@ -1,20 +1,53 @@
import React from 'react';
import { Navigate, useLocation, Outlet } from 'react-router';
import { useAppSelector } from '@/store/hooks';
import { useAppDispatch, useAppSelector } from '@/store/hooks';
import { fetchCurrentUser, markInitialized } from '@/store/slices/authSlice';
import {
hasPermission,
resolveDefaultAuthorizedPath,
resolveRequiredPermissionByPath,
} from '@/auth/permissions';
interface ProtectedRouteProps {
children?: React.ReactNode;
}
const ProtectedRoute: React.FC<ProtectedRouteProps> = ({ children }) => {
const { isAuthenticated } = useAppSelector((state) => state.auth);
const dispatch = useAppDispatch();
const { isAuthenticated, token, initialized, loading, permissions } = useAppSelector(
(state) => state.auth
);
const location = useLocation();
const requiredPermission = resolveRequiredPermissionByPath(location.pathname);
React.useEffect(() => {
if (initialized || loading) {
return;
}
if (!token) {
dispatch(markInitialized());
return;
}
void dispatch(fetchCurrentUser());
}, [dispatch, initialized, loading, token]);
if (!initialized || loading) {
return null;
}
if (!isAuthenticated) {
// Redirect to the login page, but save the current location they were trying to go to
return <Navigate to="/login" state={{ from: location }} replace />;
}
if (!hasPermission(permissions, requiredPermission)) {
const fallbackPath = resolveDefaultAuthorizedPath(permissions);
if (location.pathname === fallbackPath) {
return <Navigate to="/403" replace />;
}
return <Navigate to={fallbackPath} replace />;
}
return children ? <>{children}</> : <Outlet />;
};

View File

@@ -0,0 +1,24 @@
import React from "react";
import { Button, Result } from "antd";
import { useNavigate } from "react-router";
const ForbiddenPage: React.FC = () => {
const navigate = useNavigate();
return (
<div className="h-screen w-full flex items-center justify-center bg-[#050b14]">
<Result
status="403"
title="403"
subTitle="你当前账号没有访问该页面的权限。"
extra={
<Button type="primary" onClick={() => navigate("/data/management")}>
</Button>
}
/>
</div>
);
};
export default ForbiddenPage;

View File

@@ -1,4 +1,4 @@
import { memo, useCallback, useEffect, useState } from "react";
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { Button, Drawer, Menu, Popover } from "antd";
import {
CloseOutlined,
@@ -14,6 +14,7 @@ import SettingsPage from "../SettingsPage/SettingsPage";
import { useAppSelector, useAppDispatch } from "@/store/hooks";
import { showSettings, hideSettings } from "@/store/slices/settingsSlice";
import { logout } from "@/store/slices/authSlice";
import { hasPermission } from "@/auth/permissions";
const isPathMatch = (currentPath: string, targetPath: string) =>
currentPath === targetPath || currentPath.startsWith(`${targetPath}/`);
@@ -25,13 +26,36 @@ const AsiderAndHeaderLayout = () => {
const [sidebarOpen, setSidebarOpen] = useState(true);
const [taskCenterVisible, setTaskCenterVisible] = useState(false);
const settingVisible = useAppSelector((state) => state.settings.visible);
const permissions = useAppSelector((state) => state.auth.permissions);
const dispatch = useAppDispatch();
const visibleMenuItems = useMemo(
() =>
menuItems
.map((item) => ({
...item,
children: item.children?.filter((subItem) =>
hasPermission(permissions, (subItem as { permissionCode?: string }).permissionCode)
),
}))
.filter((item) => {
const selfVisible = hasPermission(
permissions,
(item as { permissionCode?: string }).permissionCode
);
if (item.children && item.children.length > 0) {
return selfVisible;
}
return selfVisible;
}),
[permissions]
);
// Initialize active item based on current pathname
const initActiveItem = useCallback(() => {
const dataPath = pathname.startsWith("/data/") ? pathname.slice(6) : pathname;
for (let index = 0; index < menuItems.length; index++) {
const element = menuItems[index];
for (let index = 0; index < visibleMenuItems.length; index++) {
const element = visibleMenuItems[index];
if (element.children) {
for (const subItem of element.children) {
if (isPathMatch(dataPath, subItem.id)) {
@@ -44,7 +68,8 @@ const AsiderAndHeaderLayout = () => {
return;
}
}
}, [pathname]);
setActiveItem(visibleMenuItems[0]?.id ?? "");
}, [pathname, visibleMenuItems]);
useEffect(() => {
initActiveItem();
@@ -100,7 +125,7 @@ const AsiderAndHeaderLayout = () => {
<Menu
mode="inline"
inlineCollapsed={!sidebarOpen}
items={menuItems.map((item) => ({
items={visibleMenuItems.map((item) => ({
key: item.id,
label: item.title,
icon: item.icon ? <item.icon className="w-4 h-4" /> : null,

View File

@@ -13,6 +13,7 @@ import {
// Store,
// Merge,
} from "lucide-react";
import { PermissionCodes } from "@/auth/permissions";
export const menuItems = [
// {
@@ -26,6 +27,7 @@ export const menuItems = [
id: "management",
title: "数集管理",
icon: FolderOpen,
permissionCode: PermissionCodes.dataManagementRead,
description: "创建、导入和管理数据集",
color: "bg-blue-500",
},
@@ -33,6 +35,7 @@ export const menuItems = [
id: "annotation",
title: "数据标注",
icon: Tag,
permissionCode: PermissionCodes.dataAnnotationRead,
description: "对数据进行标注和标记",
color: "bg-green-500",
},
@@ -40,6 +43,7 @@ export const menuItems = [
id: "content-generation",
title: "内容生成",
icon: Sparkles,
permissionCode: PermissionCodes.contentGenerationUse,
description: "智能内容生成与创作",
color: "bg-purple-500",
},
@@ -47,6 +51,7 @@ export const menuItems = [
id: "knowledge-management",
title: "知识管理",
icon: Shield,
permissionCode: PermissionCodes.knowledgeManagementRead,
description: "管理知识集与知识条目",
color: "bg-indigo-500",
},

View File

@@ -1,9 +1,9 @@
import React, { useState } from 'react';
import React from 'react';
import { useNavigate, useLocation } from 'react-router';
import { Form, Input, Button, Typography, message, Card } from 'antd';
import { Form, Input, Button, Typography, message } from 'antd';
import { UserOutlined, LockOutlined } from '@ant-design/icons';
import { useAppDispatch, useAppSelector } from '@/store/hooks';
import { loginLocal } from '@/store/slices/authSlice';
import { loginUser } from '@/store/slices/authSlice';
const { Title, Text } = Typography;
@@ -11,19 +11,20 @@ const LoginPage: React.FC = () => {
const navigate = useNavigate();
const location = useLocation();
const dispatch = useAppDispatch();
const { loading, error } = useAppSelector((state) => state.auth);
const { loading } = useAppSelector((state) => state.auth);
const [messageApi, contextHolder] = message.useMessage();
const from = location.state?.from?.pathname || '/data';
const onFinish = (values: any) => {
dispatch(loginLocal(values));
// The reducer updates state synchronously.
if (values.username === 'admin' && values.password === '123456') {
const onFinish = async (values: { username: string; password: string }) => {
try {
await dispatch(loginUser(values)).unwrap();
messageApi.success('登录成功');
navigate(from, { replace: true });
} else {
messageApi.error('账号或密码错误');
} catch (loginError) {
const messageText =
typeof loginError === 'string' ? loginError : '账号或密码错误';
messageApi.error(messageText);
}
};
@@ -59,9 +60,9 @@ const LoginPage: React.FC = () => {
</Text>
</div>
<Form
<Form<{ username: string; password: string }>
name="login"
initialValues={{ remember: true, username: 'admin', password: '123456' }}
initialValues={{ username: 'admin', password: '123456' }}
onFinish={onFinish}
layout="vertical"
size="large"

View File

@@ -1,12 +1,51 @@
import { useState } from "react";
import { useEffect, useMemo, useState } from "react";
import { Menu } from "antd";
import { SettingOutlined } from "@ant-design/icons";
import { SettingOutlined, TeamOutlined } from "@ant-design/icons";
import { Component } from "lucide-react";
import SystemConfig from "./SystemConfig";
import ModelAccess from "./ModelAccess";
import UserPermissionManagement from "./UserPermissionManagement";
import { useAppSelector } from "@/store/hooks";
import { hasPermission, PermissionCodes } from "@/auth/permissions";
export default function SettingsPage() {
const [activeTab, setActiveTab] = useState("model-access");
const permissions = useAppSelector((state) => state.auth.permissions);
const canManageUsers = hasPermission(permissions, PermissionCodes.userManage);
const canViewRoles = hasPermission(permissions, PermissionCodes.roleManage);
const canViewPermissions = hasPermission(
permissions,
PermissionCodes.permissionManage
);
const tabs = useMemo(() => {
const nextTabs = [
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
];
if (canManageUsers || canViewRoles || canViewPermissions) {
nextTabs.push({
key: "user-permission",
icon: <TeamOutlined />,
label: "用户与权限",
});
}
return nextTabs;
}, [canManageUsers, canViewPermissions, canViewRoles]);
const [activeTab, setActiveTab] = useState<string>(tabs[0]?.key ?? "model-access");
useEffect(() => {
const hasActiveTab = tabs.some((tab) => tab.key === activeTab);
if (!hasActiveTab && tabs.length > 0) {
setActiveTab(tabs[0].key);
}
}, [activeTab, tabs]);
return (
<div className="h-screen flex">
@@ -18,21 +57,10 @@ export default function SettingsPage() {
<div className="h-full">
<Menu
mode="inline"
items={[
{
key: "model-access",
icon: <Component className="w-4 h-4" />,
label: "模型接入",
},
{
key: "system-config",
icon: <SettingOutlined />,
label: "参数配置",
},
]}
items={tabs}
selectedKeys={[activeTab]}
onClick={({ key }) => {
setActiveTab(key);
setActiveTab(String(key));
}}
/>
</div>
@@ -41,6 +69,13 @@ export default function SettingsPage() {
{/* 内容区域,根据 activeTab 渲染不同的组件 */}
{activeTab === "system-config" && <SystemConfig />}
{activeTab === "model-access" && <ModelAccess />}
{activeTab === "user-permission" && (
<UserPermissionManagement
canManageUsers={canManageUsers}
canViewRoles={canViewRoles}
canViewPermissions={canViewPermissions}
/>
)}
</div>
</div>
);

View File

@@ -0,0 +1,321 @@
import { useCallback, useEffect, useMemo, useState } from "react";
import {
Button,
Card,
Empty,
message,
Modal,
Select,
Space,
Table,
Tag,
Typography,
} from "antd";
import type { ColumnsType } from "antd/es/table";
import {
assignUserRolesUsingPut,
listAuthPermissionsUsingGet,
listAuthRolesUsingGet,
listAuthUsersUsingGet,
} from "./settings.apis";
import type {
AuthPermissionInfo,
AuthRoleInfo,
AuthUserWithRoles,
} from "./settings.apis";
interface ApiResponse<T> {
code: string;
message: string;
data: T;
}
interface UserPermissionManagementProps {
canManageUsers: boolean;
canViewRoles: boolean;
canViewPermissions: boolean;
}
export default function UserPermissionManagement({
canManageUsers,
canViewRoles,
canViewPermissions,
}: UserPermissionManagementProps) {
const [loading, setLoading] = useState(false);
const [users, setUsers] = useState<AuthUserWithRoles[]>([]);
const [roles, setRoles] = useState<AuthRoleInfo[]>([]);
const [permissions, setPermissions] = useState<AuthPermissionInfo[]>([]);
const [editingUser, setEditingUser] = useState<AuthUserWithRoles | null>(null);
const [selectedRoleCodes, setSelectedRoleCodes] = useState<string[]>([]);
const [submitting, setSubmitting] = useState(false);
const canShowAnything = canManageUsers || canViewRoles || canViewPermissions;
const canAssignRoles = canManageUsers && roles.length > 0;
const roleNameMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.roleName || role.roleCode])),
[roles]
);
const roleCodeToIdMap = useMemo(
() => new Map(roles.map((role) => [role.roleCode, role.id])),
[roles]
);
const loadData = useCallback(async () => {
setLoading(true);
try {
const requestTasks: Array<Promise<unknown>> = [];
if (canManageUsers || canViewRoles || canViewPermissions) {
requestTasks.push(listAuthUsersUsingGet());
}
if (canManageUsers || canViewRoles) {
requestTasks.push(listAuthRolesUsingGet());
}
if (canViewPermissions) {
requestTasks.push(listAuthPermissionsUsingGet());
}
const responses = await Promise.all(requestTasks);
let index = 0;
if (canManageUsers || canViewRoles || canViewPermissions) {
const userResponse = responses[index++] as ApiResponse<AuthUserWithRoles[]>;
setUsers(userResponse?.data ?? []);
}
if (canManageUsers || canViewRoles) {
const roleResponse = responses[index++] as ApiResponse<AuthRoleInfo[]>;
setRoles(roleResponse?.data ?? []);
} else {
setRoles([]);
}
if (canViewPermissions) {
const permissionResponse = responses[index++] as ApiResponse<AuthPermissionInfo[]>;
setPermissions(permissionResponse?.data ?? []);
} else {
setPermissions([]);
}
} catch (error) {
message.error("加载用户权限信息失败");
console.error("加载用户权限信息失败:", error);
} finally {
setLoading(false);
}
}, [canManageUsers, canViewPermissions, canViewRoles]);
useEffect(() => {
if (!canShowAnything) {
return;
}
void loadData();
}, [canShowAnything, loadData]);
const userColumns: ColumnsType<AuthUserWithRoles> = [
{
title: "用户名",
dataIndex: "username",
key: "username",
width: 180,
},
{
title: "姓名",
dataIndex: "fullName",
key: "fullName",
width: 180,
render: (value?: string) => value || "-",
},
{
title: "邮箱",
dataIndex: "email",
key: "email",
render: (value?: string) => value || "-",
},
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "角色",
dataIndex: "roleCodes",
key: "roleCodes",
render: (roleCodes: string[]) => (
<Space wrap>
{(roleCodes ?? []).map((roleCode) => (
<Tag key={roleCode}>{roleNameMap.get(roleCode) || roleCode}</Tag>
))}
</Space>
),
},
{
title: "操作",
key: "actions",
width: 120,
render: (_, record) => (
<Button
type="link"
disabled={!canAssignRoles}
onClick={() => {
setEditingUser(record);
setSelectedRoleCodes(record.roleCodes ?? []);
}}
>
</Button>
),
},
];
const roleColumns: ColumnsType<AuthRoleInfo> = [
{ title: "角色编码", dataIndex: "roleCode", key: "roleCode", width: 220 },
{ title: "角色名称", dataIndex: "roleName", key: "roleName", width: 180 },
{
title: "状态",
dataIndex: "enabled",
key: "enabled",
width: 120,
render: (enabled?: boolean) =>
enabled ? <Tag color="green"></Tag> : <Tag color="default"></Tag>,
},
{
title: "描述",
dataIndex: "description",
key: "description",
render: (value?: string) => value || "-",
},
];
const permissionColumns: ColumnsType<AuthPermissionInfo> = [
{
title: "权限编码",
dataIndex: "permissionCode",
key: "permissionCode",
width: 260,
},
{
title: "权限名称",
dataIndex: "permissionName",
key: "permissionName",
width: 200,
},
{
title: "模块",
dataIndex: "module",
key: "module",
width: 140,
render: (value?: string) => value || "-",
},
{
title: "动作",
dataIndex: "action",
key: "action",
width: 120,
render: (value?: string) => value || "-",
},
{
title: "接口",
key: "api",
render: (_, record) =>
record.pathPattern ? `${record.method || "ALL"} ${record.pathPattern}` : "-",
},
];
const handleAssignRoles = async () => {
if (!editingUser) {
return;
}
if (selectedRoleCodes.length === 0) {
message.warning("请至少选择一个角色");
return;
}
const roleIds = selectedRoleCodes
.map((roleCode) => roleCodeToIdMap.get(roleCode))
.filter((roleId): roleId is string => Boolean(roleId));
if (roleIds.length !== selectedRoleCodes.length) {
message.error("角色映射失败,请刷新后重试");
return;
}
setSubmitting(true);
try {
await assignUserRolesUsingPut(editingUser.id, roleIds);
message.success("角色分配成功");
setEditingUser(null);
setSelectedRoleCodes([]);
await loadData();
} catch (error) {
message.error("角色分配失败");
console.error("角色分配失败:", error);
} finally {
setSubmitting(false);
}
};
if (!canShowAnything) {
return <Empty description="当前账号无用户与权限管理权限" />;
}
return (
<Space direction="vertical" size={16} className="w-full">
<Card title="用户管理">
<Table
loading={loading}
rowKey="id"
dataSource={users}
columns={userColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
{canViewRoles && (
<Card title="角色列表">
<Table
loading={loading}
rowKey="id"
dataSource={roles}
columns={roleColumns}
pagination={{ pageSize: 8, showSizeChanger: false }}
/>
</Card>
)}
{canViewPermissions && (
<Card title="权限列表">
<Table
loading={loading}
rowKey="id"
dataSource={permissions}
columns={permissionColumns}
pagination={{ pageSize: 10, showSizeChanger: false }}
/>
</Card>
)}
<Modal
title={`分配角色 - ${editingUser?.username || ""}`}
open={Boolean(editingUser)}
confirmLoading={submitting}
onOk={() => {
void handleAssignRoles();
}}
onCancel={() => {
setEditingUser(null);
setSelectedRoleCodes([]);
}}
>
{roles.length === 0 ? (
<Typography.Text type="secondary"></Typography.Text>
) : (
<Select
mode="multiple"
className="w-full"
placeholder="请选择角色"
value={selectedRoleCodes}
onChange={(values) => setSelectedRoleCodes(values)}
options={roles.map((role) => ({
value: role.roleCode,
label: `${role.roleName} (${role.roleCode})`,
}))}
/>
)}
</Modal>
</Space>
);
}

View File

@@ -1,11 +1,11 @@
import { get, post, put, del } from "@/utils/request";
// 模型相关接口
export function queryModelProvidersUsingGet(params?: any) {
export function queryModelProvidersUsingGet(params?: Record<string, unknown>) {
return get("/api/models/providers", params);
}
export function queryModelListUsingGet(data: any) {
export function queryModelListUsingGet(data: Record<string, unknown>) {
return get("/api/models/list", data);
}
@@ -15,12 +15,12 @@ export function queryModelDetailByIdUsingGet(id: string | number) {
export function updateModelByIdUsingPut(
id: string | number,
data: any
data: Record<string, unknown>
) {
return put(`/api/models/${id}`, data);
}
export function createModelUsingPost(data: any) {
export function createModelUsingPost(data: Record<string, unknown>) {
return post("/api/models/create", data);
}
@@ -28,13 +28,60 @@ export function deleteModelByIdUsingDelete(id: string | number) {
return del(`/api/models/${id}`);
}
// 获取系统参数列表
export function getSysParamList() {
return get('/api/sys-param/list');
return get("/api/sys-param/list");
}
// 更新系统参数值
export const updateSysParamValue = async (params: { id: string; paramValue: string }) => {
export const updateSysParamValue = async (params: {
id: string;
paramValue: string;
}) => {
return put(`/api/sys-param/${params.id}`, params);
};
export interface AuthUserWithRoles {
id: number;
username: string;
fullName?: string;
email?: string;
enabled?: boolean;
roleCodes: string[];
}
export interface AuthRoleInfo {
id: string;
roleCode: string;
roleName: string;
description?: string;
enabled?: boolean;
}
export interface AuthPermissionInfo {
id: string;
permissionCode: string;
permissionName: string;
module?: string;
action?: string;
pathPattern?: string;
method?: string;
enabled?: boolean;
}
// 用户与权限管理接口
export function listAuthUsersUsingGet() {
return get("/api/auth/users");
}
export function listAuthRolesUsingGet() {
return get("/api/auth/roles");
}
export function listAuthPermissionsUsingGet() {
return get("/api/auth/permissions");
}
export function assignUserRolesUsingPut(userId: number, roleIds: string[]) {
return put(`/api/auth/users/${userId}/roles`, { roleIds });
}

View File

@@ -51,6 +51,7 @@ import Home from "@/pages/Home/Home";
import ContentGenerationPage from "@/pages/ContentGeneration/ContentGenerationPage";
import LoginPage from "@/pages/Login/LoginPage";
import ProtectedRoute from "@/components/ProtectedRoute";
import ForbiddenPage from "@/pages/Forbidden/ForbiddenPage";
const router = createBrowserRouter([
{
@@ -64,6 +65,10 @@ const router = createBrowserRouter([
{
Component: ProtectedRoute,
children: [
{
path: "/403",
Component: ForbiddenPage,
},
{
path: "/chat",
Component: withErrorBoundary(AgentPage),

View File

@@ -1,66 +1,124 @@
// store/slices/authSlice.js
import { createSlice, createAsyncThunk } from '@reduxjs/toolkit';
import { createAsyncThunk, createSlice } from "@reduxjs/toolkit";
import { get, post } from "@/utils/request";
// 异步 thunk
export const loginUser = createAsyncThunk(
'auth/login',
async (credentials, { rejectWithValue }) => {
interface AuthUserView {
id: number;
username: string;
fullName?: string;
email?: string;
avatarUrl?: string;
organization?: string;
}
interface AuthLoginPayload {
token: string;
tokenType: string;
expiresInSeconds: number;
user: AuthUserView;
roles: string[];
permissions: string[];
}
interface AuthCurrentUserPayload {
user: AuthUserView;
roles: string[];
permissions: string[];
}
interface ApiResponse<T> {
code: string;
message: string;
data: T;
}
interface AuthState {
user: AuthUserView | null;
token: string | null;
roles: string[];
permissions: string[];
isAuthenticated: boolean;
initialized: boolean;
loading: boolean;
error: string | null;
}
interface LoginCredentials {
username: string;
password: string;
}
const extractErrorMessage = (error: unknown): string => {
if (error instanceof Error) {
const nestedMessage = (error as { data?: { message?: string } }).data?.message;
return nestedMessage ?? error.message;
}
return "登录失败,请稍后重试";
};
export const loginUser = createAsyncThunk<
AuthLoginPayload,
LoginCredentials,
{ rejectValue: string }
>("auth/login", async (credentials, { rejectWithValue }) => {
try {
const response = await fetch('/api/auth/login', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(credentials),
});
if (!response.ok) {
throw new Error('Login failed');
const response = (await post("/api/auth/login", credentials)) as ApiResponse<AuthLoginPayload>;
if (!response?.data?.token) {
return rejectWithValue(response?.message ?? "登录失败");
}
const data = await response.json();
return data;
return response.data;
} catch (error) {
return rejectWithValue(error.message);
return rejectWithValue(extractErrorMessage(error));
}
}
);
});
const authSlice = createSlice({
name: 'auth',
initialState: {
export const fetchCurrentUser = createAsyncThunk<
AuthCurrentUserPayload,
void,
{ rejectValue: string }
>("auth/fetchCurrentUser", async (_, { rejectWithValue }) => {
try {
const response = (await get("/api/auth/me")) as ApiResponse<AuthCurrentUserPayload>;
if (!response?.data?.user) {
return rejectWithValue(response?.message ?? "用户信息加载失败");
}
return response.data;
} catch (error) {
return rejectWithValue(extractErrorMessage(error));
}
});
const initialToken = localStorage.getItem("token");
const initialState: AuthState = {
user: null,
token: localStorage.getItem('token'),
isAuthenticated: !!localStorage.getItem('token'),
token: initialToken,
roles: [],
permissions: [],
isAuthenticated: Boolean(initialToken),
initialized: false,
loading: false,
error: null,
},
};
const authSlice = createSlice({
name: "auth",
initialState,
reducers: {
logout: (state) => {
state.user = null;
state.token = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false;
localStorage.removeItem('token');
state.error = null;
state.initialized = true;
localStorage.removeItem("token");
},
clearError: (state) => {
state.error = null;
},
setToken: (state, action) => {
state.token = action.payload;
localStorage.setItem('token', action.payload);
},
loginLocal: (state, action) => {
const { username, password } = action.payload;
if (username === 'admin' && password === '123456') {
state.user = { username: 'admin', role: 'admin' };
state.token = 'mock-token-' + Date.now();
state.isAuthenticated = true;
localStorage.setItem('token', state.token);
state.error = null;
} else {
state.error = 'Invalid credentials';
state.isAuthenticated = false;
}
markInitialized: (state) => {
state.initialized = true;
},
},
extraReducers: (builder) => {
@@ -71,18 +129,52 @@ const authSlice = createSlice({
})
.addCase(loginUser.fulfilled, (state, action) => {
state.loading = false;
state.initialized = true;
state.user = action.payload.user;
state.token = action.payload.token;
state.roles = action.payload.roles ?? [];
state.permissions = action.payload.permissions ?? [];
state.isAuthenticated = true;
localStorage.setItem('token', action.payload.token);
state.error = null;
localStorage.setItem("token", action.payload.token);
})
.addCase(loginUser.rejected, (state, action) => {
state.loading = false;
state.error = action.payload;
state.initialized = true;
state.user = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false;
state.token = null;
state.error = action.payload ?? "登录失败";
localStorage.removeItem("token");
})
.addCase(fetchCurrentUser.pending, (state) => {
state.loading = true;
state.error = null;
})
.addCase(fetchCurrentUser.fulfilled, (state, action) => {
state.loading = false;
state.initialized = true;
state.user = action.payload.user;
state.roles = action.payload.roles ?? [];
state.permissions = action.payload.permissions ?? [];
state.isAuthenticated = Boolean(state.token);
state.error = null;
})
.addCase(fetchCurrentUser.rejected, (state, action) => {
state.loading = false;
state.initialized = true;
state.user = null;
state.roles = [];
state.permissions = [];
state.isAuthenticated = false;
state.token = null;
state.error = action.payload ?? "登录状态已失效";
localStorage.removeItem("token");
});
},
});
export const { logout, clearError, setToken, loginLocal } = authSlice.actions;
export const { logout, clearError, markInitialized } = authSlice.actions;
export default authSlice.reducer;

View File

@@ -524,8 +524,16 @@ request.addRequestInterceptor((config) => {
// 添加默认响应拦截器 - 错误处理
request.addResponseInterceptor((response) => {
// 可以在这里添加全局错误处理逻辑
// 比如token过期自动跳转登录页等
if (response.status === 401) {
localStorage.removeItem("token");
sessionStorage.removeItem("token");
if (window.location.pathname !== "/login") {
window.location.href = "/login";
}
}
if (response.status === 403 && window.location.pathname !== "/403") {
window.location.href = "/403";
}
return response;
});

View File

@@ -20,6 +20,11 @@ from app.module.shared.schema import StandardResponse
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
AutoAnnotationTaskResponse,
@@ -39,13 +44,14 @@ service = AutoAnnotationTaskService()
@router.get("", response_model=StandardResponse[List[AutoAnnotationTaskResponse]])
async def list_auto_annotation_tasks(
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取自动标注任务列表。
前端当前不传分页参数,这里直接返回所有未删除任务。
"""
tasks = await service.list_tasks(db)
tasks = await service.list_tasks(db, user_context)
return StandardResponse(
code=200,
message="success",
@@ -57,6 +63,7 @@ async def list_auto_annotation_tasks(
async def create_auto_annotation_task(
request: CreateAutoAnnotationTaskRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""创建自动标注任务。
@@ -74,6 +81,7 @@ async def create_auto_annotation_task(
# 尝试获取数据集名称和文件数量用于冗余字段,失败时不阻塞任务创建
dataset_name = None
total_images = 0
await assert_dataset_access(db, request.dataset_id, user_context)
try:
dm_client = DatasetManagementService(db)
# Service.get_dataset 返回 DatasetResponse,包含 name 和 fileCount
@@ -106,13 +114,14 @@ async def create_auto_annotation_task(
async def get_auto_annotation_task_status(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""获取单个自动标注任务状态。
前端当前主要通过列表轮询,这里提供按 ID 查询的补充接口。
"""
task = await service.get_task(db, task_id)
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
@@ -127,10 +136,11 @@ async def get_auto_annotation_task_status(
async def delete_auto_annotation_task(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""删除(软删除)自动标注任务,仅标记 deleted_at。"""
ok = await service.soft_delete_task(db, task_id)
ok = await service.soft_delete_task(db, task_id, user_context)
if not ok:
raise HTTPException(status_code=404, detail="Task not found")
@@ -145,6 +155,7 @@ async def delete_auto_annotation_task(
async def download_auto_annotation_result(
task_id: str = Path(..., description="任务ID"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""下载指定自动标注任务的结果 ZIP。"""
@@ -154,7 +165,7 @@ async def download_auto_annotation_result(
import tempfile
# 复用服务层获取任务信息
task = await service.get_task(db, task_id)
task = await service.get_task(db, task_id, user_context)
if not task:
raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -27,6 +27,10 @@ from app.module.annotation.schema.editor import (
UpsertAnnotationResponse,
)
from app.module.annotation.service.editor import AnnotationEditorService
from app.module.annotation.security import (
RequestUserContext,
get_request_user_context,
)
from app.module.shared.schema import StandardResponse
logger = get_logger(__name__)
@@ -44,8 +48,9 @@ router = APIRouter(
async def get_editor_project_info(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
info = await service.get_project_info(project_id)
return StandardResponse(code=200, message="success", data=info)
@@ -64,8 +69,9 @@ async def list_editor_tasks(
description="是否排除已被转换为TXT的源文档文件(PDF/DOC/DOCX,仅文本数据集生效)",
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.list_tasks(
project_id,
page=page,
@@ -86,8 +92,9 @@ async def get_editor_task(
None, alias="segmentIndex", description="段落索引(分段模式下使用)"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
task = await service.get_task(project_id, file_id, segment_index=segment_index)
return StandardResponse(code=200, message="success", data=task)
@@ -103,8 +110,9 @@ async def get_editor_task_segment(
..., ge=0, alias="segmentIndex", description="段落索引(从0开始)"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.get_task_segment(project_id, file_id, segment_index)
return StandardResponse(code=200, message="success", data=result)
@@ -118,8 +126,9 @@ async def upsert_editor_annotation(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.upsert_annotation(project_id, file_id, request)
return StandardResponse(code=200, message="success", data=result)
@@ -132,11 +141,12 @@ async def check_file_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
检查文件是否有新版本
"""
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.check_file_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result)
@@ -149,10 +159,11 @@ async def use_new_version(
project_id: str = Path(..., description="标注项目ID(t_dm_labeling_projects.id)"),
file_id: str = Path(..., description="文件ID(t_dm_dataset_files.id)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
使用文件新版本并清空标注
"""
service = AnnotationEditorService(db)
service = AnnotationEditorService(db, user_context)
result = await service.use_new_version(project_id, file_id)
return StandardResponse(code=200, message="success", data=result)

View File

@@ -12,6 +12,11 @@ from app.module.shared.schema import StandardResponse, PaginatedData
from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService
from ..service.template import AnnotationTemplateService
from ..service.knowledge_sync import KnowledgeSyncService
@@ -42,7 +47,9 @@ async def login_label_studio(mapping_id: str, db: AsyncSession = Depends(get_db)
"", response_model=StandardResponse[DatasetMappingCreateResponse], status_code=201
)
async def create_mapping(
request: DatasetMappingCreateRequest, db: AsyncSession = Depends(get_db)
request: DatasetMappingCreateRequest,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
创建数据集映射
@@ -58,6 +65,8 @@ async def create_mapping(
mapping_service = DatasetMappingService(db)
template_service = AnnotationTemplateService()
await assert_dataset_access(db, request.dataset_id, user_context)
logger.info(f"Create dataset mapping request: {request.dataset_id}")
# 从DM服务获取数据集信息
@@ -163,7 +172,7 @@ async def create_mapping(
try:
from ..service.editor import AnnotationEditorService
editor_service = AnnotationEditorService(db)
editor_service = AnnotationEditorService(db, user_context)
# 异步预计算切片(不阻塞创建响应)
segmentation_result = (
await editor_service.precompute_segmentation_for_project(
@@ -202,6 +211,7 @@ async def list_mappings(
False, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
查询所有映射关系(分页)
@@ -230,6 +240,8 @@ async def list_mappings(
limit=size,
include_deleted=False,
include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
)
# 计算总页数
@@ -256,7 +268,11 @@ async def list_mappings(
@router.get("/{mapping_id}", response_model=StandardResponse[DatasetMappingResponse])
async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
async def get_mapping(
mapping_id: str,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
根据 UUID 查询单个映射关系(包含关联的标注模板详情)
@@ -278,6 +294,7 @@ async def get_mapping(mapping_id: str, db: AsyncSession = Depends(get_db)):
raise HTTPException(
status_code=404, detail=f"Mapping not found: {mapping_id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
logger.info(
f"Found mapping: {mapping.id}, template_included: {mapping.template is not None}"
@@ -304,6 +321,7 @@ async def get_mappings_by_source(
True, description="是否包含模板详情", alias="includeTemplate"
),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
根据源数据集 ID 查询所有映射关系(分页,包含模板详情)
@@ -319,6 +337,7 @@ async def get_mappings_by_source(
"""
try:
service = DatasetMappingService(db)
await assert_dataset_access(db, dataset_id, user_context)
# 计算 skip
skip = (page - 1) * size
@@ -333,6 +352,8 @@ async def get_mappings_by_source(
skip=skip,
limit=size,
include_template=include_template,
current_user_id=user_context.user_id,
is_admin=user_context.is_admin,
)
# 计算总页数
@@ -364,6 +385,7 @@ async def get_mappings_by_source(
async def delete_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
删除映射关系(软删除)
@@ -387,6 +409,7 @@ async def delete_mapping(
raise HTTPException(
status_code=404, detail=f"Mapping either not found or not specified."
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
id = mapping.id
dataset_id = mapping.dataset_id
@@ -428,6 +451,7 @@ async def update_mapping(
project_id: str = Path(..., description="映射UUID(path param)"),
request: DatasetMappingUpdateRequest = None,
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
更新标注项目信息
@@ -456,6 +480,7 @@ async def update_mapping(
raise HTTPException(
status_code=404, detail=f"Mapping not found: {project_id}"
)
await assert_dataset_access(db, mapping_orm.dataset_id, user_context)
# 构建更新数据
update_values = {}

View File

@@ -10,6 +10,11 @@ from app.module.dataset import DatasetManagementService
from app.core.logging import get_logger
from app.core.config import settings
from ..security import (
RequestUserContext,
assert_dataset_access,
get_request_user_context,
)
from ..service.mapping import DatasetMappingService
from ..schema import (
SyncDatasetRequest,
@@ -32,7 +37,8 @@ logger = get_logger(__name__)
@router.post("/sync", response_model=StandardResponse[SyncDatasetResponse])
async def sync_dataset_content(
request: SyncDatasetRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Sync Dataset Content (Files and Annotations)
@@ -51,6 +57,7 @@ async def sync_dataset_content(
status_code=404,
detail=f"Mapping not found: {request.id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
dm_client = DatasetManagementService(db)
dataset_info = await dm_client.get_dataset(mapping.dataset_id)
@@ -82,7 +89,8 @@ async def sync_dataset_content(
@router.post("/annotation/sync", response_model=StandardResponse[SyncAnnotationsResponse])
async def sync_annotations(
request: SyncAnnotationsRequest,
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Sync Annotations Only (Bidirectional Support)
@@ -102,6 +110,7 @@ async def sync_annotations(
status_code=404,
detail=f"Mapping not found: {request.id}"
)
await assert_dataset_access(db, mapping.dataset_id, user_context)
result = SyncAnnotationsResponse(
id=mapping.id,
@@ -156,7 +165,8 @@ async def check_label_studio_connection():
async def update_file_tags(
request: UpdateFileTagsRequest,
file_id: str = Path(..., description="文件ID"),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
user_context: RequestUserContext = Depends(get_request_user_context),
):
"""
Update File Tags (Partial Update with Auto Format Conversion)
@@ -189,6 +199,7 @@ async def update_file_tags(
raise HTTPException(status_code=404, detail=f"File not found: {file_id}")
dataset_id = str(file_record.dataset_id) # type: ignore - Convert Column to str
await assert_dataset_access(db, dataset_id, user_context)
# 查找数据集关联的模板ID
from ..service.mapping import DatasetMappingService

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
from fastapi import HTTPException, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.dataset_management import Dataset
HEADER_USER_ID = "X-User-Id"
HEADER_USER_NAME = "X-User-Name"
HEADER_USER_ROLES = "X-User-Roles"
ADMIN_ROLE_CODE = "ROLE_ADMIN"
@dataclass(frozen=True)
class RequestUserContext:
user_id: str
username: str | None
roles: Tuple[str, ...]
@property
def is_admin(self) -> bool:
return any(role.upper() == ADMIN_ROLE_CODE for role in self.roles)
def get_request_user_context(request: Request) -> RequestUserContext:
user_id = (request.headers.get(HEADER_USER_ID) or "").strip()
username = (request.headers.get(HEADER_USER_NAME) or "").strip() or None
role_header = request.headers.get(HEADER_USER_ROLES) or ""
roles = tuple(
role.strip()
for role in role_header.split(",")
if role and role.strip()
)
if not user_id:
raise HTTPException(status_code=403, detail="权限不足:缺少用户身份")
return RequestUserContext(user_id=user_id, username=username, roles=roles)
def ensure_dataset_owner_access(
user_context: RequestUserContext,
dataset_owner_user_id: str | None,
dataset_id: str,
) -> None:
if user_context.is_admin:
return
if not dataset_owner_user_id or dataset_owner_user_id != user_context.user_id:
raise HTTPException(
status_code=403,
detail=f"无权访问数据集: {dataset_id}",
)
async def assert_dataset_access(
db: AsyncSession,
dataset_id: str,
user_context: RequestUserContext,
) -> None:
owner_result = await db.execute(
select(Dataset.created_by).where(Dataset.id == dataset_id)
)
dataset_owner = owner_result.scalar_one_or_none()
if dataset_owner is None:
raise HTTPException(status_code=404, detail=f"数据集不存在: {dataset_id}")
ensure_dataset_owner_access(user_context, str(dataset_owner), dataset_id)

View File

@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models.annotation_management import AutoAnnotationTask
from app.db.models.dataset_management import Dataset, DatasetFiles
from app.module.annotation.security import RequestUserContext
from ..schema.auto import (
CreateAutoAnnotationTaskRequest,
@@ -63,13 +64,25 @@ class AutoAnnotationTaskService:
resp.source_datasets = [dataset_name] if dataset_name else [request.dataset_id]
return resp
async def list_tasks(self, db: AsyncSession) -> List[AutoAnnotationTaskResponse]:
def _apply_dataset_scope(self, query, user_context: RequestUserContext):
if user_context.is_admin:
return query
return query.join(
Dataset,
AutoAnnotationTask.dataset_id == Dataset.id,
).where(Dataset.created_by == user_context.user_id)
async def list_tasks(
self,
db: AsyncSession,
user_context: RequestUserContext,
) -> List[AutoAnnotationTaskResponse]:
"""获取未软删除的自动标注任务列表,按创建时间倒序。"""
query = select(AutoAnnotationTask).where(AutoAnnotationTask.deleted_at.is_(None))
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(
select(AutoAnnotationTask)
.where(AutoAnnotationTask.deleted_at.is_(None))
.order_by(AutoAnnotationTask.created_at.desc())
query.order_by(AutoAnnotationTask.created_at.desc())
)
tasks: List[AutoAnnotationTask] = list(result.scalars().all())
@@ -87,13 +100,18 @@ class AutoAnnotationTaskService:
return responses
async def get_task(self, db: AsyncSession, task_id: str) -> Optional[AutoAnnotationTaskResponse]:
result = await db.execute(
select(AutoAnnotationTask).where(
async def get_task(
self,
db: AsyncSession,
task_id: str,
user_context: RequestUserContext,
) -> Optional[AutoAnnotationTaskResponse]:
query = select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return None
@@ -138,13 +156,18 @@ class AutoAnnotationTaskService:
return [task.dataset_id]
return []
async def soft_delete_task(self, db: AsyncSession, task_id: str) -> bool:
result = await db.execute(
select(AutoAnnotationTask).where(
async def soft_delete_task(
self,
db: AsyncSession,
task_id: str,
user_context: RequestUserContext,
) -> bool:
query = select(AutoAnnotationTask).where(
AutoAnnotationTask.id == task_id,
AutoAnnotationTask.deleted_at.is_(None),
)
)
query = self._apply_dataset_scope(query, user_context)
result = await db.execute(query)
task = result.scalar_one_or_none()
if not task:
return False

View File

@@ -54,6 +54,10 @@ from app.module.annotation.service.knowledge_sync import KnowledgeSyncService
from app.module.annotation.service.annotation_text_splitter import (
AnnotationTextSplitter,
)
from app.module.annotation.security import (
RequestUserContext,
ensure_dataset_owner_access,
)
from app.module.annotation.service.text_fetcher import (
fetch_text_content_via_download_api,
)
@@ -104,8 +108,9 @@ class AnnotationEditorService:
# 分段阈值:超过此字符数自动分段
SEGMENT_THRESHOLD = 200
def __init__(self, db: AsyncSession):
def __init__(self, db: AsyncSession, user_context: RequestUserContext):
self.db = db
self.user_context = user_context
self.template_service = AnnotationTemplateService()
@staticmethod
@@ -157,14 +162,24 @@ class AnnotationEditorService:
async def _get_project_or_404(self, project_id: str) -> LabelingProject:
result = await self.db.execute(
select(LabelingProject).where(
select(LabelingProject, Dataset.created_by).join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(
LabelingProject.id == project_id,
LabelingProject.deleted_at.is_(None),
)
)
project = result.scalar_one_or_none()
if not project:
row = result.first()
if not row:
raise HTTPException(status_code=404, detail=f"标注项目不存在: {project_id}")
project = row[0]
dataset_owner = row[1]
ensure_dataset_owner_access(
self.user_context,
str(dataset_owner) if dataset_owner is not None else None,
project.dataset_id,
)
return project
async def _get_dataset_type(self, dataset_id: str) -> Optional[str]:

View File

@@ -478,7 +478,9 @@ class DatasetMappingService:
skip: int = 0,
limit: int = 100,
include_deleted: bool = False,
include_template: bool = False
include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]:
"""
获取所有映射及总数(用于分页)
@@ -495,9 +497,16 @@ class DatasetMappingService:
query = self._build_query_with_dataset_name()
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数
count_query = select(func.count()).select_from(LabelingProject)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))
@@ -557,7 +566,9 @@ class DatasetMappingService:
skip: int = 0,
limit: int = 100,
include_deleted: bool = False,
include_template: bool = False
include_template: bool = False,
current_user_id: Optional[str] = None,
is_admin: bool = False,
) -> Tuple[List[DatasetMappingResponse], int]:
"""
根据源数据集ID获取映射关系及总数(用于分页)
@@ -578,11 +589,18 @@ class DatasetMappingService:
if not include_deleted:
query = query.where(LabelingProject.deleted_at.is_(None))
if not is_admin:
query = query.where(Dataset.created_by == current_user_id)
# 获取总数
count_query = select(func.count()).select_from(LabelingProject).where(
LabelingProject.dataset_id == dataset_id
)
if not is_admin:
count_query = count_query.join(
Dataset,
LabelingProject.dataset_id == Dataset.id,
).where(Dataset.created_by == current_user_id)
if not include_deleted:
count_query = count_query.where(LabelingProject.deleted_at.is_(None))

View File

@@ -1,5 +1,4 @@
import csv
import csv
import datetime
import os
from io import StringIO
@@ -76,6 +75,7 @@ class PdfTextExtractService:
source_path = self._resolve_source_path(file_record)
dataset_path = self._resolve_dataset_path(dataset)
target_path = self._resolve_target_path(dataset_path, source_path, file_record, file_id, file_type)
logical_path = self._build_logical_path(dataset_path, target_path)
existing_record = await self._find_existing_text_record(dataset_id, target_path)
if existing_record:
@@ -85,7 +85,7 @@ class PdfTextExtractService:
file_size = self._get_file_size(target_path)
parser_name = PARSER_BY_FILE_TYPE.get(file_type, "")
record = await self._create_text_file_record(
dataset, file_record, target_path, file_size, parser_name, derived_file_type
dataset, file_record, target_path, logical_path, file_size, parser_name, derived_file_type
)
return self._build_response(dataset_id, file_id, record)
@@ -94,7 +94,7 @@ class PdfTextExtractService:
self._write_text_file(target_path, text_content)
file_size = self._get_file_size(target_path)
record = await self._create_text_file_record(
dataset, file_record, target_path, file_size, parser_name, derived_file_type
dataset, file_record, target_path, logical_path, file_size, parser_name, derived_file_type
)
return self._build_response(dataset_id, file_id, record)
@@ -170,6 +170,19 @@ class PdfTextExtractService:
target_dir.mkdir(parents=True, exist_ok=True)
return target_dir / output_name
@staticmethod
def _build_logical_path(dataset_path: Path, target_path: Path) -> str:
normalized_dataset_path = dataset_path.resolve()
normalized_target_path = target_path.resolve()
try:
relative_path = normalized_target_path.relative_to(normalized_dataset_path)
except ValueError as exc:
raise HTTPException(status_code=400, detail="解析文件路径超出数据集目录") from exc
logical_path = str(relative_path).replace("\\", "/").strip()
if not logical_path:
raise HTTPException(status_code=500, detail="解析文件逻辑路径为空")
return logical_path
async def _find_existing_text_record(self, dataset_id: str, target_path: Path) -> DatasetFiles | None:
result = await self.db.execute(
select(DatasetFiles).where(
@@ -259,10 +272,12 @@ class PdfTextExtractService:
dataset: Dataset,
source_file: DatasetFiles,
target_path: Path,
logical_path: str,
file_size: int,
parser_name: str,
derived_file_type: str,
) -> DatasetFiles:
assert logical_path
assert parser_name
assert derived_file_type
metadata = {
@@ -275,6 +290,7 @@ class PdfTextExtractService:
dataset_id=dataset.id, # type: ignore[arg-type]
file_name=target_path.name,
file_path=str(target_path),
logical_path=logical_path,
file_type=derived_file_type,
file_size=file_size,
dataset_filemetadata=metadata,

View File

@@ -74,6 +74,7 @@ class SynthesisDatasetExporter:
file_path = os.path.join(base_path, archived_file_name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
self._write_jsonl(file_path, records)
logical_path = self._build_logical_path(base_path, file_path)
# 计算文件大小
try:
@@ -85,6 +86,7 @@ class SynthesisDatasetExporter:
dataset_id=dataset.id,
file_name=archived_file_name,
file_path=file_path,
logical_path=logical_path,
file_type="jsonl",
file_size=file_size,
last_access_time=datetime.datetime.now(),
@@ -158,3 +160,12 @@ class SynthesisDatasetExporter:
raise SynthesisExportError("Dataset path is empty")
os.makedirs(dataset.path, exist_ok=True)
return dataset.path
@staticmethod
def _build_logical_path(dataset_path: str, file_path: str) -> str:
normalized_dataset_path = os.path.abspath(dataset_path)
normalized_file_path = os.path.abspath(file_path)
relative_path = os.path.relpath(normalized_file_path, normalized_dataset_path).replace("\\", "/").strip()
if relative_path in ("", ".") or relative_path.startswith("../"):
raise SynthesisExportError(f"Invalid logical path generated for file: {file_path}")
return relative_path

View File

@@ -187,11 +187,13 @@ class RatioTaskService:
dst_dir = os.path.dirname(new_path)
await asyncio.to_thread(os.makedirs, dst_dir, exist_ok=True)
await asyncio.to_thread(shutil.copy2, src_path, new_path)
logical_path = RatioTaskService.build_logical_path(dst_prefix, new_path)
file_data = {
"dataset_id": target_ds.id, # type: ignore
"file_name": file_name,
"file_path": new_path,
"logical_path": logical_path,
"file_type": f.file_type,
"file_size": f.file_size,
"check_sum": f.check_sum,
@@ -204,6 +206,15 @@ class RatioTaskService:
session.add(DatasetFiles(**file_record))
existing_paths.add(new_path)
@staticmethod
def build_logical_path(dataset_prefix: str, file_path: str) -> str:
normalized_dataset_prefix = os.path.abspath(dataset_prefix)
normalized_file_path = os.path.abspath(file_path)
relative_path = os.path.relpath(normalized_file_path, normalized_dataset_prefix).replace("\\", "/").strip()
if relative_path in ("", ".") or relative_path.startswith("../"):
raise ValueError(f"Invalid logical path generated for file: {file_path}")
return relative_path
@staticmethod
def get_new_file_name(dst_prefix: str, existing_paths: set[Any], f) -> str:
file_name = f.file_name

149
scripts/db/zz-auth-init.sql Normal file
View File

@@ -0,0 +1,149 @@
USE datamate;
-- =============================================
-- 认证与授权(RBAC)基础表
-- 注意:该脚本命名为 zz- 前缀,确保在 users 表初始化后执行
-- =============================================
CREATE TABLE IF NOT EXISTS t_auth_roles
(
id VARCHAR(36) PRIMARY KEY COMMENT '角色ID',
role_code VARCHAR(100) NOT NULL COMMENT '角色编码',
role_name VARCHAR(100) NOT NULL COMMENT '角色名称',
description VARCHAR(255) DEFAULT '' COMMENT '角色描述',
enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用',
is_built_in TINYINT DEFAULT 1 COMMENT '是否内置:1-是,0-否',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_auth_role_code (role_code)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='角色表';
CREATE TABLE IF NOT EXISTS t_auth_permissions
(
id VARCHAR(36) PRIMARY KEY COMMENT '权限ID',
permission_code VARCHAR(120) NOT NULL COMMENT '权限编码',
permission_name VARCHAR(120) NOT NULL COMMENT '权限名称',
module VARCHAR(100) NOT NULL COMMENT '模块',
action VARCHAR(50) NOT NULL COMMENT '动作',
path_pattern VARCHAR(255) DEFAULT '' COMMENT '路径模式',
method VARCHAR(20) DEFAULT '' COMMENT 'HTTP方法',
enabled TINYINT DEFAULT 1 COMMENT '是否启用:1-启用,0-禁用',
is_built_in TINYINT DEFAULT 1 COMMENT '是否内置:1-是,0-否',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
UNIQUE KEY uk_auth_permission_code (permission_code),
INDEX idx_auth_permission_module_action (module, action)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='权限表';
CREATE TABLE IF NOT EXISTS t_auth_role_permissions
(
id BIGINT PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
role_id VARCHAR(36) NOT NULL COMMENT '角色ID',
permission_id VARCHAR(36) NOT NULL COMMENT '权限ID',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
UNIQUE KEY uk_auth_role_permission (role_id, permission_id),
INDEX idx_auth_role_permission_role (role_id),
INDEX idx_auth_role_permission_permission (permission_id),
CONSTRAINT fk_auth_rp_role FOREIGN KEY (role_id) REFERENCES t_auth_roles (id) ON DELETE CASCADE,
CONSTRAINT fk_auth_rp_permission FOREIGN KEY (permission_id) REFERENCES t_auth_permissions (id) ON DELETE CASCADE
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='角色权限关系表';
CREATE TABLE IF NOT EXISTS t_auth_user_roles
(
id BIGINT PRIMARY KEY AUTO_INCREMENT COMMENT '主键',
user_id BIGINT NOT NULL COMMENT '用户ID(users.id)',
role_id VARCHAR(36) NOT NULL COMMENT '角色ID',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
UNIQUE KEY uk_auth_user_role (user_id, role_id),
INDEX idx_auth_user_role_user (user_id),
INDEX idx_auth_user_role_role (role_id),
CONSTRAINT fk_auth_ur_user FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE,
CONSTRAINT fk_auth_ur_role FOREIGN KEY (role_id) REFERENCES t_auth_roles (id) ON DELETE CASCADE
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4 COMMENT ='用户角色关系表';
-- =============================================
-- 角色初始化
-- =============================================
INSERT IGNORE INTO t_auth_roles (id, role_code, role_name, description, enabled, is_built_in)
VALUES ('role-admin', 'ROLE_ADMIN', '系统管理员', '拥有平台全部权限', 1, 1),
('role-data-editor', 'ROLE_DATA_EDITOR', '数据运营', '拥有业务模块读写权限', 1, 1),
('role-knowledge-user', 'ROLE_KNOWLEDGE_USER', '知识用户', '以知识管理为主的业务权限', 1, 1);
-- =============================================
-- 权限初始化(接口级)
-- =============================================
INSERT IGNORE INTO t_auth_permissions (id, permission_code, permission_name, module, action, path_pattern, method, enabled, is_built_in)
VALUES ('perm-dm-read', 'module:data-management:read', '数据管理读取', 'data-management', 'read', '/api/data-management/**', 'GET', 1, 1),
('perm-dm-write', 'module:data-management:write', '数据管理写入', 'data-management', 'write', '/api/data-management/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-da-read', 'module:data-annotation:read', '数据标注读取', 'data-annotation', 'read', '/api/annotation/**', 'GET', 1, 1),
('perm-da-write', 'module:data-annotation:write', '数据标注写入', 'data-annotation', 'write', '/api/annotation/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-dc-read', 'module:data-collection:read', '数据归集读取', 'data-collection', 'read', '/api/data-collection/**', 'GET', 1, 1),
('perm-dc-write', 'module:data-collection:write', '数据归集写入', 'data-collection', 'write', '/api/data-collection/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-de-read', 'module:data-evaluation:read', '数据评估读取', 'data-evaluation', 'read', '/api/evaluation/**', 'GET', 1, 1),
('perm-de-write', 'module:data-evaluation:write', '数据评估写入', 'data-evaluation', 'write', '/api/evaluation/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-ds-read', 'module:data-synthesis:read', '数据合成读取', 'data-synthesis', 'read', '/api/synthesis/**', 'GET', 1, 1),
('perm-ds-write', 'module:data-synthesis:write', '数据合成写入', 'data-synthesis', 'write', '/api/synthesis/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-km-read', 'module:knowledge-management:read', '知识管理读取', 'knowledge-management', 'read', '/api/data-management/knowledge/**', 'GET', 1, 1),
('perm-km-write', 'module:knowledge-management:write', '知识管理写入', 'knowledge-management', 'write', '/api/data-management/knowledge/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-kb-read', 'module:knowledge-base:read', '知识库读取', 'knowledge-base', 'read', '/api/knowledge-base/**', 'GET', 1, 1),
('perm-kb-write', 'module:knowledge-base:write', '知识库写入', 'knowledge-base', 'write', '/api/knowledge-base/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-om-read', 'module:operator-market:read', '算子市场读取', 'operator-market', 'read', '/api/operator-market/**', 'GET', 1, 1),
('perm-om-write', 'module:operator-market:write', '算子市场写入', 'operator-market', 'write', '/api/operator-market/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-orch-read', 'module:orchestration:read', '流程编排读取', 'orchestration', 'read', '/api/orchestration/**', 'GET', 1, 1),
('perm-orch-write', 'module:orchestration:write', '流程编排写入', 'orchestration', 'write', '/api/orchestration/**', 'POST,PUT,PATCH,DELETE', 1, 1),
('perm-agent-use', 'module:agent:use', '对话助手使用', 'agent', 'use', '/chat/**', 'GET', 1, 1),
('perm-content-use', 'module:content-generation:use', '内容生成功能使用', 'content-generation', 'use', '/api/content-generation/**', 'POST,PUT,PATCH', 1, 1),
('perm-user-manage', 'system:user:manage', '用户管理', 'system', 'manage-user', '/api/auth/users/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1),
('perm-role-manage', 'system:role:manage', '角色管理', 'system', 'manage-role', '/api/auth/roles/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1),
('perm-perm-manage', 'system:permission:manage', '权限管理', 'system', 'manage-permission', '/api/auth/permissions/**', 'GET,POST,PUT,PATCH,DELETE', 1, 1);
-- 管理员拥有所有权限
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-admin', p.id
FROM t_auth_permissions p;
-- 数据运营拥有业务模块读写权限(不含系统管理)
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-data-editor', p.id
FROM t_auth_permissions p
WHERE p.permission_code IN (
'module:data-management:read', 'module:data-management:write',
'module:data-annotation:read', 'module:data-annotation:write',
'module:data-collection:read', 'module:data-collection:write',
'module:data-evaluation:read', 'module:data-evaluation:write',
'module:data-synthesis:read', 'module:data-synthesis:write',
'module:knowledge-management:read', 'module:knowledge-management:write',
'module:knowledge-base:read', 'module:knowledge-base:write',
'module:operator-market:read', 'module:operator-market:write',
'module:orchestration:read', 'module:orchestration:write',
'module:agent:use', 'module:content-generation:use'
);
-- 知识用户拥有知识相关权限及必要数据读取权限
INSERT IGNORE INTO t_auth_role_permissions (role_id, permission_id)
SELECT 'role-knowledge-user', p.id
FROM t_auth_permissions p
WHERE p.permission_code IN (
'module:data-management:read',
'module:knowledge-management:read', 'module:knowledge-management:write',
'module:knowledge-base:read', 'module:knowledge-base:write',
'module:agent:use'
);
-- =============================================
-- 用户角色初始化(绑定到已有 users)
-- =============================================
INSERT IGNORE INTO t_auth_user_roles (user_id, role_id)
SELECT u.id, 'role-admin'
FROM users u
WHERE u.username = 'admin';
INSERT IGNORE INTO t_auth_user_roles (user_id, role_id)
SELECT u.id, 'role-knowledge-user'
FROM users u
WHERE u.username = 'knowledge_user';