You've already forked DataMate
Compare commits
17 Commits
444f8cd015
...
lsf
| Author | SHA1 | Date | |
|---|---|---|---|
| 75f9b95093 | |||
| ca37bc5a3b | |||
| e62a8369d4 | |||
| 6de41f1a5b | |||
| 24e59b87f2 | |||
| 1b2ed5335e | |||
| a5d8997c22 | |||
| e9e4cf3b1c | |||
| 9800517378 | |||
| 3a9afe3480 | |||
| afcb8783aa | |||
| 9b6ff59a11 | |||
| 39338df808 | |||
| 0ed7dcbee7 | |||
| 7abdafc338 | |||
| cca463e7d1 | |||
| 20446bf57d |
26
Makefile
26
Makefile
@@ -211,8 +211,9 @@ endif
|
|||||||
.PHONY: install
|
.PHONY: install
|
||||||
install:
|
install:
|
||||||
ifeq ($(origin INSTALLER), undefined)
|
ifeq ($(origin INSTALLER), undefined)
|
||||||
$(call prompt-installer,datamate-$$INSTALLER-install milvus-$$INSTALLER-install)
|
$(call prompt-installer,neo4j-$$INSTALLER-install datamate-$$INSTALLER-install milvus-$$INSTALLER-install)
|
||||||
else
|
else
|
||||||
|
$(MAKE) neo4j-$(INSTALLER)-install
|
||||||
$(MAKE) datamate-$(INSTALLER)-install
|
$(MAKE) datamate-$(INSTALLER)-install
|
||||||
$(MAKE) milvus-$(INSTALLER)-install
|
$(MAKE) milvus-$(INSTALLER)-install
|
||||||
endif
|
endif
|
||||||
@@ -228,7 +229,7 @@ endif
|
|||||||
.PHONY: uninstall
|
.PHONY: uninstall
|
||||||
uninstall:
|
uninstall:
|
||||||
ifeq ($(origin INSTALLER), undefined)
|
ifeq ($(origin INSTALLER), undefined)
|
||||||
$(call prompt-uninstaller,label-studio-$$INSTALLER-uninstall milvus-$$INSTALLER-uninstall deer-flow-$$INSTALLER-uninstall datamate-$$INSTALLER-uninstall)
|
$(call prompt-uninstaller,label-studio-$$INSTALLER-uninstall milvus-$$INSTALLER-uninstall neo4j-$$INSTALLER-uninstall deer-flow-$$INSTALLER-uninstall datamate-$$INSTALLER-uninstall)
|
||||||
else
|
else
|
||||||
@if [ "$(INSTALLER)" = "docker" ]; then \
|
@if [ "$(INSTALLER)" = "docker" ]; then \
|
||||||
echo "Delete volumes? (This will remove all data)"; \
|
echo "Delete volumes? (This will remove all data)"; \
|
||||||
@@ -240,6 +241,7 @@ else
|
|||||||
fi
|
fi
|
||||||
@$(MAKE) label-studio-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
@$(MAKE) label-studio-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
||||||
$(MAKE) milvus-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
$(MAKE) milvus-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
||||||
|
$(MAKE) neo4j-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
||||||
$(MAKE) deer-flow-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
$(MAKE) deer-flow-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE; \
|
||||||
$(MAKE) datamate-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE
|
$(MAKE) datamate-$(INSTALLER)-uninstall DELETE_VOLUMES_CHOICE=$$DELETE_VOLUMES_CHOICE
|
||||||
endif
|
endif
|
||||||
@@ -247,7 +249,7 @@ endif
|
|||||||
# ========== Docker Install/Uninstall Targets ==========
|
# ========== Docker Install/Uninstall Targets ==========
|
||||||
|
|
||||||
# Valid service targets for docker install/uninstall
|
# Valid service targets for docker install/uninstall
|
||||||
VALID_SERVICE_TARGETS := datamate backend frontend runtime mineru "deer-flow" milvus "label-studio" "data-juicer" dj
|
VALID_SERVICE_TARGETS := datamate backend frontend runtime mineru "deer-flow" milvus neo4j "label-studio" "data-juicer" dj
|
||||||
|
|
||||||
# Generic docker service install target
|
# Generic docker service install target
|
||||||
.PHONY: %-docker-install
|
.PHONY: %-docker-install
|
||||||
@@ -272,6 +274,8 @@ VALID_SERVICE_TARGETS := datamate backend frontend runtime mineru "deer-flow" mi
|
|||||||
REGISTRY=$(REGISTRY) docker compose -f deployment/docker/deer-flow/docker-compose.yml up -d; \
|
REGISTRY=$(REGISTRY) docker compose -f deployment/docker/deer-flow/docker-compose.yml up -d; \
|
||||||
elif [ "$*" = "milvus" ]; then \
|
elif [ "$*" = "milvus" ]; then \
|
||||||
docker compose -f deployment/docker/milvus/docker-compose.yml up -d; \
|
docker compose -f deployment/docker/milvus/docker-compose.yml up -d; \
|
||||||
|
elif [ "$*" = "neo4j" ]; then \
|
||||||
|
docker compose -f deployment/docker/neo4j/docker-compose.yml up -d; \
|
||||||
elif [ "$*" = "data-juicer" ] || [ "$*" = "dj" ]; then \
|
elif [ "$*" = "data-juicer" ] || [ "$*" = "dj" ]; then \
|
||||||
REGISTRY=$(REGISTRY) && docker compose -f deployment/docker/datamate/docker-compose.yml up -d datamate-data-juicer; \
|
REGISTRY=$(REGISTRY) && docker compose -f deployment/docker/datamate/docker-compose.yml up -d datamate-data-juicer; \
|
||||||
else \
|
else \
|
||||||
@@ -311,6 +315,12 @@ VALID_SERVICE_TARGETS := datamate backend frontend runtime mineru "deer-flow" mi
|
|||||||
else \
|
else \
|
||||||
docker compose -f deployment/docker/milvus/docker-compose.yml down; \
|
docker compose -f deployment/docker/milvus/docker-compose.yml down; \
|
||||||
fi; \
|
fi; \
|
||||||
|
elif [ "$*" = "neo4j" ]; then \
|
||||||
|
if [ "$(DELETE_VOLUMES_CHOICE)" = "1" ]; then \
|
||||||
|
docker compose -f deployment/docker/neo4j/docker-compose.yml down -v; \
|
||||||
|
else \
|
||||||
|
docker compose -f deployment/docker/neo4j/docker-compose.yml down; \
|
||||||
|
fi; \
|
||||||
elif [ "$*" = "data-juicer" ] || [ "$*" = "dj" ]; then \
|
elif [ "$*" = "data-juicer" ] || [ "$*" = "dj" ]; then \
|
||||||
$(call docker-compose-service,datamate-data-juicer,down,deployment/docker/datamate); \
|
$(call docker-compose-service,datamate-data-juicer,down,deployment/docker/datamate); \
|
||||||
else \
|
else \
|
||||||
@@ -320,7 +330,7 @@ VALID_SERVICE_TARGETS := datamate backend frontend runtime mineru "deer-flow" mi
|
|||||||
# ========== Kubernetes Install/Uninstall Targets ==========
|
# ========== Kubernetes Install/Uninstall Targets ==========
|
||||||
|
|
||||||
# Valid k8s targets
|
# Valid k8s targets
|
||||||
VALID_K8S_TARGETS := mineru datamate deer-flow milvus label-studio data-juicer dj
|
VALID_K8S_TARGETS := mineru datamate deer-flow milvus neo4j label-studio data-juicer dj
|
||||||
|
|
||||||
# Generic k8s install target
|
# Generic k8s install target
|
||||||
.PHONY: %-k8s-install
|
.PHONY: %-k8s-install
|
||||||
@@ -333,7 +343,9 @@ VALID_K8S_TARGETS := mineru datamate deer-flow milvus label-studio data-juicer d
|
|||||||
done; \
|
done; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi
|
fi
|
||||||
@if [ "$*" = "label-studio" ]; then \
|
@if [ "$*" = "neo4j" ]; then \
|
||||||
|
echo "Skipping Neo4j: no Helm chart available. Use 'make neo4j-docker-install' or provide an external Neo4j instance."; \
|
||||||
|
elif [ "$*" = "label-studio" ]; then \
|
||||||
helm upgrade label-studio deployment/helm/label-studio/ -n $(NAMESPACE) --install; \
|
helm upgrade label-studio deployment/helm/label-studio/ -n $(NAMESPACE) --install; \
|
||||||
elif [ "$*" = "mineru" ]; then \
|
elif [ "$*" = "mineru" ]; then \
|
||||||
kubectl apply -f deployment/kubernetes/mineru/deploy.yaml -n $(NAMESPACE); \
|
kubectl apply -f deployment/kubernetes/mineru/deploy.yaml -n $(NAMESPACE); \
|
||||||
@@ -362,7 +374,9 @@ VALID_K8S_TARGETS := mineru datamate deer-flow milvus label-studio data-juicer d
|
|||||||
done; \
|
done; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi
|
fi
|
||||||
@if [ "$*" = "mineru" ]; then \
|
@if [ "$*" = "neo4j" ]; then \
|
||||||
|
echo "Skipping Neo4j: no Helm chart available. Use 'make neo4j-docker-uninstall' or manage your external Neo4j instance."; \
|
||||||
|
elif [ "$*" = "mineru" ]; then \
|
||||||
kubectl delete -f deployment/kubernetes/mineru/deploy.yaml -n $(NAMESPACE); \
|
kubectl delete -f deployment/kubernetes/mineru/deploy.yaml -n $(NAMESPACE); \
|
||||||
elif [ "$*" = "datamate" ]; then \
|
elif [ "$*" = "datamate" ]; then \
|
||||||
helm uninstall datamate -n $(NAMESPACE) --ignore-not-found; \
|
helm uninstall datamate -n $(NAMESPACE) --ignore-not-found; \
|
||||||
|
|||||||
@@ -37,6 +37,14 @@ public class ApiGatewayApplication {
|
|||||||
.route("data-collection", r -> r.path("/api/data-collection/**")
|
.route("data-collection", r -> r.path("/api/data-collection/**")
|
||||||
.uri("http://datamate-backend-python:18000"))
|
.uri("http://datamate-backend-python:18000"))
|
||||||
|
|
||||||
|
// 知识图谱抽取服务路由
|
||||||
|
.route("kg-extraction", r -> r.path("/api/kg/**")
|
||||||
|
.uri("http://datamate-backend-python:18000"))
|
||||||
|
|
||||||
|
// GraphRAG 融合查询服务路由
|
||||||
|
.route("graphrag", r -> r.path("/api/graphrag/**")
|
||||||
|
.uri("http://datamate-backend-python:18000"))
|
||||||
|
|
||||||
.route("deer-flow-frontend", r -> r.path("/chat/**")
|
.route("deer-flow-frontend", r -> r.path("/chat/**")
|
||||||
.uri("http://deer-flow-frontend:3000"))
|
.uri("http://deer-flow-frontend:3000"))
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ public class PermissionRuleMatcher {
|
|||||||
addModuleRules(permissionRules, "/api/orchestration/**", "module:orchestration:read", "module:orchestration:write");
|
addModuleRules(permissionRules, "/api/orchestration/**", "module:orchestration:read", "module:orchestration:write");
|
||||||
addModuleRules(permissionRules, "/api/content-generation/**", "module:content-generation:use", "module:content-generation:use");
|
addModuleRules(permissionRules, "/api/content-generation/**", "module:content-generation:use", "module:content-generation:use");
|
||||||
addModuleRules(permissionRules, "/api/task-meta/**", "module:task-coordination:read", "module:task-coordination:write");
|
addModuleRules(permissionRules, "/api/task-meta/**", "module:task-coordination:read", "module:task-coordination:write");
|
||||||
|
addModuleRules(permissionRules, "/api/knowledge-graph/**", "module:knowledge-graph:read", "module:knowledge-graph:write");
|
||||||
|
addModuleRules(permissionRules, "/api/graphrag/**", "module:knowledge-base:read", "module:knowledge-base:write");
|
||||||
|
|
||||||
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/users/**", "system:user:manage"));
|
permissionRules.add(new PermissionRule(READ_METHODS, "/api/auth/users/**", "system:user:manage"));
|
||||||
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/users/**", "system:user:manage"));
|
permissionRules.add(new PermissionRule(WRITE_METHODS, "/api/auth/users/**", "system:user:manage"));
|
||||||
|
|||||||
@@ -266,6 +266,12 @@ public class KnowledgeItemApplicationService {
|
|||||||
response.setTotalKnowledgeSets(totalSets);
|
response.setTotalKnowledgeSets(totalSets);
|
||||||
|
|
||||||
List<String> accessibleSetIds = knowledgeSetRepository.listSetIdsByCriteria(baseQuery, ownerFilterUserId, excludeConfidential);
|
List<String> accessibleSetIds = knowledgeSetRepository.listSetIdsByCriteria(baseQuery, ownerFilterUserId, excludeConfidential);
|
||||||
|
if (CollectionUtils.isEmpty(accessibleSetIds)) {
|
||||||
|
response.setTotalFiles(0L);
|
||||||
|
response.setTotalSize(0L);
|
||||||
|
response.setTotalTags(0L);
|
||||||
|
return response;
|
||||||
|
}
|
||||||
List<KnowledgeSet> accessibleSets = knowledgeSetRepository.listByIds(accessibleSetIds);
|
List<KnowledgeSet> accessibleSets = knowledgeSetRepository.listByIds(accessibleSetIds);
|
||||||
if (CollectionUtils.isEmpty(accessibleSets)) {
|
if (CollectionUtils.isEmpty(accessibleSets)) {
|
||||||
response.setTotalFiles(0L);
|
response.setTotalFiles(0L);
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ public class DataManagementConfig {
|
|||||||
/**
|
/**
|
||||||
* 缓存管理器
|
* 缓存管理器
|
||||||
*/
|
*/
|
||||||
@Bean
|
@Bean("dataManagementCacheManager")
|
||||||
public CacheManager cacheManager() {
|
public CacheManager dataManagementCacheManager() {
|
||||||
return new ConcurrentMapCacheManager("datasets", "datasetFiles", "tags");
|
return new ConcurrentMapCacheManager("datasets", "datasetFiles", "tags");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
package com.datamate.knowledgegraph.application;
|
||||||
|
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
|
import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
||||||
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
|
import com.datamate.knowledgegraph.domain.model.EditReview;
|
||||||
|
import com.datamate.knowledgegraph.domain.repository.EditReviewRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.*;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 编辑审核业务服务。
|
||||||
|
* <p>
|
||||||
|
* 提供编辑审核的提交、审批、拒绝和查询功能。
|
||||||
|
* 审批通过后自动调用对应的实体/关系 CRUD 服务执行变更。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class EditReviewService {
|
||||||
|
|
||||||
|
private static final long MAX_SKIP = 100_000L;
|
||||||
|
private static final Pattern UUID_PATTERN = Pattern.compile(
|
||||||
|
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||||
|
);
|
||||||
|
private static final ObjectMapper MAPPER = new ObjectMapper();
|
||||||
|
|
||||||
|
private final EditReviewRepository reviewRepository;
|
||||||
|
private final GraphEntityService entityService;
|
||||||
|
private final GraphRelationService relationService;
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public EditReviewVO submitReview(String graphId, SubmitReviewRequest request, String submittedBy) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
|
||||||
|
EditReview review = EditReview.builder()
|
||||||
|
.graphId(graphId)
|
||||||
|
.operationType(request.getOperationType())
|
||||||
|
.entityId(request.getEntityId())
|
||||||
|
.relationId(request.getRelationId())
|
||||||
|
.payload(request.getPayload())
|
||||||
|
.status("PENDING")
|
||||||
|
.submittedBy(submittedBy)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
EditReview saved = reviewRepository.save(review);
|
||||||
|
log.info("Review submitted: id={}, graphId={}, type={}, by={}",
|
||||||
|
saved.getId(), graphId, request.getOperationType(), submittedBy);
|
||||||
|
return toVO(saved);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public EditReviewVO approveReview(String graphId, String reviewId, String reviewedBy, String comment) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
|
||||||
|
EditReview review = reviewRepository.findById(reviewId, graphId)
|
||||||
|
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.REVIEW_NOT_FOUND));
|
||||||
|
|
||||||
|
if (!"PENDING".equals(review.getStatus())) {
|
||||||
|
throw BusinessException.of(KnowledgeGraphErrorCode.REVIEW_ALREADY_PROCESSED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the change
|
||||||
|
applyChange(review);
|
||||||
|
|
||||||
|
// Update review status
|
||||||
|
review.setStatus("APPROVED");
|
||||||
|
review.setReviewedBy(reviewedBy);
|
||||||
|
review.setReviewComment(comment);
|
||||||
|
review.setReviewedAt(LocalDateTime.now());
|
||||||
|
reviewRepository.save(review);
|
||||||
|
|
||||||
|
log.info("Review approved: id={}, graphId={}, type={}, by={}",
|
||||||
|
reviewId, graphId, review.getOperationType(), reviewedBy);
|
||||||
|
return toVO(review);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public EditReviewVO rejectReview(String graphId, String reviewId, String reviewedBy, String comment) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
|
||||||
|
EditReview review = reviewRepository.findById(reviewId, graphId)
|
||||||
|
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.REVIEW_NOT_FOUND));
|
||||||
|
|
||||||
|
if (!"PENDING".equals(review.getStatus())) {
|
||||||
|
throw BusinessException.of(KnowledgeGraphErrorCode.REVIEW_ALREADY_PROCESSED);
|
||||||
|
}
|
||||||
|
|
||||||
|
review.setStatus("REJECTED");
|
||||||
|
review.setReviewedBy(reviewedBy);
|
||||||
|
review.setReviewComment(comment);
|
||||||
|
review.setReviewedAt(LocalDateTime.now());
|
||||||
|
reviewRepository.save(review);
|
||||||
|
|
||||||
|
log.info("Review rejected: id={}, graphId={}, type={}, by={}",
|
||||||
|
reviewId, graphId, review.getOperationType(), reviewedBy);
|
||||||
|
return toVO(review);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PagedResponse<EditReviewVO> listPendingReviews(String graphId, int page, int size) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
|
||||||
|
int safePage = Math.max(0, page);
|
||||||
|
int safeSize = Math.max(1, Math.min(size, 200));
|
||||||
|
long skip = (long) safePage * safeSize;
|
||||||
|
if (skip > MAX_SKIP) {
|
||||||
|
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<EditReview> reviews = reviewRepository.findPendingByGraphId(graphId, skip, safeSize);
|
||||||
|
long total = reviewRepository.countPendingByGraphId(graphId);
|
||||||
|
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
|
||||||
|
|
||||||
|
List<EditReviewVO> content = reviews.stream().map(EditReviewService::toVO).toList();
|
||||||
|
return PagedResponse.of(content, safePage, total, totalPages);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PagedResponse<EditReviewVO> listReviews(String graphId, String status, int page, int size) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
|
||||||
|
int safePage = Math.max(0, page);
|
||||||
|
int safeSize = Math.max(1, Math.min(size, 200));
|
||||||
|
long skip = (long) safePage * safeSize;
|
||||||
|
if (skip > MAX_SKIP) {
|
||||||
|
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
|
||||||
|
}
|
||||||
|
|
||||||
|
List<EditReview> reviews = reviewRepository.findByGraphId(graphId, status, skip, safeSize);
|
||||||
|
long total = reviewRepository.countByGraphId(graphId, status);
|
||||||
|
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
|
||||||
|
|
||||||
|
List<EditReviewVO> content = reviews.stream().map(EditReviewService::toVO).toList();
|
||||||
|
return PagedResponse.of(content, safePage, total, totalPages);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 执行变更
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
private void applyChange(EditReview review) {
|
||||||
|
String graphId = review.getGraphId();
|
||||||
|
String type = review.getOperationType();
|
||||||
|
|
||||||
|
try {
|
||||||
|
switch (type) {
|
||||||
|
case "CREATE_ENTITY" -> {
|
||||||
|
CreateEntityRequest req = MAPPER.readValue(review.getPayload(), CreateEntityRequest.class);
|
||||||
|
entityService.createEntity(graphId, req);
|
||||||
|
}
|
||||||
|
case "UPDATE_ENTITY" -> {
|
||||||
|
UpdateEntityRequest req = MAPPER.readValue(review.getPayload(), UpdateEntityRequest.class);
|
||||||
|
entityService.updateEntity(graphId, review.getEntityId(), req);
|
||||||
|
}
|
||||||
|
case "DELETE_ENTITY" -> {
|
||||||
|
entityService.deleteEntity(graphId, review.getEntityId());
|
||||||
|
}
|
||||||
|
case "BATCH_DELETE_ENTITY" -> {
|
||||||
|
BatchDeleteRequest req = MAPPER.readValue(review.getPayload(), BatchDeleteRequest.class);
|
||||||
|
entityService.batchDeleteEntities(graphId, req.getIds());
|
||||||
|
}
|
||||||
|
case "CREATE_RELATION" -> {
|
||||||
|
CreateRelationRequest req = MAPPER.readValue(review.getPayload(), CreateRelationRequest.class);
|
||||||
|
relationService.createRelation(graphId, req);
|
||||||
|
}
|
||||||
|
case "UPDATE_RELATION" -> {
|
||||||
|
UpdateRelationRequest req = MAPPER.readValue(review.getPayload(), UpdateRelationRequest.class);
|
||||||
|
relationService.updateRelation(graphId, review.getRelationId(), req);
|
||||||
|
}
|
||||||
|
case "DELETE_RELATION" -> {
|
||||||
|
relationService.deleteRelation(graphId, review.getRelationId());
|
||||||
|
}
|
||||||
|
case "BATCH_DELETE_RELATION" -> {
|
||||||
|
BatchDeleteRequest req = MAPPER.readValue(review.getPayload(), BatchDeleteRequest.class);
|
||||||
|
relationService.batchDeleteRelations(graphId, req.getIds());
|
||||||
|
}
|
||||||
|
default -> throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "未知操作类型: " + type);
|
||||||
|
}
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "变更载荷解析失败: " + e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 转换
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
private static EditReviewVO toVO(EditReview review) {
|
||||||
|
return EditReviewVO.builder()
|
||||||
|
.id(review.getId())
|
||||||
|
.graphId(review.getGraphId())
|
||||||
|
.operationType(review.getOperationType())
|
||||||
|
.entityId(review.getEntityId())
|
||||||
|
.relationId(review.getRelationId())
|
||||||
|
.payload(review.getPayload())
|
||||||
|
.status(review.getStatus())
|
||||||
|
.submittedBy(review.getSubmittedBy())
|
||||||
|
.reviewedBy(review.getReviewedBy())
|
||||||
|
.reviewComment(review.getReviewComment())
|
||||||
|
.createdAt(review.getCreatedAt())
|
||||||
|
.reviewedAt(review.getReviewedAt())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validateGraphId(String graphId) {
|
||||||
|
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
|
||||||
|
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,17 +5,22 @@ import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
|||||||
import com.datamate.common.interfaces.PagedResponse;
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.RedisCacheConfig;
|
||||||
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.cache.annotation.Cacheable;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@@ -32,6 +37,7 @@ public class GraphEntityService {
|
|||||||
|
|
||||||
private final GraphEntityRepository entityRepository;
|
private final GraphEntityRepository entityRepository;
|
||||||
private final KnowledgeGraphProperties properties;
|
private final KnowledgeGraphProperties properties;
|
||||||
|
private final GraphCacheService cacheService;
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public GraphEntity createEntity(String graphId, CreateEntityRequest request) {
|
public GraphEntity createEntity(String graphId, CreateEntityRequest request) {
|
||||||
@@ -49,15 +55,25 @@ public class GraphEntityService {
|
|||||||
.createdAt(LocalDateTime.now())
|
.createdAt(LocalDateTime.now())
|
||||||
.updatedAt(LocalDateTime.now())
|
.updatedAt(LocalDateTime.now())
|
||||||
.build();
|
.build();
|
||||||
return entityRepository.save(entity);
|
GraphEntity saved = entityRepository.save(entity);
|
||||||
|
cacheService.evictEntityCaches(graphId, saved.getId());
|
||||||
|
cacheService.evictSearchCaches(graphId);
|
||||||
|
return saved;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Cacheable(value = RedisCacheConfig.CACHE_ENTITIES,
|
||||||
|
key = "T(com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService).cacheKey(#graphId, #entityId)",
|
||||||
|
unless = "#result == null",
|
||||||
|
cacheManager = "knowledgeGraphCacheManager")
|
||||||
public GraphEntity getEntity(String graphId, String entityId) {
|
public GraphEntity getEntity(String graphId, String entityId) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
return entityRepository.findByIdAndGraphId(entityId, graphId)
|
return entityRepository.findByIdAndGraphId(entityId, graphId)
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.ENTITY_NOT_FOUND));
|
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.ENTITY_NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Cacheable(value = RedisCacheConfig.CACHE_ENTITIES,
|
||||||
|
key = "T(com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService).cacheKey(#graphId, 'list')",
|
||||||
|
cacheManager = "knowledgeGraphCacheManager")
|
||||||
public List<GraphEntity> listEntities(String graphId) {
|
public List<GraphEntity> listEntities(String graphId) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
return entityRepository.findByGraphId(graphId);
|
return entityRepository.findByGraphId(graphId);
|
||||||
@@ -135,8 +151,14 @@ public class GraphEntityService {
|
|||||||
if (request.getProperties() != null) {
|
if (request.getProperties() != null) {
|
||||||
entity.setProperties(request.getProperties());
|
entity.setProperties(request.getProperties());
|
||||||
}
|
}
|
||||||
|
if (request.getConfidence() != null) {
|
||||||
|
entity.setConfidence(request.getConfidence());
|
||||||
|
}
|
||||||
entity.setUpdatedAt(LocalDateTime.now());
|
entity.setUpdatedAt(LocalDateTime.now());
|
||||||
return entityRepository.save(entity);
|
GraphEntity saved = entityRepository.save(entity);
|
||||||
|
cacheService.evictEntityCaches(graphId, entityId);
|
||||||
|
cacheService.evictSearchCaches(graphId);
|
||||||
|
return saved;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
@@ -144,6 +166,8 @@ public class GraphEntityService {
|
|||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
GraphEntity entity = getEntity(graphId, entityId);
|
GraphEntity entity = getEntity(graphId, entityId);
|
||||||
entityRepository.delete(entity);
|
entityRepository.delete(entity);
|
||||||
|
cacheService.evictEntityCaches(graphId, entityId);
|
||||||
|
cacheService.evictSearchCaches(graphId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<GraphEntity> getNeighbors(String graphId, String entityId, int depth, int limit) {
|
public List<GraphEntity> getNeighbors(String graphId, String entityId, int depth, int limit) {
|
||||||
@@ -153,6 +177,28 @@ public class GraphEntityService {
|
|||||||
return entityRepository.findNeighbors(graphId, entityId, clampedDepth, clampedLimit);
|
return entityRepository.findNeighbors(graphId, entityId, clampedDepth, clampedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public Map<String, Object> batchDeleteEntities(String graphId, List<String> entityIds) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
int deleted = 0;
|
||||||
|
List<String> failedIds = new ArrayList<>();
|
||||||
|
for (String entityId : entityIds) {
|
||||||
|
try {
|
||||||
|
deleteEntity(graphId, entityId);
|
||||||
|
deleted++;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Batch delete: failed to delete entity {}: {}", entityId, e.getMessage());
|
||||||
|
failedIds.add(entityId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Map<String, Object> result = Map.of(
|
||||||
|
"deleted", deleted,
|
||||||
|
"total", entityIds.size(),
|
||||||
|
"failedIds", failedIds
|
||||||
|
);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
public long countEntities(String graphId) {
|
public long countEntities(String graphId) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
return entityRepository.countByGraphId(graphId);
|
return entityRepository.countByGraphId(graphId);
|
||||||
|
|||||||
@@ -6,23 +6,32 @@ import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
|||||||
import com.datamate.common.interfaces.PagedResponse;
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.RedisCacheConfig;
|
||||||
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.*;
|
import com.datamate.knowledgegraph.interfaces.dto.*;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.neo4j.driver.Driver;
|
||||||
|
import org.neo4j.driver.Record;
|
||||||
|
import org.neo4j.driver.Session;
|
||||||
|
import org.neo4j.driver.TransactionConfig;
|
||||||
import org.neo4j.driver.Value;
|
import org.neo4j.driver.Value;
|
||||||
import org.neo4j.driver.types.MapAccessor;
|
import org.neo4j.driver.types.MapAccessor;
|
||||||
|
import org.springframework.cache.annotation.Cacheable;
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient;
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.function.Function;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识图谱查询服务。
|
* 知识图谱查询服务。
|
||||||
* <p>
|
* <p>
|
||||||
* 提供图遍历(N 跳邻居、最短路径、子图提取)和全文搜索功能。
|
* 提供图遍历(N 跳邻居、最短路径、所有路径、子图提取、子图导出)和全文搜索功能。
|
||||||
* 使用 {@link Neo4jClient} 执行复杂 Cypher 查询。
|
* 使用 {@link Neo4jClient} 执行复杂 Cypher 查询。
|
||||||
* <p>
|
* <p>
|
||||||
* 查询结果根据用户权限进行过滤:
|
* 查询结果根据用户权限进行过滤:
|
||||||
@@ -48,6 +57,7 @@ public class GraphQueryService {
|
|||||||
);
|
);
|
||||||
|
|
||||||
private final Neo4jClient neo4jClient;
|
private final Neo4jClient neo4jClient;
|
||||||
|
private final Driver neo4jDriver;
|
||||||
private final GraphEntityRepository entityRepository;
|
private final GraphEntityRepository entityRepository;
|
||||||
private final KnowledgeGraphProperties properties;
|
private final KnowledgeGraphProperties properties;
|
||||||
private final ResourceAccessService resourceAccessService;
|
private final ResourceAccessService resourceAccessService;
|
||||||
@@ -62,6 +72,9 @@ public class GraphQueryService {
|
|||||||
* @param depth 跳数(1-3,由配置上限约束)
|
* @param depth 跳数(1-3,由配置上限约束)
|
||||||
* @param limit 返回节点数上限
|
* @param limit 返回节点数上限
|
||||||
*/
|
*/
|
||||||
|
@Cacheable(value = RedisCacheConfig.CACHE_QUERIES,
|
||||||
|
key = "T(com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService).cacheKey(#graphId, #entityId, #depth, #limit, @resourceAccessService.resolveOwnerFilterUserId(), @resourceAccessService.canViewConfidential())",
|
||||||
|
cacheManager = "knowledgeGraphCacheManager")
|
||||||
public SubgraphVO getNeighborGraph(String graphId, String entityId, int depth, int limit) {
|
public SubgraphVO getNeighborGraph(String graphId, String entityId, int depth, int limit) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
String filterUserId = resolveOwnerFilter();
|
String filterUserId = resolveOwnerFilter();
|
||||||
@@ -225,6 +238,7 @@ public class GraphQueryService {
|
|||||||
" (t:Entity {graph_id: $graphId, id: $targetId}), " +
|
" (t:Entity {graph_id: $graphId, id: $targetId}), " +
|
||||||
" path = shortestPath((s)-[:" + REL_TYPE + "*1.." + clampedDepth + "]-(t)) " +
|
" path = shortestPath((s)-[:" + REL_TYPE + "*1.." + clampedDepth + "]-(t)) " +
|
||||||
"WHERE ALL(n IN nodes(path) WHERE n.graph_id = $graphId) " +
|
"WHERE ALL(n IN nodes(path) WHERE n.graph_id = $graphId) " +
|
||||||
|
" AND ALL(r IN relationships(path) WHERE r.graph_id = $graphId) " +
|
||||||
permFilter +
|
permFilter +
|
||||||
"RETURN " +
|
"RETURN " +
|
||||||
" [n IN nodes(path) | {id: n.id, name: n.name, type: n.type, description: n.description}] AS pathNodes, " +
|
" [n IN nodes(path) | {id: n.id, name: n.name, type: n.type, description: n.description}] AS pathNodes, " +
|
||||||
@@ -244,6 +258,106 @@ public class GraphQueryService {
|
|||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 所有路径
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询两个实体之间的所有路径。
|
||||||
|
*
|
||||||
|
* @param maxDepth 最大搜索深度(由配置上限约束)
|
||||||
|
* @param maxPaths 返回路径数上限
|
||||||
|
* @return 所有路径结果,按路径长度升序排列
|
||||||
|
*/
|
||||||
|
public AllPathsVO findAllPaths(String graphId, String sourceId, String targetId, int maxDepth, int maxPaths) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
String filterUserId = resolveOwnerFilter();
|
||||||
|
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
|
||||||
|
|
||||||
|
// 校验两个实体存在 + 权限
|
||||||
|
GraphEntity sourceEntity = entityRepository.findByIdAndGraphId(sourceId, graphId)
|
||||||
|
.orElseThrow(() -> BusinessException.of(
|
||||||
|
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "源实体不存在"));
|
||||||
|
|
||||||
|
if (filterUserId != null) {
|
||||||
|
assertEntityAccess(sourceEntity, filterUserId, excludeConfidential);
|
||||||
|
}
|
||||||
|
|
||||||
|
entityRepository.findByIdAndGraphId(targetId, graphId)
|
||||||
|
.ifPresentOrElse(
|
||||||
|
targetEntity -> {
|
||||||
|
if (filterUserId != null && !sourceId.equals(targetId)) {
|
||||||
|
assertEntityAccess(targetEntity, filterUserId, excludeConfidential);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
() -> { throw BusinessException.of(
|
||||||
|
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "目标实体不存在"); }
|
||||||
|
);
|
||||||
|
|
||||||
|
if (sourceId.equals(targetId)) {
|
||||||
|
EntitySummaryVO node = EntitySummaryVO.builder()
|
||||||
|
.id(sourceEntity.getId())
|
||||||
|
.name(sourceEntity.getName())
|
||||||
|
.type(sourceEntity.getType())
|
||||||
|
.description(sourceEntity.getDescription())
|
||||||
|
.build();
|
||||||
|
PathVO singlePath = PathVO.builder()
|
||||||
|
.nodes(List.of(node))
|
||||||
|
.edges(List.of())
|
||||||
|
.pathLength(0)
|
||||||
|
.build();
|
||||||
|
return AllPathsVO.builder()
|
||||||
|
.paths(List.of(singlePath))
|
||||||
|
.pathCount(1)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
int clampedDepth = Math.max(1, Math.min(maxDepth, properties.getMaxDepth()));
|
||||||
|
int clampedMaxPaths = Math.max(1, Math.min(maxPaths, properties.getMaxNodesPerQuery()));
|
||||||
|
|
||||||
|
String permFilter = "";
|
||||||
|
if (filterUserId != null) {
|
||||||
|
StringBuilder pf = new StringBuilder("AND ALL(n IN nodes(path) WHERE ");
|
||||||
|
pf.append("(n.type IN ['User', 'Org', 'Field'] OR n.`properties.created_by` = $filterUserId)");
|
||||||
|
if (excludeConfidential) {
|
||||||
|
pf.append(" AND (toUpper(trim(n.`properties.sensitivity`)) IS NULL OR toUpper(trim(n.`properties.sensitivity`)) <> 'CONFIDENTIAL')");
|
||||||
|
}
|
||||||
|
pf.append(") ");
|
||||||
|
permFilter = pf.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("graphId", graphId);
|
||||||
|
params.put("sourceId", sourceId);
|
||||||
|
params.put("targetId", targetId);
|
||||||
|
params.put("maxPaths", clampedMaxPaths);
|
||||||
|
if (filterUserId != null) {
|
||||||
|
params.put("filterUserId", filterUserId);
|
||||||
|
}
|
||||||
|
|
||||||
|
String cypher =
|
||||||
|
"MATCH (s:Entity {graph_id: $graphId, id: $sourceId}), " +
|
||||||
|
" (t:Entity {graph_id: $graphId, id: $targetId}), " +
|
||||||
|
" path = (s)-[:" + REL_TYPE + "*1.." + clampedDepth + "]-(t) " +
|
||||||
|
"WHERE ALL(n IN nodes(path) WHERE n.graph_id = $graphId) " +
|
||||||
|
" AND ALL(r IN relationships(path) WHERE r.graph_id = $graphId) " +
|
||||||
|
permFilter +
|
||||||
|
"RETURN " +
|
||||||
|
" [n IN nodes(path) | {id: n.id, name: n.name, type: n.type, description: n.description}] AS pathNodes, " +
|
||||||
|
" [r IN relationships(path) | {id: r.id, relation_type: r.relation_type, weight: r.weight, " +
|
||||||
|
" source: startNode(r).id, target: endNode(r).id}] AS pathEdges, " +
|
||||||
|
" length(path) AS pathLength " +
|
||||||
|
"ORDER BY length(path) ASC " +
|
||||||
|
"LIMIT $maxPaths";
|
||||||
|
|
||||||
|
List<PathVO> paths = queryWithTimeout(cypher, params, record -> mapPathRecord(record));
|
||||||
|
|
||||||
|
return AllPathsVO.builder()
|
||||||
|
.paths(paths)
|
||||||
|
.pathCount(paths.size())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// 子图提取
|
// 子图提取
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -313,6 +427,140 @@ public class GraphQueryService {
|
|||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 子图导出
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出指定实体集合的子图,支持深度扩展。
|
||||||
|
*
|
||||||
|
* @param entityIds 种子实体 ID 列表
|
||||||
|
* @param depth 扩展深度(0=仅种子实体,1=含 1 跳邻居,以此类推)
|
||||||
|
* @return 包含完整属性的子图导出结果
|
||||||
|
*/
|
||||||
|
public SubgraphExportVO exportSubgraph(String graphId, List<String> entityIds, int depth) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
String filterUserId = resolveOwnerFilter();
|
||||||
|
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
|
||||||
|
|
||||||
|
if (entityIds == null || entityIds.isEmpty()) {
|
||||||
|
return SubgraphExportVO.builder()
|
||||||
|
.nodes(List.of())
|
||||||
|
.edges(List.of())
|
||||||
|
.nodeCount(0)
|
||||||
|
.edgeCount(0)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
int maxNodes = properties.getMaxNodesPerQuery();
|
||||||
|
if (entityIds.size() > maxNodes) {
|
||||||
|
throw BusinessException.of(KnowledgeGraphErrorCode.MAX_NODES_EXCEEDED,
|
||||||
|
"实体数量超出限制(最大 " + maxNodes + ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
int clampedDepth = Math.max(0, Math.min(depth, properties.getMaxDepth()));
|
||||||
|
List<GraphEntity> entities;
|
||||||
|
|
||||||
|
if (clampedDepth == 0) {
|
||||||
|
// 仅种子实体
|
||||||
|
entities = entityRepository.findByGraphIdAndIdIn(graphId, entityIds);
|
||||||
|
} else {
|
||||||
|
// 扩展邻居:先查询扩展后的节点 ID 集合
|
||||||
|
Set<String> expandedIds = expandNeighborIds(graphId, entityIds, clampedDepth,
|
||||||
|
filterUserId, excludeConfidential, maxNodes);
|
||||||
|
entities = expandedIds.isEmpty()
|
||||||
|
? List.of()
|
||||||
|
: entityRepository.findByGraphIdAndIdIn(graphId, new ArrayList<>(expandedIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 权限过滤
|
||||||
|
if (filterUserId != null) {
|
||||||
|
entities = entities.stream()
|
||||||
|
.filter(e -> isEntityAccessible(e, filterUserId, excludeConfidential))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (entities.isEmpty()) {
|
||||||
|
return SubgraphExportVO.builder()
|
||||||
|
.nodes(List.of())
|
||||||
|
.edges(List.of())
|
||||||
|
.nodeCount(0)
|
||||||
|
.edgeCount(0)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<ExportNodeVO> nodes = entities.stream()
|
||||||
|
.map(e -> ExportNodeVO.builder()
|
||||||
|
.id(e.getId())
|
||||||
|
.name(e.getName())
|
||||||
|
.type(e.getType())
|
||||||
|
.description(e.getDescription())
|
||||||
|
.properties(e.getProperties() != null ? e.getProperties() : Map.of())
|
||||||
|
.build())
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
List<String> nodeIds = entities.stream().map(GraphEntity::getId).toList();
|
||||||
|
List<ExportEdgeVO> edges = queryExportEdgesBetween(graphId, nodeIds);
|
||||||
|
|
||||||
|
return SubgraphExportVO.builder()
|
||||||
|
.nodes(nodes)
|
||||||
|
.edges(edges)
|
||||||
|
.nodeCount(nodes.size())
|
||||||
|
.edgeCount(edges.size())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将子图导出结果转换为 GraphML XML 格式。
|
||||||
|
*/
|
||||||
|
public String convertToGraphML(SubgraphExportVO exportVO) {
|
||||||
|
StringBuilder xml = new StringBuilder();
|
||||||
|
xml.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
|
||||||
|
xml.append("<graphml xmlns=\"http://graphml.graphstruct.org/graphml\"\n");
|
||||||
|
xml.append(" xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n");
|
||||||
|
xml.append(" xsi:schemaLocation=\"http://graphml.graphstruct.org/graphml ");
|
||||||
|
xml.append("http://graphml.graphstruct.org/xmlns/1.0/graphml.xsd\">\n");
|
||||||
|
|
||||||
|
// Key 定义
|
||||||
|
xml.append(" <key id=\"name\" for=\"node\" attr.name=\"name\" attr.type=\"string\"/>\n");
|
||||||
|
xml.append(" <key id=\"type\" for=\"node\" attr.name=\"type\" attr.type=\"string\"/>\n");
|
||||||
|
xml.append(" <key id=\"description\" for=\"node\" attr.name=\"description\" attr.type=\"string\"/>\n");
|
||||||
|
xml.append(" <key id=\"relationType\" for=\"edge\" attr.name=\"relationType\" attr.type=\"string\"/>\n");
|
||||||
|
xml.append(" <key id=\"weight\" for=\"edge\" attr.name=\"weight\" attr.type=\"double\"/>\n");
|
||||||
|
|
||||||
|
xml.append(" <graph id=\"G\" edgedefault=\"directed\">\n");
|
||||||
|
|
||||||
|
// 节点
|
||||||
|
if (exportVO.getNodes() != null) {
|
||||||
|
for (ExportNodeVO node : exportVO.getNodes()) {
|
||||||
|
xml.append(" <node id=\"").append(escapeXml(node.getId())).append("\">\n");
|
||||||
|
appendGraphMLData(xml, "name", node.getName());
|
||||||
|
appendGraphMLData(xml, "type", node.getType());
|
||||||
|
appendGraphMLData(xml, "description", node.getDescription());
|
||||||
|
xml.append(" </node>\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 边
|
||||||
|
if (exportVO.getEdges() != null) {
|
||||||
|
for (ExportEdgeVO edge : exportVO.getEdges()) {
|
||||||
|
xml.append(" <edge id=\"").append(escapeXml(edge.getId()))
|
||||||
|
.append("\" source=\"").append(escapeXml(edge.getSourceEntityId()))
|
||||||
|
.append("\" target=\"").append(escapeXml(edge.getTargetEntityId()))
|
||||||
|
.append("\">\n");
|
||||||
|
appendGraphMLData(xml, "relationType", edge.getRelationType());
|
||||||
|
if (edge.getWeight() != null) {
|
||||||
|
appendGraphMLData(xml, "weight", String.valueOf(edge.getWeight()));
|
||||||
|
}
|
||||||
|
xml.append(" </edge>\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xml.append(" </graph>\n");
|
||||||
|
xml.append("</graphml>\n");
|
||||||
|
return xml.toString();
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// 全文搜索
|
// 全文搜索
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -325,6 +573,9 @@ public class GraphQueryService {
|
|||||||
*
|
*
|
||||||
* @param query 搜索关键词(支持 Lucene 查询语法)
|
* @param query 搜索关键词(支持 Lucene 查询语法)
|
||||||
*/
|
*/
|
||||||
|
@Cacheable(value = RedisCacheConfig.CACHE_SEARCH,
|
||||||
|
key = "T(com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService).cacheKey(#graphId, #query, #page, #size, @resourceAccessService.resolveOwnerFilterUserId(), @resourceAccessService.canViewConfidential())",
|
||||||
|
cacheManager = "knowledgeGraphCacheManager")
|
||||||
public PagedResponse<SearchHitVO> fulltextSearch(String graphId, String query, int page, int size) {
|
public PagedResponse<SearchHitVO> fulltextSearch(String graphId, String query, int page, int size) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
String filterUserId = resolveOwnerFilter();
|
String filterUserId = resolveOwnerFilter();
|
||||||
@@ -581,9 +832,159 @@ public class GraphQueryService {
|
|||||||
return (v == null || v.isNull()) ? null : v.asDouble();
|
return (v == null || v.isNull()) ? null : v.asDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询指定节点集合之间的所有边(导出用,包含完整属性)。
|
||||||
|
*/
|
||||||
|
private List<ExportEdgeVO> queryExportEdgesBetween(String graphId, List<String> nodeIds) {
|
||||||
|
if (nodeIds.size() < 2) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
|
||||||
|
return neo4jClient
|
||||||
|
.query(
|
||||||
|
"MATCH (s:Entity {graph_id: $graphId})-[r:" + REL_TYPE + " {graph_id: $graphId}]->(t:Entity {graph_id: $graphId}) " +
|
||||||
|
"WHERE s.id IN $nodeIds AND t.id IN $nodeIds " +
|
||||||
|
"RETURN r.id AS id, s.id AS sourceEntityId, t.id AS targetEntityId, " +
|
||||||
|
"r.relation_type AS relationType, r.weight AS weight, " +
|
||||||
|
"r.confidence AS confidence, r.source_id AS sourceId"
|
||||||
|
)
|
||||||
|
.bindAll(Map.of("graphId", graphId, "nodeIds", nodeIds))
|
||||||
|
.fetchAs(ExportEdgeVO.class)
|
||||||
|
.mappedBy((ts, record) -> ExportEdgeVO.builder()
|
||||||
|
.id(record.get("id").asString(null))
|
||||||
|
.sourceEntityId(record.get("sourceEntityId").asString(null))
|
||||||
|
.targetEntityId(record.get("targetEntityId").asString(null))
|
||||||
|
.relationType(record.get("relationType").asString(null))
|
||||||
|
.weight(record.get("weight").isNull() ? null : record.get("weight").asDouble())
|
||||||
|
.confidence(record.get("confidence").isNull() ? null : record.get("confidence").asDouble())
|
||||||
|
.sourceId(record.get("sourceId").asString(null))
|
||||||
|
.build())
|
||||||
|
.all()
|
||||||
|
.stream().toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从种子实体扩展 N 跳邻居,返回所有节点 ID(含种子)。
|
||||||
|
* <p>
|
||||||
|
* 使用事务超时保护,防止深度扩展导致组合爆炸。
|
||||||
|
* 结果总数严格不超过 maxNodes(含种子节点)。
|
||||||
|
*/
|
||||||
|
private Set<String> expandNeighborIds(String graphId, List<String> seedIds, int depth,
|
||||||
|
String filterUserId, boolean excludeConfidential, int maxNodes) {
|
||||||
|
String permFilter = "";
|
||||||
|
if (filterUserId != null) {
|
||||||
|
StringBuilder pf = new StringBuilder("AND ALL(n IN nodes(p) WHERE ");
|
||||||
|
pf.append("(n.type IN ['User', 'Org', 'Field'] OR n.`properties.created_by` = $filterUserId)");
|
||||||
|
if (excludeConfidential) {
|
||||||
|
pf.append(" AND (toUpper(trim(n.`properties.sensitivity`)) IS NULL OR toUpper(trim(n.`properties.sensitivity`)) <> 'CONFIDENTIAL')");
|
||||||
|
}
|
||||||
|
pf.append(") ");
|
||||||
|
permFilter = pf.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("graphId", graphId);
|
||||||
|
params.put("seedIds", seedIds);
|
||||||
|
params.put("maxNodes", maxNodes);
|
||||||
|
if (filterUserId != null) {
|
||||||
|
params.put("filterUserId", filterUserId);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 种子节点在 Cypher 中纳入 LIMIT 约束,确保总数不超过 maxNodes
|
||||||
|
String cypher =
|
||||||
|
"MATCH (seed:Entity {graph_id: $graphId}) " +
|
||||||
|
"WHERE seed.id IN $seedIds " +
|
||||||
|
"WITH collect(DISTINCT seed) AS seeds " +
|
||||||
|
"UNWIND seeds AS s " +
|
||||||
|
"OPTIONAL MATCH p = (s)-[:" + REL_TYPE + "*1.." + depth + "]-(neighbor:Entity) " +
|
||||||
|
"WHERE ALL(n IN nodes(p) WHERE n.graph_id = $graphId) " +
|
||||||
|
" AND ALL(r IN relationships(p) WHERE r.graph_id = $graphId) " +
|
||||||
|
permFilter +
|
||||||
|
"WITH seeds + collect(DISTINCT neighbor) AS allNodes " +
|
||||||
|
"UNWIND allNodes AS node " +
|
||||||
|
"WITH DISTINCT node " +
|
||||||
|
"WHERE node IS NOT NULL " +
|
||||||
|
"RETURN node.id AS id " +
|
||||||
|
"LIMIT $maxNodes";
|
||||||
|
|
||||||
|
List<String> ids = queryWithTimeout(cypher, params,
|
||||||
|
record -> record.get("id").asString(null));
|
||||||
|
|
||||||
|
return new LinkedHashSet<>(ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void appendGraphMLData(StringBuilder xml, String key, String value) {
|
||||||
|
if (value != null) {
|
||||||
|
xml.append(" <data key=\"").append(key).append("\">")
|
||||||
|
.append(escapeXml(value))
|
||||||
|
.append("</data>\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String escapeXml(String text) {
|
||||||
|
if (text == null) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return text.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
.replace("\"", """)
|
||||||
|
.replace("'", "'");
|
||||||
|
}
|
||||||
|
|
||||||
private void validateGraphId(String graphId) {
|
private void validateGraphId(String graphId) {
|
||||||
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
|
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
|
||||||
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
|
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用 Neo4j Driver 直接执行查询,附带事务级超时保护。
|
||||||
|
* <p>
|
||||||
|
* 用于路径枚举等可能触发组合爆炸的高开销查询,
|
||||||
|
* 超时后 Neo4j 服务端会主动终止事务,避免资源耗尽。
|
||||||
|
*/
|
||||||
|
private <T> List<T> queryWithTimeout(String cypher, Map<String, Object> params,
|
||||||
|
Function<Record, T> mapper) {
|
||||||
|
int timeoutSeconds = properties.getQueryTimeoutSeconds();
|
||||||
|
TransactionConfig txConfig = TransactionConfig.builder()
|
||||||
|
.withTimeout(Duration.ofSeconds(timeoutSeconds))
|
||||||
|
.build();
|
||||||
|
try (Session session = neo4jDriver.session()) {
|
||||||
|
return session.executeRead(tx -> {
|
||||||
|
var result = tx.run(cypher, params);
|
||||||
|
List<T> items = new ArrayList<>();
|
||||||
|
while (result.hasNext()) {
|
||||||
|
items.add(mapper.apply(result.next()));
|
||||||
|
}
|
||||||
|
return items;
|
||||||
|
}, txConfig);
|
||||||
|
} catch (Exception e) {
|
||||||
|
if (isTransactionTimeout(e)) {
|
||||||
|
log.warn("图查询超时({}秒): {}", timeoutSeconds, cypher.substring(0, Math.min(cypher.length(), 120)));
|
||||||
|
throw BusinessException.of(KnowledgeGraphErrorCode.QUERY_TIMEOUT,
|
||||||
|
"查询超时(" + timeoutSeconds + "秒),请缩小搜索范围或减少深度");
|
||||||
|
}
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断异常是否为 Neo4j 事务超时。
|
||||||
|
*/
|
||||||
|
private static boolean isTransactionTimeout(Exception e) {
|
||||||
|
// Neo4j 事务超时时抛出的异常链中通常包含 "terminated" 或 "timeout"
|
||||||
|
Throwable current = e;
|
||||||
|
while (current != null) {
|
||||||
|
String msg = current.getMessage();
|
||||||
|
if (msg != null) {
|
||||||
|
String lower = msg.toLowerCase(Locale.ROOT);
|
||||||
|
if (lower.contains("transaction has been terminated") || lower.contains("timed out")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current.getCause();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.datamate.common.interfaces.PagedResponse;
|
|||||||
import com.datamate.knowledgegraph.domain.model.RelationDetail;
|
import com.datamate.knowledgegraph.domain.model.RelationDetail;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
|
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
|
||||||
@@ -15,7 +16,9 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
@@ -43,6 +46,7 @@ public class GraphRelationService {
|
|||||||
|
|
||||||
private final GraphRelationRepository relationRepository;
|
private final GraphRelationRepository relationRepository;
|
||||||
private final GraphEntityRepository entityRepository;
|
private final GraphEntityRepository entityRepository;
|
||||||
|
private final GraphCacheService cacheService;
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public RelationVO createRelation(String graphId, CreateRelationRequest request) {
|
public RelationVO createRelation(String graphId, CreateRelationRequest request) {
|
||||||
@@ -73,6 +77,7 @@ public class GraphRelationService {
|
|||||||
log.info("Relation created: id={}, graphId={}, type={}, source={} -> target={}",
|
log.info("Relation created: id={}, graphId={}, type={}, source={} -> target={}",
|
||||||
detail.getId(), graphId, request.getRelationType(),
|
detail.getId(), graphId, request.getRelationType(),
|
||||||
request.getSourceEntityId(), request.getTargetEntityId());
|
request.getSourceEntityId(), request.getTargetEntityId());
|
||||||
|
cacheService.evictEntityCaches(graphId, request.getSourceEntityId());
|
||||||
return toVO(detail);
|
return toVO(detail);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +170,7 @@ public class GraphRelationService {
|
|||||||
).orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
|
).orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
|
||||||
|
|
||||||
log.info("Relation updated: id={}, graphId={}", relationId, graphId);
|
log.info("Relation updated: id={}, graphId={}", relationId, graphId);
|
||||||
|
cacheService.evictEntityCaches(graphId, detail.getSourceEntityId());
|
||||||
return toVO(detail);
|
return toVO(detail);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,8 +178,8 @@ public class GraphRelationService {
|
|||||||
public void deleteRelation(String graphId, String relationId) {
|
public void deleteRelation(String graphId, String relationId) {
|
||||||
validateGraphId(graphId);
|
validateGraphId(graphId);
|
||||||
|
|
||||||
// 确认关系存在
|
// 确认关系存在并保留关系两端实体 ID,用于精准缓存失效
|
||||||
relationRepository.findByIdAndGraphId(relationId, graphId)
|
RelationDetail detail = relationRepository.findByIdAndGraphId(relationId, graphId)
|
||||||
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
|
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
|
||||||
|
|
||||||
long deleted = relationRepository.deleteByIdAndGraphId(relationId, graphId);
|
long deleted = relationRepository.deleteByIdAndGraphId(relationId, graphId);
|
||||||
@@ -181,6 +187,33 @@ public class GraphRelationService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND);
|
throw BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND);
|
||||||
}
|
}
|
||||||
log.info("Relation deleted: id={}, graphId={}", relationId, graphId);
|
log.info("Relation deleted: id={}, graphId={}", relationId, graphId);
|
||||||
|
cacheService.evictEntityCaches(graphId, detail.getSourceEntityId());
|
||||||
|
if (detail.getTargetEntityId() != null
|
||||||
|
&& !detail.getTargetEntityId().equals(detail.getSourceEntityId())) {
|
||||||
|
cacheService.evictEntityCaches(graphId, detail.getTargetEntityId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Transactional
|
||||||
|
public Map<String, Object> batchDeleteRelations(String graphId, List<String> relationIds) {
|
||||||
|
validateGraphId(graphId);
|
||||||
|
int deleted = 0;
|
||||||
|
List<String> failedIds = new ArrayList<>();
|
||||||
|
for (String relationId : relationIds) {
|
||||||
|
try {
|
||||||
|
deleteRelation(graphId, relationId);
|
||||||
|
deleted++;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Batch delete: failed to delete relation {}: {}", relationId, e.getMessage());
|
||||||
|
failedIds.add(relationId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Map<String, Object> result = Map.of(
|
||||||
|
"deleted", deleted,
|
||||||
|
"total", relationIds.size(),
|
||||||
|
"failedIds", failedIds
|
||||||
|
);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.datamate.common.infrastructure.exception.SystemErrorCode;
|
|||||||
import com.datamate.knowledgegraph.domain.model.SyncMetadata;
|
import com.datamate.knowledgegraph.domain.model.SyncMetadata;
|
||||||
import com.datamate.knowledgegraph.domain.model.SyncResult;
|
import com.datamate.knowledgegraph.domain.model.SyncResult;
|
||||||
import com.datamate.knowledgegraph.domain.repository.SyncHistoryRepository;
|
import com.datamate.knowledgegraph.domain.repository.SyncHistoryRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
|
||||||
@@ -56,6 +57,7 @@ public class GraphSyncService {
|
|||||||
private final DataManagementClient dataManagementClient;
|
private final DataManagementClient dataManagementClient;
|
||||||
private final KnowledgeGraphProperties properties;
|
private final KnowledgeGraphProperties properties;
|
||||||
private final SyncHistoryRepository syncHistoryRepository;
|
private final SyncHistoryRepository syncHistoryRepository;
|
||||||
|
private final GraphCacheService cacheService;
|
||||||
|
|
||||||
/** 同 graphId 互斥锁,防止并发同步。 */
|
/** 同 graphId 互斥锁,防止并发同步。 */
|
||||||
private final ConcurrentHashMap<String, ReentrantLock> graphLocks = new ConcurrentHashMap<>();
|
private final ConcurrentHashMap<String, ReentrantLock> graphLocks = new ConcurrentHashMap<>();
|
||||||
@@ -93,7 +95,15 @@ public class GraphSyncService {
|
|||||||
|
|
||||||
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
|
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
|
||||||
resultMap.put("User", stepService.upsertUserEntities(graphId, usernames, syncId));
|
resultMap.put("User", stepService.upsertUserEntities(graphId, usernames, syncId));
|
||||||
resultMap.put("Org", stepService.upsertOrgEntities(graphId, syncId));
|
|
||||||
|
Map<String, String> userOrgMap = fetchMapWithRetry(syncId, "user-orgs",
|
||||||
|
() -> dataManagementClient.fetchUserOrganizationMap());
|
||||||
|
boolean orgMapDegraded = (userOrgMap == null);
|
||||||
|
if (orgMapDegraded) {
|
||||||
|
log.warn("[{}] Org map fetch degraded, using empty map; Org purge will be skipped", syncId);
|
||||||
|
userOrgMap = Collections.emptyMap();
|
||||||
|
}
|
||||||
|
resultMap.put("Org", stepService.upsertOrgEntities(graphId, userOrgMap, syncId));
|
||||||
resultMap.put("Workflow", stepService.upsertWorkflowEntities(graphId, workflows, syncId));
|
resultMap.put("Workflow", stepService.upsertWorkflowEntities(graphId, workflows, syncId));
|
||||||
resultMap.put("Job", stepService.upsertJobEntities(graphId, jobs, syncId));
|
resultMap.put("Job", stepService.upsertJobEntities(graphId, jobs, syncId));
|
||||||
resultMap.put("LabelTask", stepService.upsertLabelTaskEntities(graphId, labelTasks, syncId));
|
resultMap.put("LabelTask", stepService.upsertLabelTaskEntities(graphId, labelTasks, syncId));
|
||||||
@@ -130,6 +140,14 @@ public class GraphSyncService {
|
|||||||
resultMap.get("User").setPurged(
|
resultMap.get("User").setPurged(
|
||||||
stepService.purgeStaleEntities(graphId, "User", activeUserIds, syncId));
|
stepService.purgeStaleEntities(graphId, "User", activeUserIds, syncId));
|
||||||
|
|
||||||
|
if (!orgMapDegraded) {
|
||||||
|
Set<String> activeOrgSourceIds = buildActiveOrgSourceIds(userOrgMap);
|
||||||
|
resultMap.get("Org").setPurged(
|
||||||
|
stepService.purgeStaleEntities(graphId, "Org", activeOrgSourceIds, syncId));
|
||||||
|
} else {
|
||||||
|
log.info("[{}] Skipping Org purge due to degraded org map fetch", syncId);
|
||||||
|
}
|
||||||
|
|
||||||
Set<String> activeWorkflowIds = workflows.stream()
|
Set<String> activeWorkflowIds = workflows.stream()
|
||||||
.filter(Objects::nonNull)
|
.filter(Objects::nonNull)
|
||||||
.map(WorkflowDTO::getId)
|
.map(WorkflowDTO::getId)
|
||||||
@@ -169,7 +187,12 @@ public class GraphSyncService {
|
|||||||
// 关系构建(MERGE 幂等)
|
// 关系构建(MERGE 幂等)
|
||||||
resultMap.put("HAS_FIELD", stepService.mergeHasFieldRelations(graphId, syncId));
|
resultMap.put("HAS_FIELD", stepService.mergeHasFieldRelations(graphId, syncId));
|
||||||
resultMap.put("DERIVED_FROM", stepService.mergeDerivedFromRelations(graphId, syncId));
|
resultMap.put("DERIVED_FROM", stepService.mergeDerivedFromRelations(graphId, syncId));
|
||||||
resultMap.put("BELONGS_TO", stepService.mergeBelongsToRelations(graphId, syncId));
|
if (!orgMapDegraded) {
|
||||||
|
resultMap.put("BELONGS_TO", stepService.mergeBelongsToRelations(graphId, userOrgMap, syncId));
|
||||||
|
} else {
|
||||||
|
log.info("[{}] Skipping BELONGS_TO relation build due to degraded org map fetch", syncId);
|
||||||
|
resultMap.put("BELONGS_TO", SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
|
}
|
||||||
resultMap.put("USES_DATASET", stepService.mergeUsesDatasetRelations(graphId, syncId));
|
resultMap.put("USES_DATASET", stepService.mergeUsesDatasetRelations(graphId, syncId));
|
||||||
resultMap.put("PRODUCES", stepService.mergeProducesRelations(graphId, syncId));
|
resultMap.put("PRODUCES", stepService.mergeProducesRelations(graphId, syncId));
|
||||||
resultMap.put("ASSIGNED_TO", stepService.mergeAssignedToRelations(graphId, syncId));
|
resultMap.put("ASSIGNED_TO", stepService.mergeAssignedToRelations(graphId, syncId));
|
||||||
@@ -196,6 +219,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Full sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Full sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "全量同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "全量同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -251,7 +275,15 @@ public class GraphSyncService {
|
|||||||
|
|
||||||
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
|
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
|
||||||
resultMap.put("User", stepService.upsertUserEntities(graphId, usernames, syncId));
|
resultMap.put("User", stepService.upsertUserEntities(graphId, usernames, syncId));
|
||||||
resultMap.put("Org", stepService.upsertOrgEntities(graphId, syncId));
|
|
||||||
|
Map<String, String> userOrgMap = fetchMapWithRetry(syncId, "user-orgs",
|
||||||
|
() -> dataManagementClient.fetchUserOrganizationMap());
|
||||||
|
boolean orgMapDegraded = (userOrgMap == null);
|
||||||
|
if (orgMapDegraded) {
|
||||||
|
log.warn("[{}] Org map fetch degraded in incremental sync, using empty map", syncId);
|
||||||
|
userOrgMap = Collections.emptyMap();
|
||||||
|
}
|
||||||
|
resultMap.put("Org", stepService.upsertOrgEntities(graphId, userOrgMap, syncId));
|
||||||
resultMap.put("Workflow", stepService.upsertWorkflowEntities(graphId, workflows, syncId));
|
resultMap.put("Workflow", stepService.upsertWorkflowEntities(graphId, workflows, syncId));
|
||||||
resultMap.put("Job", stepService.upsertJobEntities(graphId, jobs, syncId));
|
resultMap.put("Job", stepService.upsertJobEntities(graphId, jobs, syncId));
|
||||||
resultMap.put("LabelTask", stepService.upsertLabelTaskEntities(graphId, labelTasks, syncId));
|
resultMap.put("LabelTask", stepService.upsertLabelTaskEntities(graphId, labelTasks, syncId));
|
||||||
@@ -263,7 +295,14 @@ public class GraphSyncService {
|
|||||||
// 关系构建(MERGE 幂等)- 增量同步时只处理变更实体相关的关系
|
// 关系构建(MERGE 幂等)- 增量同步时只处理变更实体相关的关系
|
||||||
resultMap.put("HAS_FIELD", stepService.mergeHasFieldRelations(graphId, syncId, changedEntityIds));
|
resultMap.put("HAS_FIELD", stepService.mergeHasFieldRelations(graphId, syncId, changedEntityIds));
|
||||||
resultMap.put("DERIVED_FROM", stepService.mergeDerivedFromRelations(graphId, syncId, changedEntityIds));
|
resultMap.put("DERIVED_FROM", stepService.mergeDerivedFromRelations(graphId, syncId, changedEntityIds));
|
||||||
resultMap.put("BELONGS_TO", stepService.mergeBelongsToRelations(graphId, syncId, changedEntityIds));
|
if (!orgMapDegraded) {
|
||||||
|
// BELONGS_TO 依赖全量 userOrgMap,组织映射变更可能影响全部 User/Dataset。
|
||||||
|
// 增量同步下也执行全量 BELONGS_TO 重建,避免漏更新。
|
||||||
|
resultMap.put("BELONGS_TO", stepService.mergeBelongsToRelations(graphId, userOrgMap, syncId));
|
||||||
|
} else {
|
||||||
|
log.info("[{}] Skipping BELONGS_TO relation build due to degraded org map fetch", syncId);
|
||||||
|
resultMap.put("BELONGS_TO", SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
|
}
|
||||||
resultMap.put("USES_DATASET", stepService.mergeUsesDatasetRelations(graphId, syncId, changedEntityIds));
|
resultMap.put("USES_DATASET", stepService.mergeUsesDatasetRelations(graphId, syncId, changedEntityIds));
|
||||||
resultMap.put("PRODUCES", stepService.mergeProducesRelations(graphId, syncId, changedEntityIds));
|
resultMap.put("PRODUCES", stepService.mergeProducesRelations(graphId, syncId, changedEntityIds));
|
||||||
resultMap.put("ASSIGNED_TO", stepService.mergeAssignedToRelations(graphId, syncId, changedEntityIds));
|
resultMap.put("ASSIGNED_TO", stepService.mergeAssignedToRelations(graphId, syncId, changedEntityIds));
|
||||||
@@ -298,6 +337,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Incremental sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Incremental sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "增量同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "增量同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -331,6 +371,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Dataset sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Dataset sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "数据集同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "数据集同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -367,6 +408,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Field sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Field sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "字段同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "字段同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -401,6 +443,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] User sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] User sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "用户同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "用户同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -411,7 +454,22 @@ public class GraphSyncService {
|
|||||||
LocalDateTime startedAt = LocalDateTime.now();
|
LocalDateTime startedAt = LocalDateTime.now();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
SyncResult result = stepService.upsertOrgEntities(graphId, syncId);
|
Map<String, String> userOrgMap = fetchMapWithRetry(syncId, "user-orgs",
|
||||||
|
() -> dataManagementClient.fetchUserOrganizationMap());
|
||||||
|
boolean orgMapDegraded = (userOrgMap == null);
|
||||||
|
if (orgMapDegraded) {
|
||||||
|
log.warn("[{}] Org map fetch degraded, using empty map; Org purge will be skipped", syncId);
|
||||||
|
userOrgMap = Collections.emptyMap();
|
||||||
|
}
|
||||||
|
SyncResult result = stepService.upsertOrgEntities(graphId, userOrgMap, syncId);
|
||||||
|
|
||||||
|
if (!orgMapDegraded) {
|
||||||
|
Set<String> activeOrgSourceIds = buildActiveOrgSourceIds(userOrgMap);
|
||||||
|
result.setPurged(stepService.purgeStaleEntities(graphId, "Org", activeOrgSourceIds, syncId));
|
||||||
|
} else {
|
||||||
|
log.info("[{}] Skipping Org purge due to degraded org map fetch", syncId);
|
||||||
|
}
|
||||||
|
|
||||||
saveSyncHistory(SyncMetadata.fromResults(
|
saveSyncHistory(SyncMetadata.fromResults(
|
||||||
syncId, graphId, SyncMetadata.TYPE_ORGS, startedAt, List.of(result)));
|
syncId, graphId, SyncMetadata.TYPE_ORGS, startedAt, List.of(result)));
|
||||||
return result;
|
return result;
|
||||||
@@ -423,6 +481,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Org sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Org sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "组织同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "组织同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -432,7 +491,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeHasFieldRelations(graphId, syncId);
|
SyncResult result = stepService.mergeHasFieldRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -440,6 +500,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"HAS_FIELD 关系构建失败,syncId=" + syncId);
|
"HAS_FIELD 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -449,7 +510,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeDerivedFromRelations(graphId, syncId);
|
SyncResult result = stepService.mergeDerivedFromRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -457,6 +519,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"DERIVED_FROM 关系构建失败,syncId=" + syncId);
|
"DERIVED_FROM 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -466,7 +529,14 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeBelongsToRelations(graphId, syncId);
|
Map<String, String> userOrgMap = fetchMapWithRetry(syncId, "user-orgs",
|
||||||
|
() -> dataManagementClient.fetchUserOrganizationMap());
|
||||||
|
if (userOrgMap == null) {
|
||||||
|
log.warn("[{}] Org map fetch degraded, skipping BELONGS_TO relation build to preserve existing relations", syncId);
|
||||||
|
return SyncResult.builder().syncType("BELONGS_TO").build();
|
||||||
|
}
|
||||||
|
SyncResult result = stepService.mergeBelongsToRelations(graphId, userOrgMap, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -474,6 +544,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"BELONGS_TO 关系构建失败,syncId=" + syncId);
|
"BELONGS_TO 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -507,6 +578,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Workflow sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Workflow sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "工作流同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "工作流同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -536,6 +608,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] Job sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] Job sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "作业同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "作业同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -565,6 +638,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] LabelTask sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] LabelTask sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "标注任务同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "标注任务同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -594,6 +668,7 @@ public class GraphSyncService {
|
|||||||
log.error("[{}] KnowledgeSet sync failed for graphId={}", syncId, graphId, e);
|
log.error("[{}] KnowledgeSet sync failed for graphId={}", syncId, graphId, e);
|
||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "知识集同步失败,syncId=" + syncId);
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "知识集同步失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -607,7 +682,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeUsesDatasetRelations(graphId, syncId);
|
SyncResult result = stepService.mergeUsesDatasetRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -615,6 +691,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"USES_DATASET 关系构建失败,syncId=" + syncId);
|
"USES_DATASET 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -624,7 +701,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeProducesRelations(graphId, syncId);
|
SyncResult result = stepService.mergeProducesRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -632,6 +710,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"PRODUCES 关系构建失败,syncId=" + syncId);
|
"PRODUCES 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -641,7 +720,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeAssignedToRelations(graphId, syncId);
|
SyncResult result = stepService.mergeAssignedToRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -649,6 +729,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"ASSIGNED_TO 关系构建失败,syncId=" + syncId);
|
"ASSIGNED_TO 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -658,7 +739,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeTriggersRelations(graphId, syncId);
|
SyncResult result = stepService.mergeTriggersRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -666,6 +748,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"TRIGGERS 关系构建失败,syncId=" + syncId);
|
"TRIGGERS 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -675,7 +758,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeDependsOnRelations(graphId, syncId);
|
SyncResult result = stepService.mergeDependsOnRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -683,6 +767,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"DEPENDS_ON 关系构建失败,syncId=" + syncId);
|
"DEPENDS_ON 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -692,7 +777,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeImpactsRelations(graphId, syncId);
|
SyncResult result = stepService.mergeImpactsRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -700,6 +786,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"IMPACTS 关系构建失败,syncId=" + syncId);
|
"IMPACTS 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -709,7 +796,8 @@ public class GraphSyncService {
|
|||||||
String syncId = UUID.randomUUID().toString();
|
String syncId = UUID.randomUUID().toString();
|
||||||
ReentrantLock lock = acquireLock(graphId, syncId);
|
ReentrantLock lock = acquireLock(graphId, syncId);
|
||||||
try {
|
try {
|
||||||
return stepService.mergeSourcedFromRelations(graphId, syncId);
|
SyncResult result = stepService.mergeSourcedFromRelations(graphId, syncId);
|
||||||
|
return result;
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
throw e;
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -717,6 +805,7 @@ public class GraphSyncService {
|
|||||||
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
|
||||||
"SOURCED_FROM 关系构建失败,syncId=" + syncId);
|
"SOURCED_FROM 关系构建失败,syncId=" + syncId);
|
||||||
} finally {
|
} finally {
|
||||||
|
cacheService.evictGraphCaches(graphId);
|
||||||
releaseLock(graphId, lock);
|
releaseLock(graphId, lock);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -819,6 +908,54 @@ public class GraphSyncService {
|
|||||||
"拉取" + resourceName + "失败(已重试 " + maxRetries + " 次),syncId=" + syncId);
|
"拉取" + resourceName + "失败(已重试 " + maxRetries + " 次),syncId=" + syncId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 带重试的 Map 拉取方法。失败时返回 {@code null} 表示降级。
|
||||||
|
* <p>
|
||||||
|
* 调用方需检查返回值是否为 null,并在降级时跳过依赖完整数据的操作
|
||||||
|
* (如 purge),以避免基于不完整快照误删数据。
|
||||||
|
*/
|
||||||
|
private <K, V> Map<K, V> fetchMapWithRetry(String syncId, String resourceName,
|
||||||
|
java.util.function.Supplier<Map<K, V>> fetcher) {
|
||||||
|
int maxRetries = properties.getSync().getMaxRetries();
|
||||||
|
long retryInterval = properties.getSync().getRetryInterval();
|
||||||
|
Exception lastException = null;
|
||||||
|
|
||||||
|
for (int attempt = 1; attempt <= maxRetries; attempt++) {
|
||||||
|
try {
|
||||||
|
return fetcher.get();
|
||||||
|
} catch (Exception e) {
|
||||||
|
lastException = e;
|
||||||
|
log.warn("[{}] {} fetch attempt {}/{} failed: {}",
|
||||||
|
syncId, resourceName, attempt, maxRetries, e.getMessage());
|
||||||
|
if (attempt < maxRetries) {
|
||||||
|
try {
|
||||||
|
Thread.sleep(retryInterval * attempt);
|
||||||
|
} catch (InterruptedException ie) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "同步被中断");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.warn("[{}] All {} fetch attempts for {} failed, returning null (degraded)",
|
||||||
|
syncId, maxRetries, resourceName, lastException);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据 userOrgMap 计算活跃的 Org source_id 集合(含 "未分配" 兜底组织)。
|
||||||
|
*/
|
||||||
|
private Set<String> buildActiveOrgSourceIds(Map<String, String> userOrgMap) {
|
||||||
|
Set<String> activeOrgSourceIds = new LinkedHashSet<>();
|
||||||
|
activeOrgSourceIds.add("org:unassigned");
|
||||||
|
for (String org : userOrgMap.values()) {
|
||||||
|
if (org != null && !org.isBlank()) {
|
||||||
|
activeOrgSourceIds.add("org:" + GraphSyncStepService.normalizeOrgCode(org.trim()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return activeOrgSourceIds;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 从所有实体类型中提取用户名。
|
* 从所有实体类型中提取用户名。
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ public class GraphSyncStepService {
|
|||||||
|
|
||||||
private static final String SOURCE_TYPE_SYNC = "SYNC";
|
private static final String SOURCE_TYPE_SYNC = "SYNC";
|
||||||
private static final String REL_TYPE = "RELATED_TO";
|
private static final String REL_TYPE = "RELATED_TO";
|
||||||
|
static final String DEFAULT_ORG_NAME = "未分配";
|
||||||
|
|
||||||
private final GraphEntityRepository entityRepository;
|
private final GraphEntityRepository entityRepository;
|
||||||
final Neo4jClient neo4jClient; // 改为包级别访问,供GraphSyncService使用
|
final Neo4jClient neo4jClient; // 改为包级别访问,供GraphSyncService使用
|
||||||
@@ -143,18 +144,35 @@ public class GraphSyncStepService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public SyncResult upsertOrgEntities(String graphId, String syncId) {
|
public SyncResult upsertOrgEntities(String graphId, Map<String, String> userOrgMap, String syncId) {
|
||||||
SyncResult result = beginResult("Org", syncId);
|
SyncResult result = beginResult("Org", syncId);
|
||||||
|
|
||||||
|
// 提取去重的组织名称;null/blank 归入 "未分配"
|
||||||
|
Set<String> orgNames = new LinkedHashSet<>();
|
||||||
|
orgNames.add(DEFAULT_ORG_NAME);
|
||||||
|
for (String org : userOrgMap.values()) {
|
||||||
|
if (org != null && !org.isBlank()) {
|
||||||
|
orgNames.add(org.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (String orgName : orgNames) {
|
||||||
try {
|
try {
|
||||||
|
String orgCode = normalizeOrgCode(orgName);
|
||||||
|
String sourceId = "org:" + orgCode;
|
||||||
Map<String, Object> props = new HashMap<>();
|
Map<String, Object> props = new HashMap<>();
|
||||||
props.put("org_code", "DEFAULT");
|
props.put("org_code", orgCode);
|
||||||
props.put("level", 1);
|
props.put("level", 1);
|
||||||
upsertEntity(graphId, "org:default", "Org", "默认组织",
|
|
||||||
"系统默认组织(待对接组织服务后更新)", props, result);
|
String description = DEFAULT_ORG_NAME.equals(orgName)
|
||||||
|
? "未分配组织(用户无组织信息时使用)"
|
||||||
|
: "组织:" + orgName;
|
||||||
|
|
||||||
|
upsertEntity(graphId, sourceId, "Org", orgName, description, props, result);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("[{}] Failed to upsert default org", syncId, e);
|
log.warn("[{}] Failed to upsert org: {}", syncId, orgName, e);
|
||||||
result.addError("org:default");
|
result.addError("org:" + orgName);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return endResult(result);
|
return endResult(result);
|
||||||
}
|
}
|
||||||
@@ -547,33 +565,52 @@ public class GraphSyncStepService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public SyncResult mergeBelongsToRelations(String graphId, String syncId) {
|
public SyncResult mergeBelongsToRelations(String graphId, Map<String, String> userOrgMap, String syncId) {
|
||||||
return mergeBelongsToRelations(graphId, syncId, null);
|
return mergeBelongsToRelations(graphId, userOrgMap, syncId, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public SyncResult mergeBelongsToRelations(String graphId, String syncId, Set<String> changedEntityIds) {
|
public SyncResult mergeBelongsToRelations(String graphId, Map<String, String> userOrgMap,
|
||||||
|
String syncId, Set<String> changedEntityIds) {
|
||||||
SyncResult result = beginResult("BELONGS_TO", syncId);
|
SyncResult result = beginResult("BELONGS_TO", syncId);
|
||||||
|
|
||||||
Optional<GraphEntity> defaultOrgOpt = entityRepository.findByGraphIdAndSourceIdAndType(
|
// 构建 org sourceId → entityId 映射
|
||||||
graphId, "org:default", "Org");
|
Map<String, String> orgMap = buildSourceIdToEntityIdMap(graphId, "Org");
|
||||||
if (defaultOrgOpt.isEmpty()) {
|
|
||||||
log.warn("[{}] Default org not found, skipping BELONGS_TO", syncId);
|
String unassignedOrgEntityId = orgMap.get("org:unassigned");
|
||||||
|
if (orgMap.isEmpty() || unassignedOrgEntityId == null) {
|
||||||
|
log.warn("[{}] No org entities found (or unassigned org missing), skipping BELONGS_TO", syncId);
|
||||||
result.addError("belongs_to:org_missing");
|
result.addError("belongs_to:org_missing");
|
||||||
return endResult(result);
|
return endResult(result);
|
||||||
}
|
}
|
||||||
String orgId = defaultOrgOpt.get().getId();
|
|
||||||
|
|
||||||
// User → Org
|
|
||||||
List<GraphEntity> users = entityRepository.findByGraphIdAndType(graphId, "User");
|
|
||||||
if (changedEntityIds != null) {
|
if (changedEntityIds != null) {
|
||||||
users = users.stream()
|
log.debug("[{}] BELONGS_TO rebuild ignores changedEntityIds(size={}) due to org map dependency",
|
||||||
.filter(user -> changedEntityIds.contains(user.getId()))
|
syncId, changedEntityIds.size());
|
||||||
.toList();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User → Org(通过 userOrgMap 查找对应组织)
|
||||||
|
List<GraphEntity> users = entityRepository.findByGraphIdAndType(graphId, "User");
|
||||||
|
|
||||||
|
// Dataset → Org(通过创建者的组织)
|
||||||
|
List<GraphEntity> datasets = entityRepository.findByGraphIdAndType(graphId, "Dataset");
|
||||||
|
|
||||||
|
// 删除受影响实体的旧 BELONGS_TO 关系,避免组织变更后遗留过时关系
|
||||||
|
Set<String> affectedEntityIds = new LinkedHashSet<>();
|
||||||
|
users.forEach(u -> affectedEntityIds.add(u.getId()));
|
||||||
|
datasets.forEach(d -> affectedEntityIds.add(d.getId()));
|
||||||
|
if (!affectedEntityIds.isEmpty()) {
|
||||||
|
deleteOutgoingRelations(graphId, "BELONGS_TO", affectedEntityIds, syncId);
|
||||||
|
}
|
||||||
|
|
||||||
for (GraphEntity user : users) {
|
for (GraphEntity user : users) {
|
||||||
try {
|
try {
|
||||||
boolean created = mergeRelation(graphId, user.getId(), orgId,
|
Object usernameObj = user.getProperties() != null ? user.getProperties().get("username") : null;
|
||||||
|
String username = usernameObj != null ? usernameObj.toString() : null;
|
||||||
|
|
||||||
|
String orgEntityId = resolveOrgEntityId(username, userOrgMap, orgMap, unassignedOrgEntityId);
|
||||||
|
|
||||||
|
boolean created = mergeRelation(graphId, user.getId(), orgEntityId,
|
||||||
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
|
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
|
||||||
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
|
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -582,16 +619,15 @@ public class GraphSyncStepService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dataset → Org
|
// Dataset → Org(通过创建者的组织)
|
||||||
List<GraphEntity> datasets = entityRepository.findByGraphIdAndType(graphId, "Dataset");
|
|
||||||
if (changedEntityIds != null) {
|
|
||||||
datasets = datasets.stream()
|
|
||||||
.filter(dataset -> changedEntityIds.contains(dataset.getId()))
|
|
||||||
.toList();
|
|
||||||
}
|
|
||||||
for (GraphEntity dataset : datasets) {
|
for (GraphEntity dataset : datasets) {
|
||||||
try {
|
try {
|
||||||
boolean created = mergeRelation(graphId, dataset.getId(), orgId,
|
Object createdByObj = dataset.getProperties() != null ? dataset.getProperties().get("created_by") : null;
|
||||||
|
String createdBy = createdByObj != null ? createdByObj.toString() : null;
|
||||||
|
|
||||||
|
String orgEntityId = resolveOrgEntityId(createdBy, userOrgMap, orgMap, unassignedOrgEntityId);
|
||||||
|
|
||||||
|
boolean created = mergeRelation(graphId, dataset.getId(), orgEntityId,
|
||||||
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
|
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
|
||||||
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
|
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -1236,4 +1272,56 @@ public class GraphSyncStepService {
|
|||||||
.filter(e -> e.getSourceId() != null)
|
.filter(e -> e.getSourceId() != null)
|
||||||
.collect(Collectors.toMap(GraphEntity::getSourceId, GraphEntity::getId, (a, b) -> a));
|
.collect(Collectors.toMap(GraphEntity::getSourceId, GraphEntity::getId, (a, b) -> a));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 组织名称转换为 source_id 片段。
|
||||||
|
* <p>
|
||||||
|
* 直接使用 trim 后的原始名称,避免归一化导致不同组织碰撞
|
||||||
|
* (如 "Org A" 和 "Org_A" 在 lowercase+regex 归一化下会合并为同一编码)。
|
||||||
|
* Neo4j 属性值支持任意 Unicode 字符串,无需额外编码。
|
||||||
|
*/
|
||||||
|
static String normalizeOrgCode(String orgName) {
|
||||||
|
if (DEFAULT_ORG_NAME.equals(orgName)) {
|
||||||
|
return "unassigned";
|
||||||
|
}
|
||||||
|
return orgName.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除指定实体的出向关系(按关系类型)。
|
||||||
|
* <p>
|
||||||
|
* 用于在重建 BELONGS_TO 等关系前清除旧关系,
|
||||||
|
* 确保组织变更等场景下不会遗留过时的关系。
|
||||||
|
*/
|
||||||
|
private void deleteOutgoingRelations(String graphId, String relationType,
|
||||||
|
Set<String> entityIds, String syncId) {
|
||||||
|
log.debug("[{}] Deleting existing {} relations for {} entities",
|
||||||
|
syncId, relationType, entityIds.size());
|
||||||
|
neo4jClient.query(
|
||||||
|
"MATCH (e:Entity {graph_id: $graphId})" +
|
||||||
|
"-[r:RELATED_TO {graph_id: $graphId, relation_type: $relationType}]->()" +
|
||||||
|
" WHERE e.id IN $entityIds DELETE r"
|
||||||
|
).bindAll(Map.of(
|
||||||
|
"graphId", graphId,
|
||||||
|
"relationType", relationType,
|
||||||
|
"entityIds", new ArrayList<>(entityIds)
|
||||||
|
)).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据用户名查找对应组织实体 ID,未找到时降级到未分配组织。
|
||||||
|
*/
|
||||||
|
private String resolveOrgEntityId(String username, Map<String, String> userOrgMap,
|
||||||
|
Map<String, String> orgMap, String unassignedOrgEntityId) {
|
||||||
|
if (username == null || username.isBlank()) {
|
||||||
|
return unassignedOrgEntityId;
|
||||||
|
}
|
||||||
|
String orgName = userOrgMap.get(username);
|
||||||
|
if (orgName == null || orgName.isBlank()) {
|
||||||
|
return unassignedOrgEntityId;
|
||||||
|
}
|
||||||
|
String orgCode = normalizeOrgCode(orgName.trim());
|
||||||
|
String orgEntityId = orgMap.get("org:" + orgCode);
|
||||||
|
return orgEntityId != null ? orgEntityId : unassignedOrgEntityId;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package com.datamate.knowledgegraph.application;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 索引健康检查服务。
|
||||||
|
* <p>
|
||||||
|
* 提供 Neo4j 索引状态查询,用于运维监控和启动验证。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class IndexHealthService {
|
||||||
|
|
||||||
|
private final Neo4jClient neo4jClient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有索引状态信息。
|
||||||
|
*
|
||||||
|
* @return 索引名称到状态的映射列表,每项包含 name, state, type, entityType, labelsOrTypes, properties
|
||||||
|
*/
|
||||||
|
public List<Map<String, Object>> getIndexStatus() {
|
||||||
|
return neo4jClient
|
||||||
|
.query("SHOW INDEXES YIELD name, state, type, entityType, labelsOrTypes, properties " +
|
||||||
|
"RETURN name, state, type, entityType, labelsOrTypes, properties " +
|
||||||
|
"ORDER BY name")
|
||||||
|
.fetchAs(Map.class)
|
||||||
|
.mappedBy((ts, record) -> {
|
||||||
|
Map<String, Object> info = new java.util.LinkedHashMap<>();
|
||||||
|
info.put("name", record.get("name").asString(null));
|
||||||
|
info.put("state", record.get("state").asString(null));
|
||||||
|
info.put("type", record.get("type").asString(null));
|
||||||
|
info.put("entityType", record.get("entityType").asString(null));
|
||||||
|
var labelsOrTypes = record.get("labelsOrTypes");
|
||||||
|
info.put("labelsOrTypes", labelsOrTypes.isNull() ? List.of() : labelsOrTypes.asList(v -> v.asString(null)));
|
||||||
|
var properties = record.get("properties");
|
||||||
|
info.put("properties", properties.isNull() ? List.of() : properties.asList(v -> v.asString(null)));
|
||||||
|
return info;
|
||||||
|
})
|
||||||
|
.all()
|
||||||
|
.stream()
|
||||||
|
.map(m -> (Map<String, Object>) m)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否存在非 ONLINE 状态的索引。
|
||||||
|
*
|
||||||
|
* @return true 表示所有索引健康(ONLINE 状态)
|
||||||
|
*/
|
||||||
|
public boolean allIndexesOnline() {
|
||||||
|
List<Map<String, Object>> indexes = getIndexStatus();
|
||||||
|
if (indexes.isEmpty()) {
|
||||||
|
log.warn("No indexes found in Neo4j database");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (Map<String, Object> idx : indexes) {
|
||||||
|
String state = (String) idx.get("state");
|
||||||
|
if (!"ONLINE".equals(state)) {
|
||||||
|
log.warn("Index '{}' is in state '{}' (expected ONLINE)", idx.get("name"), state);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取数据库统计信息(节点数、关系数)。
|
||||||
|
*
|
||||||
|
* @return 包含 nodeCount 和 relationshipCount 的映射
|
||||||
|
*/
|
||||||
|
public Map<String, Long> getDatabaseStats() {
|
||||||
|
Long nodeCount = neo4jClient
|
||||||
|
.query("MATCH (n:Entity) RETURN count(n) AS cnt")
|
||||||
|
.fetchAs(Long.class)
|
||||||
|
.mappedBy((ts, record) -> record.get("cnt").asLong())
|
||||||
|
.one()
|
||||||
|
.orElse(0L);
|
||||||
|
|
||||||
|
Long relCount = neo4jClient
|
||||||
|
.query("MATCH ()-[r:RELATED_TO]->() RETURN count(r) AS cnt")
|
||||||
|
.fetchAs(Long.class)
|
||||||
|
.mappedBy((ts, record) -> record.get("cnt").asLong())
|
||||||
|
.one()
|
||||||
|
.orElse(0L);
|
||||||
|
|
||||||
|
return Map.of("nodeCount", nodeCount, "relationshipCount", relCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package com.datamate.knowledgegraph.domain.model;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 知识图谱编辑审核记录。
|
||||||
|
* <p>
|
||||||
|
* 在 Neo4j 中作为 {@code EditReview} 节点存储,
|
||||||
|
* 记录实体/关系的增删改请求及审核状态。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class EditReview {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
|
||||||
|
/** 所属图谱 ID */
|
||||||
|
private String graphId;
|
||||||
|
|
||||||
|
/** 操作类型:CREATE_ENTITY, UPDATE_ENTITY, DELETE_ENTITY, BATCH_DELETE_ENTITY, CREATE_RELATION, UPDATE_RELATION, DELETE_RELATION, BATCH_DELETE_RELATION */
|
||||||
|
private String operationType;
|
||||||
|
|
||||||
|
/** 目标实体 ID(实体操作时非空) */
|
||||||
|
private String entityId;
|
||||||
|
|
||||||
|
/** 目标关系 ID(关系操作时非空) */
|
||||||
|
private String relationId;
|
||||||
|
|
||||||
|
/** 变更载荷(JSON 序列化的请求体) */
|
||||||
|
private String payload;
|
||||||
|
|
||||||
|
/** 审核状态:PENDING, APPROVED, REJECTED */
|
||||||
|
@Builder.Default
|
||||||
|
private String status = "PENDING";
|
||||||
|
|
||||||
|
/** 提交人 ID */
|
||||||
|
private String submittedBy;
|
||||||
|
|
||||||
|
/** 审核人 ID */
|
||||||
|
private String reviewedBy;
|
||||||
|
|
||||||
|
/** 审核意见 */
|
||||||
|
private String reviewComment;
|
||||||
|
|
||||||
|
private LocalDateTime createdAt;
|
||||||
|
|
||||||
|
private LocalDateTime reviewedAt;
|
||||||
|
}
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
package com.datamate.knowledgegraph.domain.repository;
|
||||||
|
|
||||||
|
import com.datamate.knowledgegraph.domain.model.EditReview;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.neo4j.driver.Value;
|
||||||
|
import org.neo4j.driver.types.MapAccessor;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 编辑审核仓储。
|
||||||
|
* <p>
|
||||||
|
* 使用 {@code Neo4jClient} 管理 {@code EditReview} 节点。
|
||||||
|
*/
|
||||||
|
@Repository
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class EditReviewRepository {
|
||||||
|
|
||||||
|
private final Neo4jClient neo4jClient;
|
||||||
|
|
||||||
|
public EditReview save(EditReview review) {
|
||||||
|
if (review.getId() == null) {
|
||||||
|
review.setId(UUID.randomUUID().toString());
|
||||||
|
}
|
||||||
|
if (review.getCreatedAt() == null) {
|
||||||
|
review.setCreatedAt(LocalDateTime.now());
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("id", review.getId());
|
||||||
|
params.put("graphId", review.getGraphId());
|
||||||
|
params.put("operationType", review.getOperationType());
|
||||||
|
params.put("entityId", review.getEntityId() != null ? review.getEntityId() : "");
|
||||||
|
params.put("relationId", review.getRelationId() != null ? review.getRelationId() : "");
|
||||||
|
params.put("payload", review.getPayload() != null ? review.getPayload() : "");
|
||||||
|
params.put("status", review.getStatus());
|
||||||
|
params.put("submittedBy", review.getSubmittedBy() != null ? review.getSubmittedBy() : "");
|
||||||
|
params.put("reviewedBy", review.getReviewedBy() != null ? review.getReviewedBy() : "");
|
||||||
|
params.put("reviewComment", review.getReviewComment() != null ? review.getReviewComment() : "");
|
||||||
|
params.put("createdAt", review.getCreatedAt());
|
||||||
|
|
||||||
|
// reviewed_at 为 null 时(PENDING 状态)不写入 SET,避免 null 参数导致属性缺失
|
||||||
|
String reviewedAtSet = "";
|
||||||
|
if (review.getReviewedAt() != null) {
|
||||||
|
reviewedAtSet = ", r.reviewed_at = $reviewedAt";
|
||||||
|
params.put("reviewedAt", review.getReviewedAt());
|
||||||
|
}
|
||||||
|
|
||||||
|
neo4jClient
|
||||||
|
.query(
|
||||||
|
"MERGE (r:EditReview {id: $id}) " +
|
||||||
|
"SET r.graph_id = $graphId, " +
|
||||||
|
" r.operation_type = $operationType, " +
|
||||||
|
" r.entity_id = $entityId, " +
|
||||||
|
" r.relation_id = $relationId, " +
|
||||||
|
" r.payload = $payload, " +
|
||||||
|
" r.status = $status, " +
|
||||||
|
" r.submitted_by = $submittedBy, " +
|
||||||
|
" r.reviewed_by = $reviewedBy, " +
|
||||||
|
" r.review_comment = $reviewComment, " +
|
||||||
|
" r.created_at = $createdAt" +
|
||||||
|
reviewedAtSet + " " +
|
||||||
|
"RETURN r"
|
||||||
|
)
|
||||||
|
.bindAll(params)
|
||||||
|
.run();
|
||||||
|
|
||||||
|
return review;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<EditReview> findById(String reviewId, String graphId) {
|
||||||
|
return neo4jClient
|
||||||
|
.query("MATCH (r:EditReview {id: $id, graph_id: $graphId}) RETURN r")
|
||||||
|
.bindAll(Map.of("id", reviewId, "graphId", graphId))
|
||||||
|
.fetchAs(EditReview.class)
|
||||||
|
.mappedBy((typeSystem, record) -> mapRecord(record))
|
||||||
|
.one();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<EditReview> findPendingByGraphId(String graphId, long skip, int size) {
|
||||||
|
return neo4jClient
|
||||||
|
.query(
|
||||||
|
"MATCH (r:EditReview {graph_id: $graphId, status: 'PENDING'}) " +
|
||||||
|
"RETURN r ORDER BY r.created_at DESC SKIP $skip LIMIT $size"
|
||||||
|
)
|
||||||
|
.bindAll(Map.of("graphId", graphId, "skip", skip, "size", size))
|
||||||
|
.fetchAs(EditReview.class)
|
||||||
|
.mappedBy((typeSystem, record) -> mapRecord(record))
|
||||||
|
.all()
|
||||||
|
.stream().toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long countPendingByGraphId(String graphId) {
|
||||||
|
return neo4jClient
|
||||||
|
.query("MATCH (r:EditReview {graph_id: $graphId, status: 'PENDING'}) RETURN count(r) AS cnt")
|
||||||
|
.bindAll(Map.of("graphId", graphId))
|
||||||
|
.fetchAs(Long.class)
|
||||||
|
.mappedBy((typeSystem, record) -> record.get("cnt").asLong())
|
||||||
|
.one()
|
||||||
|
.orElse(0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<EditReview> findByGraphId(String graphId, String status, long skip, int size) {
|
||||||
|
String statusFilter = (status != null && !status.isBlank())
|
||||||
|
? "AND r.status = $status "
|
||||||
|
: "";
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("graphId", graphId);
|
||||||
|
params.put("status", status != null ? status : "");
|
||||||
|
params.put("skip", skip);
|
||||||
|
params.put("size", size);
|
||||||
|
|
||||||
|
return neo4jClient
|
||||||
|
.query(
|
||||||
|
"MATCH (r:EditReview {graph_id: $graphId}) " +
|
||||||
|
"WHERE true " + statusFilter +
|
||||||
|
"RETURN r ORDER BY r.created_at DESC SKIP $skip LIMIT $size"
|
||||||
|
)
|
||||||
|
.bindAll(params)
|
||||||
|
.fetchAs(EditReview.class)
|
||||||
|
.mappedBy((typeSystem, record) -> mapRecord(record))
|
||||||
|
.all()
|
||||||
|
.stream().toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long countByGraphId(String graphId, String status) {
|
||||||
|
String statusFilter = (status != null && !status.isBlank())
|
||||||
|
? "AND r.status = $status "
|
||||||
|
: "";
|
||||||
|
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("graphId", graphId);
|
||||||
|
params.put("status", status != null ? status : "");
|
||||||
|
|
||||||
|
return neo4jClient
|
||||||
|
.query(
|
||||||
|
"MATCH (r:EditReview {graph_id: $graphId}) " +
|
||||||
|
"WHERE true " + statusFilter +
|
||||||
|
"RETURN count(r) AS cnt"
|
||||||
|
)
|
||||||
|
.bindAll(params)
|
||||||
|
.fetchAs(Long.class)
|
||||||
|
.mappedBy((typeSystem, record) -> record.get("cnt").asLong())
|
||||||
|
.one()
|
||||||
|
.orElse(0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 内部映射
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
private EditReview mapRecord(MapAccessor record) {
|
||||||
|
Value r = record.get("r");
|
||||||
|
|
||||||
|
return EditReview.builder()
|
||||||
|
.id(getStringOrNull(r, "id"))
|
||||||
|
.graphId(getStringOrNull(r, "graph_id"))
|
||||||
|
.operationType(getStringOrNull(r, "operation_type"))
|
||||||
|
.entityId(getStringOrEmpty(r, "entity_id"))
|
||||||
|
.relationId(getStringOrEmpty(r, "relation_id"))
|
||||||
|
.payload(getStringOrNull(r, "payload"))
|
||||||
|
.status(getStringOrNull(r, "status"))
|
||||||
|
.submittedBy(getStringOrEmpty(r, "submitted_by"))
|
||||||
|
.reviewedBy(getStringOrEmpty(r, "reviewed_by"))
|
||||||
|
.reviewComment(getStringOrEmpty(r, "review_comment"))
|
||||||
|
.createdAt(getLocalDateTimeOrNull(r, "created_at"))
|
||||||
|
.reviewedAt(getLocalDateTimeOrNull(r, "reviewed_at"))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getStringOrNull(Value value, String key) {
|
||||||
|
Value v = value.get(key);
|
||||||
|
return (v == null || v.isNull()) ? null : v.asString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getStringOrEmpty(Value value, String key) {
|
||||||
|
Value v = value.get(key);
|
||||||
|
if (v == null || v.isNull()) return null;
|
||||||
|
String s = v.asString();
|
||||||
|
return s.isEmpty() ? null : s;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static LocalDateTime getLocalDateTimeOrNull(Value value, String key) {
|
||||||
|
Value v = value.get(key);
|
||||||
|
return (v == null || v.isNull()) ? null : v.asLocalDateTime();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.cache;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.beans.factory.annotation.Qualifier;
|
||||||
|
import org.springframework.cache.Cache;
|
||||||
|
import org.springframework.cache.CacheManager;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 图谱缓存管理服务。
|
||||||
|
* <p>
|
||||||
|
* 提供缓存失效操作,在写操作(增删改)后由 Service 层调用,
|
||||||
|
* 确保缓存与数据库的最终一致性。
|
||||||
|
* <p>
|
||||||
|
* 当 {@link StringRedisTemplate} 可用时,使用按 graphId 前缀的细粒度失效,
|
||||||
|
* 避免跨图谱缓存刷新;否则退化为清空整个缓存区域。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class GraphCacheService {
|
||||||
|
|
||||||
|
private static final String KEY_PREFIX = "datamate:";
|
||||||
|
|
||||||
|
private final CacheManager cacheManager;
|
||||||
|
private StringRedisTemplate redisTemplate;
|
||||||
|
|
||||||
|
public GraphCacheService(@Qualifier("knowledgeGraphCacheManager") CacheManager cacheManager) {
|
||||||
|
this.cacheManager = cacheManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Autowired(required = false)
|
||||||
|
public void setRedisTemplate(StringRedisTemplate redisTemplate) {
|
||||||
|
this.redisTemplate = redisTemplate;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 失效指定图谱的全部缓存。
|
||||||
|
* <p>
|
||||||
|
* 在 sync、批量操作后调用,确保缓存一致性。
|
||||||
|
* 当 Redis 可用时仅失效该 graphId 的缓存条目,避免影响其他图谱。
|
||||||
|
*/
|
||||||
|
public void evictGraphCaches(String graphId) {
|
||||||
|
log.debug("Evicting all caches for graph_id={}", graphId);
|
||||||
|
evictByGraphPrefix(RedisCacheConfig.CACHE_ENTITIES, graphId);
|
||||||
|
evictByGraphPrefix(RedisCacheConfig.CACHE_QUERIES, graphId);
|
||||||
|
evictByGraphPrefix(RedisCacheConfig.CACHE_SEARCH, graphId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 失效指定实体相关的缓存。
|
||||||
|
* <p>
|
||||||
|
* 在单实体增删改后调用。精确失效该实体缓存和 list 缓存,
|
||||||
|
* 并清除该图谱的查询缓存(因邻居关系可能变化)。
|
||||||
|
*/
|
||||||
|
public void evictEntityCaches(String graphId, String entityId) {
|
||||||
|
log.debug("Evicting entity caches: graph_id={}, entity_id={}", graphId, entityId);
|
||||||
|
// 精确失效具体实体和 list 缓存
|
||||||
|
evictKey(RedisCacheConfig.CACHE_ENTITIES, cacheKey(graphId, entityId));
|
||||||
|
evictKey(RedisCacheConfig.CACHE_ENTITIES, cacheKey(graphId, "list"));
|
||||||
|
// 按 graphId 前缀失效查询缓存
|
||||||
|
evictByGraphPrefix(RedisCacheConfig.CACHE_QUERIES, graphId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 失效指定图谱的搜索缓存。
|
||||||
|
* <p>
|
||||||
|
* 在实体名称/描述变更后调用。
|
||||||
|
*/
|
||||||
|
public void evictSearchCaches(String graphId) {
|
||||||
|
log.debug("Evicting search caches for graph_id={}", graphId);
|
||||||
|
evictByGraphPrefix(RedisCacheConfig.CACHE_SEARCH, graphId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 失效所有搜索缓存(无 graphId 上下文时使用)。
|
||||||
|
*/
|
||||||
|
public void evictSearchCaches() {
|
||||||
|
log.debug("Evicting all search caches");
|
||||||
|
evictCache(RedisCacheConfig.CACHE_SEARCH);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 内部方法
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 按 graphId 前缀失效缓存条目。
|
||||||
|
* <p>
|
||||||
|
* 所有缓存 key 均以 {@code graphId:} 开头,因此可通过前缀模式匹配。
|
||||||
|
* 当 Redis 不可用时退化为清空整个缓存区域。
|
||||||
|
*/
|
||||||
|
private void evictByGraphPrefix(String cacheName, String graphId) {
|
||||||
|
if (redisTemplate != null) {
|
||||||
|
try {
|
||||||
|
String pattern = KEY_PREFIX + cacheName + "::" + graphId + ":*";
|
||||||
|
Set<String> keys = redisTemplate.keys(pattern);
|
||||||
|
if (keys != null && !keys.isEmpty()) {
|
||||||
|
redisTemplate.delete(keys);
|
||||||
|
log.debug("Evicted {} keys for graph_id={} in cache={}", keys.size(), graphId, cacheName);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Failed to evict by graph prefix, falling back to full cache clear: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 降级:清空整个缓存区域
|
||||||
|
evictCache(cacheName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 精确失效单个缓存条目。
|
||||||
|
*/
|
||||||
|
private void evictKey(String cacheName, String key) {
|
||||||
|
Cache cache = cacheManager.getCache(cacheName);
|
||||||
|
if (cache != null) {
|
||||||
|
cache.evict(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清空整个缓存区域。
|
||||||
|
*/
|
||||||
|
private void evictCache(String cacheName) {
|
||||||
|
Cache cache = cacheManager.getCache(cacheName);
|
||||||
|
if (cache != null) {
|
||||||
|
cache.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成缓存 key。
|
||||||
|
* <p>
|
||||||
|
* 将多个参数拼接为冒号分隔的字符串 key,用于 {@code @Cacheable} 的 key 表达式。
|
||||||
|
* <b>约定</b>:graphId 必须作为第一个参数,以支持按 graphId 前缀失效。
|
||||||
|
*/
|
||||||
|
public static String cacheKey(Object... parts) {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
for (int i = 0; i < parts.length; i++) {
|
||||||
|
if (i > 0) sb.append(':');
|
||||||
|
sb.append(Objects.toString(parts[i], "null"));
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.cache;
|
||||||
|
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
|
import org.springframework.cache.CacheManager;
|
||||||
|
import org.springframework.cache.annotation.EnableCaching;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.context.annotation.Primary;
|
||||||
|
import org.springframework.data.redis.cache.RedisCacheConfiguration;
|
||||||
|
import org.springframework.data.redis.cache.RedisCacheManager;
|
||||||
|
import org.springframework.data.redis.connection.RedisConnectionFactory;
|
||||||
|
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
|
||||||
|
import org.springframework.data.redis.serializer.RedisSerializationContext;
|
||||||
|
import org.springframework.data.redis.serializer.StringRedisSerializer;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Redis 缓存配置。
|
||||||
|
* <p>
|
||||||
|
* 当 {@code datamate.knowledge-graph.cache.enabled=true} 时激活,
|
||||||
|
* 为不同缓存区域配置独立的 TTL。
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Configuration
|
||||||
|
@EnableCaching
|
||||||
|
@ConditionalOnProperty(
|
||||||
|
prefix = "datamate.knowledge-graph.cache",
|
||||||
|
name = "enabled",
|
||||||
|
havingValue = "true",
|
||||||
|
matchIfMissing = true
|
||||||
|
)
|
||||||
|
public class RedisCacheConfig {
|
||||||
|
|
||||||
|
/** 实体缓存:单实体查询、实体列表 */
|
||||||
|
public static final String CACHE_ENTITIES = "kg:entities";
|
||||||
|
|
||||||
|
/** 查询缓存:邻居图、子图、路径查询 */
|
||||||
|
public static final String CACHE_QUERIES = "kg:queries";
|
||||||
|
|
||||||
|
/** 搜索缓存:全文搜索结果 */
|
||||||
|
public static final String CACHE_SEARCH = "kg:search";
|
||||||
|
|
||||||
|
@Primary
|
||||||
|
@Bean("knowledgeGraphCacheManager")
|
||||||
|
public CacheManager knowledgeGraphCacheManager(
|
||||||
|
RedisConnectionFactory connectionFactory,
|
||||||
|
KnowledgeGraphProperties properties
|
||||||
|
) {
|
||||||
|
KnowledgeGraphProperties.Cache cacheProps = properties.getCache();
|
||||||
|
|
||||||
|
// JSON 序列化,确保缓存数据可读且兼容版本变更
|
||||||
|
var jsonSerializer = new GenericJackson2JsonRedisSerializer();
|
||||||
|
var serializationPair = RedisSerializationContext.SerializationPair.fromSerializer(jsonSerializer);
|
||||||
|
|
||||||
|
RedisCacheConfiguration defaultConfig = RedisCacheConfiguration.defaultCacheConfig()
|
||||||
|
.serializeKeysWith(RedisSerializationContext.SerializationPair.fromSerializer(new StringRedisSerializer()))
|
||||||
|
.serializeValuesWith(serializationPair)
|
||||||
|
.disableCachingNullValues()
|
||||||
|
.prefixCacheNameWith("datamate:");
|
||||||
|
|
||||||
|
// 各缓存区域独立 TTL
|
||||||
|
Map<String, RedisCacheConfiguration> cacheConfigs = Map.of(
|
||||||
|
CACHE_ENTITIES, defaultConfig.entryTtl(Duration.ofSeconds(cacheProps.getEntityTtlSeconds())),
|
||||||
|
CACHE_QUERIES, defaultConfig.entryTtl(Duration.ofSeconds(cacheProps.getQueryTtlSeconds())),
|
||||||
|
CACHE_SEARCH, defaultConfig.entryTtl(Duration.ofSeconds(cacheProps.getSearchTtlSeconds()))
|
||||||
|
);
|
||||||
|
|
||||||
|
log.info("Redis cache enabled: entity TTL={}s, query TTL={}s, search TTL={}s",
|
||||||
|
cacheProps.getEntityTtlSeconds(),
|
||||||
|
cacheProps.getQueryTtlSeconds(),
|
||||||
|
cacheProps.getSearchTtlSeconds());
|
||||||
|
|
||||||
|
return RedisCacheManager.builder(connectionFactory)
|
||||||
|
.cacheDefaults(defaultConfig.entryTtl(Duration.ofSeconds(cacheProps.getQueryTtlSeconds())))
|
||||||
|
.withInitialCacheConfigurations(cacheConfigs)
|
||||||
|
.transactionAware()
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -204,6 +204,37 @@ public class DataManagementClient {
|
|||||||
"knowledge-sets");
|
"knowledge-sets");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 拉取所有用户的组织映射。
|
||||||
|
*/
|
||||||
|
public Map<String, String> fetchUserOrganizationMap() {
|
||||||
|
String url = baseUrl + "/auth/users/organizations";
|
||||||
|
log.debug("Fetching user-organization mappings from: {}", url);
|
||||||
|
try {
|
||||||
|
ResponseEntity<List<UserOrgDTO>> response = restTemplate.exchange(
|
||||||
|
url, HttpMethod.GET, null,
|
||||||
|
new ParameterizedTypeReference<List<UserOrgDTO>>() {});
|
||||||
|
|
||||||
|
List<UserOrgDTO> body = response.getBody();
|
||||||
|
if (body == null || body.isEmpty()) {
|
||||||
|
log.warn("No user-organization mappings returned from auth service");
|
||||||
|
return Collections.emptyMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, String> result = new LinkedHashMap<>();
|
||||||
|
for (UserOrgDTO dto : body) {
|
||||||
|
if (dto.getUsername() != null && !dto.getUsername().isBlank()) {
|
||||||
|
result.put(dto.getUsername(), dto.getOrganization());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.info("Fetched {} user-organization mappings", result.size());
|
||||||
|
return result;
|
||||||
|
} catch (RestClientException e) {
|
||||||
|
log.error("Failed to fetch user-organization mappings from: {}", url, e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 通用自动分页拉取方法。
|
* 通用自动分页拉取方法。
|
||||||
*/
|
*/
|
||||||
@@ -459,4 +490,14 @@ public class DataManagementClient {
|
|||||||
/** 来源数据集 ID 列表(SOURCED_FROM 关系) */
|
/** 来源数据集 ID 列表(SOURCED_FROM 关系) */
|
||||||
private List<String> sourceDatasetIds;
|
private List<String> sourceDatasetIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户-组织映射 DTO(与 AuthController.listUserOrganizations 对齐)。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@JsonIgnoreProperties(ignoreUnknown = true)
|
||||||
|
public static class UserOrgDTO {
|
||||||
|
private String username;
|
||||||
|
private String organization;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,13 @@ public enum KnowledgeGraphErrorCode implements ErrorCode {
|
|||||||
EMPTY_SNAPSHOT_PURGE_BLOCKED("knowledge_graph.0010", "空快照保护:上游返回空列表,已阻止 purge 操作"),
|
EMPTY_SNAPSHOT_PURGE_BLOCKED("knowledge_graph.0010", "空快照保护:上游返回空列表,已阻止 purge 操作"),
|
||||||
SCHEMA_INIT_FAILED("knowledge_graph.0011", "图谱 Schema 初始化失败"),
|
SCHEMA_INIT_FAILED("knowledge_graph.0011", "图谱 Schema 初始化失败"),
|
||||||
INSECURE_DEFAULT_CREDENTIALS("knowledge_graph.0012", "检测到默认凭据,生产环境禁止使用默认密码"),
|
INSECURE_DEFAULT_CREDENTIALS("knowledge_graph.0012", "检测到默认凭据,生产环境禁止使用默认密码"),
|
||||||
UNAUTHORIZED_INTERNAL_CALL("knowledge_graph.0013", "内部调用未授权:X-Internal-Token 校验失败");
|
UNAUTHORIZED_INTERNAL_CALL("knowledge_graph.0013", "内部调用未授权:X-Internal-Token 校验失败"),
|
||||||
|
QUERY_TIMEOUT("knowledge_graph.0014", "图查询超时,请缩小搜索范围或减少深度"),
|
||||||
|
SCHEMA_MIGRATION_FAILED("knowledge_graph.0015", "Schema 迁移执行失败"),
|
||||||
|
SCHEMA_CHECKSUM_MISMATCH("knowledge_graph.0016", "Schema 迁移 checksum 不匹配:已应用的迁移被修改"),
|
||||||
|
SCHEMA_MIGRATION_LOCKED("knowledge_graph.0017", "Schema 迁移锁被占用,其他实例正在执行迁移"),
|
||||||
|
REVIEW_NOT_FOUND("knowledge_graph.0018", "审核记录不存在"),
|
||||||
|
REVIEW_ALREADY_PROCESSED("knowledge_graph.0019", "审核记录已处理");
|
||||||
|
|
||||||
private final String code;
|
private final String code;
|
||||||
private final String message;
|
private final String message;
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
package com.datamate.knowledgegraph.infrastructure.neo4j;
|
package com.datamate.knowledgegraph.infrastructure.neo4j;
|
||||||
|
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.migration.SchemaMigrationService;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.boot.ApplicationArguments;
|
import org.springframework.boot.ApplicationArguments;
|
||||||
import org.springframework.boot.ApplicationRunner;
|
import org.springframework.boot.ApplicationRunner;
|
||||||
import org.springframework.core.annotation.Order;
|
import org.springframework.core.annotation.Order;
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient;
|
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 图谱 Schema 初始化器。
|
* 图谱 Schema 初始化器。
|
||||||
* <p>
|
* <p>
|
||||||
* 应用启动时自动创建 Neo4j 索引和约束。
|
* 应用启动时通过 {@link SchemaMigrationService} 执行版本化 Schema 迁移。
|
||||||
* 所有语句使用 {@code IF NOT EXISTS},保证幂等性。
|
|
||||||
* <p>
|
|
||||||
* 对应 {@code docs/knowledge-graph/schema/schema.cypher} 中的第 1-3 部分。
|
|
||||||
* <p>
|
* <p>
|
||||||
* <b>安全自检</b>:在非开发环境中,检测到默认 Neo4j 密码时拒绝启动。
|
* <b>安全自检</b>:在非开发环境中,检测到默认 Neo4j 密码时拒绝启动。
|
||||||
*/
|
*/
|
||||||
@@ -33,13 +30,8 @@ public class GraphInitializer implements ApplicationRunner {
|
|||||||
"datamate123", "neo4j", "password", "123456", "admin"
|
"datamate123", "neo4j", "password", "123456", "admin"
|
||||||
);
|
);
|
||||||
|
|
||||||
/** 仅识别「已存在」类错误消息的关键词,其余错误不应吞掉。 */
|
|
||||||
private static final Set<String> ALREADY_EXISTS_KEYWORDS = Set.of(
|
|
||||||
"already exists", "already exist", "EquivalentSchemaRuleAlreadyExists"
|
|
||||||
);
|
|
||||||
|
|
||||||
private final Neo4jClient neo4jClient;
|
|
||||||
private final KnowledgeGraphProperties properties;
|
private final KnowledgeGraphProperties properties;
|
||||||
|
private final SchemaMigrationService schemaMigrationService;
|
||||||
|
|
||||||
@Value("${spring.neo4j.authentication.password:}")
|
@Value("${spring.neo4j.authentication.password:}")
|
||||||
private String neo4jPassword;
|
private String neo4jPassword;
|
||||||
@@ -47,122 +39,25 @@ public class GraphInitializer implements ApplicationRunner {
|
|||||||
@Value("${spring.profiles.active:default}")
|
@Value("${spring.profiles.active:default}")
|
||||||
private String activeProfile;
|
private String activeProfile;
|
||||||
|
|
||||||
/**
|
|
||||||
* 需要在启动时执行的 Cypher 语句。
|
|
||||||
* 每条语句必须独立执行(Neo4j 不支持多条 DDL 在同一事务中)。
|
|
||||||
*/
|
|
||||||
private static final List<String> SCHEMA_STATEMENTS = List.of(
|
|
||||||
// 约束(自动创建对应索引)
|
|
||||||
"CREATE CONSTRAINT entity_id_unique IF NOT EXISTS FOR (n:Entity) REQUIRE n.id IS UNIQUE",
|
|
||||||
|
|
||||||
// 同步 upsert 复合唯一约束:防止并发写入产生重复实体
|
|
||||||
"CREATE CONSTRAINT entity_sync_unique IF NOT EXISTS " +
|
|
||||||
"FOR (n:Entity) REQUIRE (n.graph_id, n.source_id, n.type) IS UNIQUE",
|
|
||||||
|
|
||||||
// 单字段索引
|
|
||||||
"CREATE INDEX entity_graph_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id)",
|
|
||||||
"CREATE INDEX entity_type IF NOT EXISTS FOR (n:Entity) ON (n.type)",
|
|
||||||
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
|
||||||
"CREATE INDEX entity_source_id IF NOT EXISTS FOR (n:Entity) ON (n.source_id)",
|
|
||||||
"CREATE INDEX entity_created_at IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
|
||||||
|
|
||||||
// 复合索引
|
|
||||||
"CREATE INDEX entity_graph_id_type IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.type)",
|
|
||||||
"CREATE INDEX entity_graph_id_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.id)",
|
|
||||||
"CREATE INDEX entity_graph_id_source_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.source_id)",
|
|
||||||
|
|
||||||
// 全文索引
|
|
||||||
"CREATE FULLTEXT INDEX entity_fulltext IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.description]",
|
|
||||||
|
|
||||||
// ── SyncHistory 约束和索引 ──
|
|
||||||
|
|
||||||
// P1: syncId 唯一约束,防止 ID 碰撞
|
|
||||||
"CREATE CONSTRAINT sync_history_graph_sync_unique IF NOT EXISTS " +
|
|
||||||
"FOR (h:SyncHistory) REQUIRE (h.graph_id, h.sync_id) IS UNIQUE",
|
|
||||||
|
|
||||||
// P2-3: 查询优化索引
|
|
||||||
"CREATE INDEX sync_history_graph_started IF NOT EXISTS " +
|
|
||||||
"FOR (h:SyncHistory) ON (h.graph_id, h.started_at)",
|
|
||||||
|
|
||||||
"CREATE INDEX sync_history_graph_status_started IF NOT EXISTS " +
|
|
||||||
"FOR (h:SyncHistory) ON (h.graph_id, h.status, h.started_at)"
|
|
||||||
);
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(ApplicationArguments args) {
|
public void run(ApplicationArguments args) {
|
||||||
// ── 安全自检:默认凭据检测 ──
|
// ── 安全自检:默认凭据检测(已禁用) ──
|
||||||
validateCredentials();
|
// validateCredentials();
|
||||||
|
|
||||||
if (!properties.getSync().isAutoInitSchema()) {
|
if (!properties.getSync().isAutoInitSchema()) {
|
||||||
log.info("Schema auto-init is disabled, skipping");
|
log.info("Schema auto-init is disabled, skipping");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Initializing Neo4j schema: {} statements to execute", SCHEMA_STATEMENTS.size());
|
schemaMigrationService.migrate(UUID.randomUUID().toString());
|
||||||
|
|
||||||
int succeeded = 0;
|
|
||||||
int failed = 0;
|
|
||||||
|
|
||||||
for (String statement : SCHEMA_STATEMENTS) {
|
|
||||||
try {
|
|
||||||
neo4jClient.query(statement).run();
|
|
||||||
succeeded++;
|
|
||||||
log.debug("Schema statement executed: {}", truncate(statement));
|
|
||||||
} catch (Exception e) {
|
|
||||||
if (isAlreadyExistsError(e)) {
|
|
||||||
// 约束/索引已存在,安全跳过
|
|
||||||
succeeded++;
|
|
||||||
log.debug("Schema element already exists (safe to skip): {}", truncate(statement));
|
|
||||||
} else {
|
|
||||||
// 非「已存在」错误:记录并抛出,阻止启动
|
|
||||||
failed++;
|
|
||||||
log.error("Schema statement FAILED: {} — {}", truncate(statement), e.getMessage());
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Neo4j schema initialization failed: " + truncate(statement), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("Neo4j schema initialization completed: succeeded={}, failed={}", succeeded, failed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 检测是否使用了默认凭据。
|
* 检测是否使用了默认凭据。
|
||||||
* <p>
|
* <p>
|
||||||
* 在 dev/test 环境中仅发出警告,在其他环境(prod、staging 等)中直接拒绝启动。
|
* <b>注意:密码安全检查已禁用。</b>
|
||||||
*/
|
*/
|
||||||
private void validateCredentials() {
|
private void validateCredentials() {
|
||||||
if (neo4jPassword == null || neo4jPassword.isBlank()) {
|
// 密码安全检查已禁用,开发环境跳过
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (BLOCKED_DEFAULT_PASSWORDS.contains(neo4jPassword)) {
|
|
||||||
boolean isDev = activeProfile.contains("dev") || activeProfile.contains("test")
|
|
||||||
|| activeProfile.contains("local");
|
|
||||||
if (isDev) {
|
|
||||||
log.warn("⚠ Neo4j is using a WEAK DEFAULT password. "
|
|
||||||
+ "This is acceptable in dev/test but MUST be changed for production.");
|
|
||||||
} else {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"SECURITY: Neo4j password is set to a known default ('" + neo4jPassword + "'). "
|
|
||||||
+ "Production environments MUST use a strong, unique password. "
|
|
||||||
+ "Set the NEO4J_PASSWORD environment variable to a secure value.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 判断异常是否仅因为 Schema 元素已存在(安全可忽略)。
|
|
||||||
*/
|
|
||||||
private static boolean isAlreadyExistsError(Exception e) {
|
|
||||||
String msg = e.getMessage();
|
|
||||||
if (msg == null) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
String lowerMsg = msg.toLowerCase();
|
|
||||||
return ALREADY_EXISTS_KEYWORDS.stream().anyMatch(kw -> lowerMsg.contains(kw.toLowerCase()));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String truncate(String s) {
|
|
||||||
return s.length() <= 100 ? s : s.substring(0, 97) + "...";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ public class KnowledgeGraphProperties {
|
|||||||
/** 子图返回最大节点数 */
|
/** 子图返回最大节点数 */
|
||||||
private int maxNodesPerQuery = 500;
|
private int maxNodesPerQuery = 500;
|
||||||
|
|
||||||
|
/** 复杂图查询超时(秒),防止路径枚举等高开销查询失控 */
|
||||||
|
@Min(value = 1, message = "queryTimeoutSeconds 必须 >= 1")
|
||||||
|
private int queryTimeoutSeconds = 10;
|
||||||
|
|
||||||
/** 批量导入批次大小(必须 >= 1,否则取模运算会抛异常) */
|
/** 批量导入批次大小(必须 >= 1,否则取模运算会抛异常) */
|
||||||
@Min(value = 1, message = "importBatchSize 必须 >= 1")
|
@Min(value = 1, message = "importBatchSize 必须 >= 1")
|
||||||
private int importBatchSize = 100;
|
private int importBatchSize = 100;
|
||||||
@@ -28,6 +32,12 @@ public class KnowledgeGraphProperties {
|
|||||||
/** 安全相关配置 */
|
/** 安全相关配置 */
|
||||||
private Security security = new Security();
|
private Security security = new Security();
|
||||||
|
|
||||||
|
/** Schema 迁移配置 */
|
||||||
|
private Migration migration = new Migration();
|
||||||
|
|
||||||
|
/** 缓存配置 */
|
||||||
|
private Cache cache = new Cache();
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public static class Security {
|
public static class Security {
|
||||||
|
|
||||||
@@ -47,10 +57,10 @@ public class KnowledgeGraphProperties {
|
|||||||
public static class Sync {
|
public static class Sync {
|
||||||
|
|
||||||
/** 数据管理服务基础 URL */
|
/** 数据管理服务基础 URL */
|
||||||
private String dataManagementUrl = "http://localhost:8080";
|
private String dataManagementUrl = "http://localhost:8080/api";
|
||||||
|
|
||||||
/** 标注服务基础 URL */
|
/** 标注服务基础 URL */
|
||||||
private String annotationServiceUrl = "http://localhost:8081";
|
private String annotationServiceUrl = "http://localhost:8080/api";
|
||||||
|
|
||||||
/** 同步每页拉取数量 */
|
/** 同步每页拉取数量 */
|
||||||
private int pageSize = 200;
|
private int pageSize = 200;
|
||||||
@@ -78,4 +88,30 @@ public class KnowledgeGraphProperties {
|
|||||||
*/
|
*/
|
||||||
private boolean allowPurgeOnEmptySnapshot = false;
|
private boolean allowPurgeOnEmptySnapshot = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Migration {
|
||||||
|
|
||||||
|
/** 是否启用 Schema 版本化迁移 */
|
||||||
|
private boolean enabled = true;
|
||||||
|
|
||||||
|
/** 是否校验已应用迁移的 checksum(防止迁移被篡改) */
|
||||||
|
private boolean validateChecksums = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class Cache {
|
||||||
|
|
||||||
|
/** 是否启用缓存 */
|
||||||
|
private boolean enabled = true;
|
||||||
|
|
||||||
|
/** 实体缓存 TTL(秒) */
|
||||||
|
private long entityTtlSeconds = 3600;
|
||||||
|
|
||||||
|
/** 查询结果缓存 TTL(秒) */
|
||||||
|
private long queryTtlSeconds = 300;
|
||||||
|
|
||||||
|
/** 全文搜索结果缓存 TTL(秒) */
|
||||||
|
private long searchTtlSeconds = 180;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schema 迁移接口。
|
||||||
|
* <p>
|
||||||
|
* 每个实现类代表一个版本化的 Schema 变更,版本号单调递增。
|
||||||
|
*/
|
||||||
|
public interface SchemaMigration {
|
||||||
|
|
||||||
|
/** 单调递增版本号 (1, 2, 3...) */
|
||||||
|
int getVersion();
|
||||||
|
|
||||||
|
/** 人类可读描述 */
|
||||||
|
String getDescription();
|
||||||
|
|
||||||
|
/** Cypher DDL 语句列表 */
|
||||||
|
List<String> getStatements();
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 迁移记录数据类,映射 {@code _SchemaMigration} 节点。
|
||||||
|
* <p>
|
||||||
|
* 纯 POJO,不使用 SDN {@code @Node} 注解。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class SchemaMigrationRecord {
|
||||||
|
|
||||||
|
/** 迁移版本号 */
|
||||||
|
private int version;
|
||||||
|
|
||||||
|
/** 迁移描述 */
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
/** 迁移语句的 SHA-256 校验和 */
|
||||||
|
private String checksum;
|
||||||
|
|
||||||
|
/** 迁移应用时间(ISO-8601) */
|
||||||
|
private String appliedAt;
|
||||||
|
|
||||||
|
/** 迁移执行耗时(毫秒) */
|
||||||
|
private long executionTimeMs;
|
||||||
|
|
||||||
|
/** 迁移是否成功 */
|
||||||
|
private boolean success;
|
||||||
|
|
||||||
|
/** 迁移语句数量 */
|
||||||
|
private int statementsCount;
|
||||||
|
|
||||||
|
/** 失败时的错误信息 */
|
||||||
|
private String errorMessage;
|
||||||
|
}
|
||||||
@@ -0,0 +1,384 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schema 迁移编排器。
|
||||||
|
* <p>
|
||||||
|
* 参考 Flyway 设计思路,为 Neo4j 图数据库提供版本化迁移机制:
|
||||||
|
* <ul>
|
||||||
|
* <li>在数据库中记录已应用的迁移版本({@code _SchemaMigration} 节点)</li>
|
||||||
|
* <li>自动检测并执行新增迁移</li>
|
||||||
|
* <li>通过 checksum 校验防止已应用迁移被篡改</li>
|
||||||
|
* <li>通过分布式锁({@code _SchemaLock} 节点)防止多实例并发迁移</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
@Slf4j
|
||||||
|
public class SchemaMigrationService {
|
||||||
|
|
||||||
|
/** 分布式锁过期时间(毫秒),5 分钟 */
|
||||||
|
private static final long LOCK_TIMEOUT_MS = 5 * 60 * 1000;
|
||||||
|
|
||||||
|
/** 仅识别「已存在」类错误消息的关键词,其余错误不应吞掉。 */
|
||||||
|
private static final Set<String> ALREADY_EXISTS_KEYWORDS = Set.of(
|
||||||
|
"already exists", "already exist", "EquivalentSchemaRuleAlreadyExists"
|
||||||
|
);
|
||||||
|
|
||||||
|
private final Neo4jClient neo4jClient;
|
||||||
|
private final KnowledgeGraphProperties properties;
|
||||||
|
private final List<SchemaMigration> migrations;
|
||||||
|
|
||||||
|
public SchemaMigrationService(Neo4jClient neo4jClient,
|
||||||
|
KnowledgeGraphProperties properties,
|
||||||
|
List<SchemaMigration> migrations) {
|
||||||
|
this.neo4jClient = neo4jClient;
|
||||||
|
this.properties = properties;
|
||||||
|
this.migrations = migrations.stream()
|
||||||
|
.sorted(Comparator.comparingInt(SchemaMigration::getVersion))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 Schema 迁移主流程。
|
||||||
|
*
|
||||||
|
* @param instanceId 当前实例标识,用于分布式锁
|
||||||
|
*/
|
||||||
|
public void migrate(String instanceId) {
|
||||||
|
if (!properties.getMigration().isEnabled()) {
|
||||||
|
log.info("Schema migration is disabled, skipping");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Starting schema migration, instanceId={}", instanceId);
|
||||||
|
|
||||||
|
// 1. Bootstrap — 创建迁移系统自身需要的约束
|
||||||
|
bootstrapMigrationSchema();
|
||||||
|
|
||||||
|
// 2. 获取分布式锁
|
||||||
|
acquireLock(instanceId);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 3. 加载已应用迁移
|
||||||
|
List<SchemaMigrationRecord> applied = loadAppliedMigrations();
|
||||||
|
|
||||||
|
// 4. 校验 checksum
|
||||||
|
if (properties.getMigration().isValidateChecksums()) {
|
||||||
|
validateChecksums(applied, migrations);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 过滤待执行迁移
|
||||||
|
Set<Integer> appliedVersions = applied.stream()
|
||||||
|
.map(SchemaMigrationRecord::getVersion)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
List<SchemaMigration> pending = migrations.stream()
|
||||||
|
.filter(m -> !appliedVersions.contains(m.getVersion()))
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
if (pending.isEmpty()) {
|
||||||
|
log.info("Schema is up to date, no pending migrations");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 逐个执行
|
||||||
|
executePendingMigrations(pending);
|
||||||
|
|
||||||
|
log.info("Schema migration completed successfully, applied {} migration(s)", pending.size());
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
// 7. 释放锁
|
||||||
|
releaseLock(instanceId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建迁移系统自身需要的约束(解决鸡生蛋问题)。
|
||||||
|
*/
|
||||||
|
void bootstrapMigrationSchema() {
|
||||||
|
log.debug("Bootstrapping migration schema constraints");
|
||||||
|
neo4jClient.query(
|
||||||
|
"CREATE CONSTRAINT schema_migration_version_unique IF NOT EXISTS " +
|
||||||
|
"FOR (n:_SchemaMigration) REQUIRE n.version IS UNIQUE"
|
||||||
|
).run();
|
||||||
|
neo4jClient.query(
|
||||||
|
"CREATE CONSTRAINT schema_lock_name_unique IF NOT EXISTS " +
|
||||||
|
"FOR (n:_SchemaLock) REQUIRE n.name IS UNIQUE"
|
||||||
|
).run();
|
||||||
|
|
||||||
|
// 修复历史遗留节点:为缺失属性补充默认值,避免后续查询产生属性缺失警告
|
||||||
|
neo4jClient.query(
|
||||||
|
"MATCH (m:_SchemaMigration) WHERE m.description IS NULL OR m.checksum IS NULL " +
|
||||||
|
"SET m.description = COALESCE(m.description, ''), " +
|
||||||
|
" m.checksum = COALESCE(m.checksum, ''), " +
|
||||||
|
" m.applied_at = COALESCE(m.applied_at, ''), " +
|
||||||
|
" m.execution_time_ms = COALESCE(m.execution_time_ms, 0), " +
|
||||||
|
" m.statements_count = COALESCE(m.statements_count, 0), " +
|
||||||
|
" m.error_message = COALESCE(m.error_message, '')"
|
||||||
|
).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取分布式锁。
|
||||||
|
* <p>
|
||||||
|
* MERGE {@code _SchemaLock} 节点,如果锁已被其他实例持有且未过期,则抛出异常。
|
||||||
|
* 如果锁已过期(超过 5 分钟),自动接管。
|
||||||
|
* <p>
|
||||||
|
* 时间戳完全使用数据库端 {@code datetime().epochMillis},避免多实例时钟偏差导致锁被误抢占。
|
||||||
|
*/
|
||||||
|
void acquireLock(String instanceId) {
|
||||||
|
log.debug("Acquiring schema migration lock, instanceId={}", instanceId);
|
||||||
|
|
||||||
|
// 使用数据库时间(datetime().epochMillis)避免多实例时钟偏差导致锁被误抢占
|
||||||
|
Optional<Map<String, Object>> result = neo4jClient.query(
|
||||||
|
"MERGE (lock:_SchemaLock {name: 'schema_migration'}) " +
|
||||||
|
"ON CREATE SET lock.locked_by = $instanceId, lock.locked_at = datetime().epochMillis " +
|
||||||
|
"WITH lock, " +
|
||||||
|
" CASE WHEN lock.locked_by = $instanceId THEN true " +
|
||||||
|
" WHEN lock.locked_at < (datetime().epochMillis - $timeoutMs) THEN true " +
|
||||||
|
" ELSE false END AS canAcquire " +
|
||||||
|
"SET lock.locked_by = CASE WHEN canAcquire THEN $instanceId ELSE lock.locked_by END, " +
|
||||||
|
" lock.locked_at = CASE WHEN canAcquire THEN datetime().epochMillis ELSE lock.locked_at END " +
|
||||||
|
"RETURN lock.locked_by AS lockedBy, canAcquire"
|
||||||
|
).bindAll(Map.of("instanceId", instanceId, "timeoutMs", LOCK_TIMEOUT_MS))
|
||||||
|
.fetch().first();
|
||||||
|
|
||||||
|
if (result.isEmpty()) {
|
||||||
|
throw new IllegalStateException("Failed to acquire schema migration lock: unexpected empty result");
|
||||||
|
}
|
||||||
|
|
||||||
|
Boolean canAcquire = (Boolean) result.get().get("canAcquire");
|
||||||
|
if (!Boolean.TRUE.equals(canAcquire)) {
|
||||||
|
String lockedBy = (String) result.get().get("lockedBy");
|
||||||
|
throw BusinessException.of(
|
||||||
|
KnowledgeGraphErrorCode.SCHEMA_MIGRATION_LOCKED,
|
||||||
|
"Schema migration lock is held by instance: " + lockedBy
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Schema migration lock acquired, instanceId={}", instanceId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 释放分布式锁。
|
||||||
|
*/
|
||||||
|
void releaseLock(String instanceId) {
|
||||||
|
try {
|
||||||
|
neo4jClient.query(
|
||||||
|
"MATCH (lock:_SchemaLock {name: 'schema_migration', locked_by: $instanceId}) " +
|
||||||
|
"DELETE lock"
|
||||||
|
).bindAll(Map.of("instanceId", instanceId)).run();
|
||||||
|
log.debug("Schema migration lock released, instanceId={}", instanceId);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Failed to release schema migration lock: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 加载已应用的迁移记录。
|
||||||
|
*/
|
||||||
|
List<SchemaMigrationRecord> loadAppliedMigrations() {
|
||||||
|
return neo4jClient.query(
|
||||||
|
"MATCH (m:_SchemaMigration {success: true}) " +
|
||||||
|
"RETURN m.version AS version, " +
|
||||||
|
" COALESCE(m.description, '') AS description, " +
|
||||||
|
" COALESCE(m.checksum, '') AS checksum, " +
|
||||||
|
" COALESCE(m.applied_at, '') AS appliedAt, " +
|
||||||
|
" COALESCE(m.execution_time_ms, 0) AS executionTimeMs, " +
|
||||||
|
" m.success AS success, " +
|
||||||
|
" COALESCE(m.statements_count, 0) AS statementsCount, " +
|
||||||
|
" COALESCE(m.error_message, '') AS errorMessage " +
|
||||||
|
"ORDER BY m.version"
|
||||||
|
).fetch().all().stream()
|
||||||
|
.map(row -> SchemaMigrationRecord.builder()
|
||||||
|
.version(((Number) row.get("version")).intValue())
|
||||||
|
.description((String) row.get("description"))
|
||||||
|
.checksum((String) row.get("checksum"))
|
||||||
|
.appliedAt((String) row.get("appliedAt"))
|
||||||
|
.executionTimeMs(((Number) row.get("executionTimeMs")).longValue())
|
||||||
|
.success(Boolean.TRUE.equals(row.get("success")))
|
||||||
|
.statementsCount(((Number) row.get("statementsCount")).intValue())
|
||||||
|
.errorMessage((String) row.get("errorMessage"))
|
||||||
|
.build())
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验已应用迁移的 checksum。
|
||||||
|
*/
|
||||||
|
void validateChecksums(List<SchemaMigrationRecord> applied, List<SchemaMigration> registered) {
|
||||||
|
Map<Integer, SchemaMigration> registeredByVersion = registered.stream()
|
||||||
|
.collect(Collectors.toMap(SchemaMigration::getVersion, m -> m));
|
||||||
|
|
||||||
|
for (SchemaMigrationRecord record : applied) {
|
||||||
|
SchemaMigration migration = registeredByVersion.get(record.getVersion());
|
||||||
|
if (migration == null) {
|
||||||
|
continue; // 已应用但代码中不再有该迁移(可能是老版本被删除)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过 checksum 为空的历史遗留记录(属性缺失修复后的节点)
|
||||||
|
if (record.getChecksum() == null || record.getChecksum().isEmpty()) {
|
||||||
|
log.warn("Migration V{} ({}) has no recorded checksum, skipping validation",
|
||||||
|
record.getVersion(), record.getDescription());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
String currentChecksum = computeChecksum(migration.getStatements());
|
||||||
|
if (!currentChecksum.equals(record.getChecksum())) {
|
||||||
|
throw BusinessException.of(
|
||||||
|
KnowledgeGraphErrorCode.SCHEMA_CHECKSUM_MISMATCH,
|
||||||
|
String.format("Migration V%d (%s): recorded checksum=%s, current checksum=%s",
|
||||||
|
record.getVersion(), record.getDescription(),
|
||||||
|
record.getChecksum(), currentChecksum)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 逐个执行待迁移。
|
||||||
|
*/
|
||||||
|
void executePendingMigrations(List<SchemaMigration> pending) {
|
||||||
|
for (SchemaMigration migration : pending) {
|
||||||
|
log.info("Executing migration V{}: {}", migration.getVersion(), migration.getDescription());
|
||||||
|
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
String errorMessage = null;
|
||||||
|
boolean success = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
for (String statement : migration.getStatements()) {
|
||||||
|
try {
|
||||||
|
neo4jClient.query(statement).run();
|
||||||
|
log.debug(" Statement executed: {}",
|
||||||
|
statement.length() <= 100 ? statement : statement.substring(0, 97) + "...");
|
||||||
|
} catch (Exception e) {
|
||||||
|
if (isAlreadyExistsError(e)) {
|
||||||
|
log.debug(" Schema element already exists (safe to skip): {}",
|
||||||
|
statement.length() <= 100 ? statement : statement.substring(0, 97) + "...");
|
||||||
|
} else {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
success = false;
|
||||||
|
errorMessage = e.getMessage();
|
||||||
|
|
||||||
|
long elapsed = System.currentTimeMillis() - startTime;
|
||||||
|
recordMigration(SchemaMigrationRecord.builder()
|
||||||
|
.version(migration.getVersion())
|
||||||
|
.description(migration.getDescription())
|
||||||
|
.checksum(computeChecksum(migration.getStatements()))
|
||||||
|
.appliedAt(Instant.now().toString())
|
||||||
|
.executionTimeMs(elapsed)
|
||||||
|
.success(false)
|
||||||
|
.statementsCount(migration.getStatements().size())
|
||||||
|
.errorMessage(errorMessage)
|
||||||
|
.build());
|
||||||
|
|
||||||
|
throw BusinessException.of(
|
||||||
|
KnowledgeGraphErrorCode.SCHEMA_MIGRATION_FAILED,
|
||||||
|
String.format("Migration V%d (%s) failed: %s",
|
||||||
|
migration.getVersion(), migration.getDescription(), errorMessage)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
long elapsed = System.currentTimeMillis() - startTime;
|
||||||
|
recordMigration(SchemaMigrationRecord.builder()
|
||||||
|
.version(migration.getVersion())
|
||||||
|
.description(migration.getDescription())
|
||||||
|
.checksum(computeChecksum(migration.getStatements()))
|
||||||
|
.appliedAt(Instant.now().toString())
|
||||||
|
.executionTimeMs(elapsed)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(migration.getStatements().size())
|
||||||
|
.build());
|
||||||
|
|
||||||
|
log.info("Migration V{} completed in {}ms", migration.getVersion(), elapsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 写入迁移记录节点。
|
||||||
|
* <p>
|
||||||
|
* 使用 MERGE(按 version 匹配)+ SET 而非 CREATE,确保:
|
||||||
|
* <ul>
|
||||||
|
* <li>失败后重试不会因唯一约束冲突而卡死(P0)</li>
|
||||||
|
* <li>迁移执行成功但记录写入失败后,重跑可安全补写记录(幂等性)</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
void recordMigration(SchemaMigrationRecord record) {
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("version", record.getVersion());
|
||||||
|
params.put("description", nullToEmpty(record.getDescription()));
|
||||||
|
params.put("checksum", nullToEmpty(record.getChecksum()));
|
||||||
|
params.put("appliedAt", nullToEmpty(record.getAppliedAt()));
|
||||||
|
params.put("executionTimeMs", record.getExecutionTimeMs());
|
||||||
|
params.put("success", record.isSuccess());
|
||||||
|
params.put("statementsCount", record.getStatementsCount());
|
||||||
|
params.put("errorMessage", nullToEmpty(record.getErrorMessage()));
|
||||||
|
|
||||||
|
neo4jClient.query(
|
||||||
|
"MERGE (m:_SchemaMigration {version: $version}) " +
|
||||||
|
"SET m.description = $description, " +
|
||||||
|
" m.checksum = $checksum, " +
|
||||||
|
" m.applied_at = $appliedAt, " +
|
||||||
|
" m.execution_time_ms = $executionTimeMs, " +
|
||||||
|
" m.success = $success, " +
|
||||||
|
" m.statements_count = $statementsCount, " +
|
||||||
|
" m.error_message = $errorMessage"
|
||||||
|
).bindAll(params).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 计算语句列表的 SHA-256 校验和。
|
||||||
|
*/
|
||||||
|
static String computeChecksum(List<String> statements) {
|
||||||
|
try {
|
||||||
|
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||||
|
for (String statement : statements) {
|
||||||
|
digest.update(statement.getBytes(StandardCharsets.UTF_8));
|
||||||
|
}
|
||||||
|
byte[] hash = digest.digest();
|
||||||
|
StringBuilder hex = new StringBuilder();
|
||||||
|
for (byte b : hash) {
|
||||||
|
hex.append(String.format("%02x", b));
|
||||||
|
}
|
||||||
|
return hex.toString();
|
||||||
|
} catch (NoSuchAlgorithmException e) {
|
||||||
|
throw new IllegalStateException("SHA-256 algorithm not available", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断异常是否仅因为 Schema 元素已存在(安全可忽略)。
|
||||||
|
*/
|
||||||
|
static boolean isAlreadyExistsError(Exception e) {
|
||||||
|
String msg = e.getMessage();
|
||||||
|
if (msg == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
String lowerMsg = msg.toLowerCase();
|
||||||
|
return ALREADY_EXISTS_KEYWORDS.stream().anyMatch(kw -> lowerMsg.contains(kw.toLowerCase()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将 null 字符串转换为空字符串,避免 Neo4j 驱动 bindAll 传入 null 值导致属性缺失。
|
||||||
|
*/
|
||||||
|
private static String nullToEmpty(String value) {
|
||||||
|
return value != null ? value : "";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* V1 基线迁移:初始 Schema。
|
||||||
|
* <p>
|
||||||
|
* 包含 {@code GraphInitializer} 中原有的全部 14 条 DDL 语句。
|
||||||
|
* 在已有数据库上首次运行时,所有语句因 {@code IF NOT EXISTS} 而为 no-op,
|
||||||
|
* 但会建立版本基线。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class V1__InitialSchema implements SchemaMigration {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getVersion() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getDescription() {
|
||||||
|
return "Initial schema: Entity and SyncHistory constraints and indexes";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getStatements() {
|
||||||
|
return List.of(
|
||||||
|
// 约束(自动创建对应索引)
|
||||||
|
"CREATE CONSTRAINT entity_id_unique IF NOT EXISTS FOR (n:Entity) REQUIRE n.id IS UNIQUE",
|
||||||
|
|
||||||
|
// 同步 upsert 复合唯一约束:防止并发写入产生重复实体
|
||||||
|
"CREATE CONSTRAINT entity_sync_unique IF NOT EXISTS " +
|
||||||
|
"FOR (n:Entity) REQUIRE (n.graph_id, n.source_id, n.type) IS UNIQUE",
|
||||||
|
|
||||||
|
// 单字段索引
|
||||||
|
"CREATE INDEX entity_graph_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id)",
|
||||||
|
"CREATE INDEX entity_type IF NOT EXISTS FOR (n:Entity) ON (n.type)",
|
||||||
|
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
|
||||||
|
"CREATE INDEX entity_source_id IF NOT EXISTS FOR (n:Entity) ON (n.source_id)",
|
||||||
|
"CREATE INDEX entity_created_at IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
|
||||||
|
|
||||||
|
// 复合索引
|
||||||
|
"CREATE INDEX entity_graph_id_type IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.type)",
|
||||||
|
"CREATE INDEX entity_graph_id_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.id)",
|
||||||
|
"CREATE INDEX entity_graph_id_source_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.source_id)",
|
||||||
|
|
||||||
|
// 全文索引
|
||||||
|
"CREATE FULLTEXT INDEX entity_fulltext IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.description]",
|
||||||
|
|
||||||
|
// ── SyncHistory 约束和索引 ──
|
||||||
|
|
||||||
|
// syncId 唯一约束,防止 ID 碰撞
|
||||||
|
"CREATE CONSTRAINT sync_history_graph_sync_unique IF NOT EXISTS " +
|
||||||
|
"FOR (h:SyncHistory) REQUIRE (h.graph_id, h.sync_id) IS UNIQUE",
|
||||||
|
|
||||||
|
// 查询优化索引
|
||||||
|
"CREATE INDEX sync_history_graph_started IF NOT EXISTS " +
|
||||||
|
"FOR (h:SyncHistory) ON (h.graph_id, h.started_at)",
|
||||||
|
|
||||||
|
"CREATE INDEX sync_history_graph_status_started IF NOT EXISTS " +
|
||||||
|
"FOR (h:SyncHistory) ON (h.graph_id, h.status, h.started_at)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* V2 性能优化迁移:关系索引和属性索引。
|
||||||
|
* <p>
|
||||||
|
* V1 仅对 Entity 节点创建了索引。该迁移补充:
|
||||||
|
* <ul>
|
||||||
|
* <li>RELATED_TO 关系的 graph_id 索引(加速子图查询中的关系过滤)</li>
|
||||||
|
* <li>RELATED_TO 关系的 relation_type 索引(加速按类型筛选)</li>
|
||||||
|
* <li>Entity 的 (graph_id, name) 复合索引(加速 name 过滤查询)</li>
|
||||||
|
* <li>Entity 的 updated_at 索引(加速增量同步范围查询)</li>
|
||||||
|
* <li>RELATED_TO 关系的 (graph_id, relation_type) 复合索引</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class V2__PerformanceIndexes implements SchemaMigration {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getVersion() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getDescription() {
|
||||||
|
return "Performance indexes: relationship indexes and additional composite indexes";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getStatements() {
|
||||||
|
return List.of(
|
||||||
|
// 关系索引:加速子图查询中 WHERE r.graph_id = $graphId 的过滤
|
||||||
|
"CREATE INDEX rel_graph_id IF NOT EXISTS FOR ()-[r:RELATED_TO]-() ON (r.graph_id)",
|
||||||
|
|
||||||
|
// 关系索引:加速按关系类型筛选
|
||||||
|
"CREATE INDEX rel_relation_type IF NOT EXISTS FOR ()-[r:RELATED_TO]-() ON (r.relation_type)",
|
||||||
|
|
||||||
|
// 关系复合索引:加速同一图谱内按类型查询关系
|
||||||
|
"CREATE INDEX rel_graph_id_type IF NOT EXISTS FOR ()-[r:RELATED_TO]-() ON (r.graph_id, r.relation_type)",
|
||||||
|
|
||||||
|
// 节点复合索引:加速 graph_id + name 过滤查询
|
||||||
|
"CREATE INDEX entity_graph_id_name IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.name)",
|
||||||
|
|
||||||
|
// 节点索引:加速增量同步中的时间范围查询
|
||||||
|
"CREATE INDEX entity_updated_at IF NOT EXISTS FOR (n:Entity) ON (n.updated_at)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 所有路径查询结果。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class AllPathsVO {
|
||||||
|
|
||||||
|
/** 所有路径列表(按路径长度升序) */
|
||||||
|
private List<PathVO> paths;
|
||||||
|
|
||||||
|
/** 路径总数 */
|
||||||
|
private int pathCount;
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import jakarta.validation.constraints.NotEmpty;
|
||||||
|
import jakarta.validation.constraints.Size;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量删除请求。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class BatchDeleteRequest {
|
||||||
|
|
||||||
|
@NotEmpty(message = "ID 列表不能为空")
|
||||||
|
@Size(max = 100, message = "单次批量删除最多 100 条")
|
||||||
|
private List<String> ids;
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 编辑审核记录视图对象。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class EditReviewVO {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
private String graphId;
|
||||||
|
private String operationType;
|
||||||
|
private String entityId;
|
||||||
|
private String relationId;
|
||||||
|
private String payload;
|
||||||
|
private String status;
|
||||||
|
private String submittedBy;
|
||||||
|
private String reviewedBy;
|
||||||
|
private String reviewComment;
|
||||||
|
private LocalDateTime createdAt;
|
||||||
|
private LocalDateTime reviewedAt;
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出用关系边,包含完整属性。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class ExportEdgeVO {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
private String sourceEntityId;
|
||||||
|
private String targetEntityId;
|
||||||
|
private String relationType;
|
||||||
|
private Double weight;
|
||||||
|
private Double confidence;
|
||||||
|
private String sourceId;
|
||||||
|
}
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出用节点,包含完整属性。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class ExportNodeVO {
|
||||||
|
|
||||||
|
private String id;
|
||||||
|
private String name;
|
||||||
|
private String type;
|
||||||
|
private String description;
|
||||||
|
private Map<String, Object> properties;
|
||||||
|
}
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 审核通过/拒绝请求。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class ReviewActionRequest {
|
||||||
|
|
||||||
|
/** 审核意见(可选) */
|
||||||
|
private String comment;
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 子图导出结果。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class SubgraphExportVO {
|
||||||
|
|
||||||
|
/** 子图中的节点列表(包含完整属性) */
|
||||||
|
private List<ExportNodeVO> nodes;
|
||||||
|
|
||||||
|
/** 子图中的边列表 */
|
||||||
|
private List<ExportEdgeVO> edges;
|
||||||
|
|
||||||
|
/** 节点数量 */
|
||||||
|
private int nodeCount;
|
||||||
|
|
||||||
|
/** 边数量 */
|
||||||
|
private int edgeCount;
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.dto;
|
||||||
|
|
||||||
|
import jakarta.validation.constraints.AssertTrue;
|
||||||
|
import jakarta.validation.constraints.NotBlank;
|
||||||
|
import jakarta.validation.constraints.Pattern;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 提交编辑审核请求。
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class SubmitReviewRequest {
|
||||||
|
|
||||||
|
private static final String UUID_REGEX =
|
||||||
|
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 操作类型:CREATE_ENTITY, UPDATE_ENTITY, DELETE_ENTITY,
|
||||||
|
* CREATE_RELATION, UPDATE_RELATION, DELETE_RELATION,
|
||||||
|
* BATCH_DELETE_ENTITY, BATCH_DELETE_RELATION
|
||||||
|
*/
|
||||||
|
@NotBlank(message = "操作类型不能为空")
|
||||||
|
@Pattern(regexp = "^(CREATE|UPDATE|DELETE|BATCH_DELETE)_(ENTITY|RELATION)$",
|
||||||
|
message = "操作类型无效")
|
||||||
|
private String operationType;
|
||||||
|
|
||||||
|
/** 目标实体 ID(实体操作时必填) */
|
||||||
|
private String entityId;
|
||||||
|
|
||||||
|
/** 目标关系 ID(关系操作时必填) */
|
||||||
|
private String relationId;
|
||||||
|
|
||||||
|
/** 变更载荷(JSON 格式的请求体) */
|
||||||
|
private String payload;
|
||||||
|
|
||||||
|
@AssertTrue(message = "UPDATE/DELETE 实体操作必须提供 entityId")
|
||||||
|
private boolean isEntityIdValid() {
|
||||||
|
if (operationType == null) return true;
|
||||||
|
if (operationType.endsWith("_ENTITY") && !operationType.startsWith("CREATE")
|
||||||
|
&& !operationType.startsWith("BATCH")) {
|
||||||
|
return entityId != null && !entityId.isBlank();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@AssertTrue(message = "UPDATE/DELETE 关系操作必须提供 relationId")
|
||||||
|
private boolean isRelationIdValid() {
|
||||||
|
if (operationType == null) return true;
|
||||||
|
if (operationType.endsWith("_RELATION") && !operationType.startsWith("CREATE")
|
||||||
|
&& !operationType.startsWith("BATCH")) {
|
||||||
|
return relationId != null && !relationId.isBlank();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@AssertTrue(message = "CREATE/UPDATE/BATCH_DELETE 操作必须提供 payload")
|
||||||
|
private boolean isPayloadValid() {
|
||||||
|
if (operationType == null) return true;
|
||||||
|
if (operationType.startsWith("CREATE") || operationType.startsWith("UPDATE")
|
||||||
|
|| operationType.startsWith("BATCH_DELETE")) {
|
||||||
|
return payload != null && !payload.isBlank();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,4 +15,6 @@ public class UpdateEntityRequest {
|
|||||||
private List<String> aliases;
|
private List<String> aliases;
|
||||||
|
|
||||||
private Map<String, Object> properties;
|
private Map<String, Object> properties;
|
||||||
|
|
||||||
|
private Double confidence;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.rest;
|
||||||
|
|
||||||
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
|
import com.datamate.knowledgegraph.application.EditReviewService;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.EditReviewVO;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.ReviewActionRequest;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.SubmitReviewRequest;
|
||||||
|
import jakarta.validation.Valid;
|
||||||
|
import jakarta.validation.constraints.Pattern;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/knowledge-graph/{graphId}/review")
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
@Validated
|
||||||
|
public class EditReviewController {
|
||||||
|
|
||||||
|
private static final String UUID_REGEX =
|
||||||
|
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
|
||||||
|
|
||||||
|
private final EditReviewService reviewService;
|
||||||
|
|
||||||
|
@PostMapping("/submit")
|
||||||
|
@ResponseStatus(HttpStatus.CREATED)
|
||||||
|
public EditReviewVO submitReview(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@Valid @RequestBody SubmitReviewRequest request,
|
||||||
|
@RequestHeader(value = "X-User-Id", defaultValue = "anonymous") String userId) {
|
||||||
|
return reviewService.submitReview(graphId, request, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/{reviewId}/approve")
|
||||||
|
public EditReviewVO approveReview(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "reviewId 格式无效") String reviewId,
|
||||||
|
@RequestBody(required = false) ReviewActionRequest request,
|
||||||
|
@RequestHeader(value = "X-User-Id", defaultValue = "anonymous") String userId) {
|
||||||
|
String comment = (request != null) ? request.getComment() : null;
|
||||||
|
return reviewService.approveReview(graphId, reviewId, userId, comment);
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/{reviewId}/reject")
|
||||||
|
public EditReviewVO rejectReview(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "reviewId 格式无效") String reviewId,
|
||||||
|
@RequestBody(required = false) ReviewActionRequest request,
|
||||||
|
@RequestHeader(value = "X-User-Id", defaultValue = "anonymous") String userId) {
|
||||||
|
String comment = (request != null) ? request.getComment() : null;
|
||||||
|
return reviewService.rejectReview(graphId, reviewId, userId, comment);
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/pending")
|
||||||
|
public PagedResponse<EditReviewVO> listPendingReviews(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@RequestParam(defaultValue = "0") int page,
|
||||||
|
@RequestParam(defaultValue = "20") int size) {
|
||||||
|
return reviewService.listPendingReviews(graphId, page, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping
|
||||||
|
public PagedResponse<EditReviewVO> listReviews(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@RequestParam(required = false) String status,
|
||||||
|
@RequestParam(defaultValue = "0") int page,
|
||||||
|
@RequestParam(defaultValue = "20") int size) {
|
||||||
|
return reviewService.listReviews(graphId, status, page, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -119,4 +119,5 @@ public class GraphEntityController {
|
|||||||
@RequestParam(defaultValue = "50") int limit) {
|
@RequestParam(defaultValue = "50") int limit) {
|
||||||
return entityService.getNeighbors(graphId, entityId, depth, limit);
|
return entityService.getNeighbors(graphId, entityId, depth, limit);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,20 +2,19 @@ package com.datamate.knowledgegraph.interfaces.rest;
|
|||||||
|
|
||||||
import com.datamate.common.interfaces.PagedResponse;
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
import com.datamate.knowledgegraph.application.GraphQueryService;
|
import com.datamate.knowledgegraph.application.GraphQueryService;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.PathVO;
|
import com.datamate.knowledgegraph.interfaces.dto.*;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.SearchHitVO;
|
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.SubgraphRequest;
|
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.SubgraphVO;
|
|
||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import jakarta.validation.constraints.Pattern;
|
import jakarta.validation.constraints.Pattern;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识图谱查询接口。
|
* 知识图谱查询接口。
|
||||||
* <p>
|
* <p>
|
||||||
* 提供图遍历(邻居、最短路径、子图)和全文搜索功能。
|
* 提供图遍历(邻居、最短路径、所有路径、子图、子图导出)和全文搜索功能。
|
||||||
*/
|
*/
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/knowledge-graph/{graphId}/query")
|
@RequestMapping("/knowledge-graph/{graphId}/query")
|
||||||
@@ -56,6 +55,21 @@ public class GraphQueryController {
|
|||||||
return queryService.getShortestPath(graphId, sourceId, targetId, maxDepth);
|
return queryService.getShortestPath(graphId, sourceId, targetId, maxDepth);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询两个实体之间的所有路径。
|
||||||
|
* <p>
|
||||||
|
* 返回按路径长度升序排列的所有路径,支持最大深度和最大路径数限制。
|
||||||
|
*/
|
||||||
|
@GetMapping("/all-paths")
|
||||||
|
public AllPathsVO findAllPaths(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@RequestParam @Pattern(regexp = UUID_REGEX, message = "sourceId 格式无效") String sourceId,
|
||||||
|
@RequestParam @Pattern(regexp = UUID_REGEX, message = "targetId 格式无效") String targetId,
|
||||||
|
@RequestParam(defaultValue = "3") int maxDepth,
|
||||||
|
@RequestParam(defaultValue = "10") int maxPaths) {
|
||||||
|
return queryService.findAllPaths(graphId, sourceId, targetId, maxDepth, maxPaths);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 提取指定实体集合的子图(关系网络)。
|
* 提取指定实体集合的子图(关系网络)。
|
||||||
*/
|
*/
|
||||||
@@ -66,6 +80,32 @@ public class GraphQueryController {
|
|||||||
return queryService.getSubgraph(graphId, request.getEntityIds());
|
return queryService.getSubgraph(graphId, request.getEntityIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出指定实体集合的子图。
|
||||||
|
* <p>
|
||||||
|
* 支持深度扩展和多种输出格式(JSON、GraphML)。
|
||||||
|
*
|
||||||
|
* @param format 输出格式:json(默认)或 graphml
|
||||||
|
* @param depth 扩展深度(0=仅指定实体,1=含 1 跳邻居)
|
||||||
|
*/
|
||||||
|
@PostMapping("/subgraph/export")
|
||||||
|
public ResponseEntity<?> exportSubgraph(
|
||||||
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
|
||||||
|
@Valid @RequestBody SubgraphRequest request,
|
||||||
|
@RequestParam(defaultValue = "json") String format,
|
||||||
|
@RequestParam(defaultValue = "0") int depth) {
|
||||||
|
SubgraphExportVO exportVO = queryService.exportSubgraph(graphId, request.getEntityIds(), depth);
|
||||||
|
|
||||||
|
if ("graphml".equalsIgnoreCase(format)) {
|
||||||
|
String graphml = queryService.convertToGraphML(exportVO);
|
||||||
|
return ResponseEntity.ok()
|
||||||
|
.contentType(MediaType.APPLICATION_XML)
|
||||||
|
.body(graphml);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ResponseEntity.ok(exportVO);
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// 全文搜索
|
// 全文搜索
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -62,4 +62,5 @@ public class GraphRelationController {
|
|||||||
@PathVariable @Pattern(regexp = UUID_REGEX, message = "relationId 格式无效") String relationId) {
|
@PathVariable @Pattern(regexp = UUID_REGEX, message = "relationId 格式无效") String relationId) {
|
||||||
relationService.deleteRelation(graphId, relationId);
|
relationService.deleteRelation(graphId, relationId);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,13 @@
|
|||||||
# 注意:生产环境务必通过环境变量 NEO4J_PASSWORD 设置密码,不要使用默认值
|
# 注意:生产环境务必通过环境变量 NEO4J_PASSWORD 设置密码,不要使用默认值
|
||||||
|
|
||||||
spring:
|
spring:
|
||||||
|
data:
|
||||||
|
redis:
|
||||||
|
host: ${REDIS_HOST:datamate-redis}
|
||||||
|
port: ${REDIS_PORT:6379}
|
||||||
|
password: ${REDIS_PASSWORD:}
|
||||||
|
timeout: ${REDIS_TIMEOUT:3000}
|
||||||
|
|
||||||
neo4j:
|
neo4j:
|
||||||
uri: ${NEO4J_URI:bolt://datamate-neo4j:7687}
|
uri: ${NEO4J_URI:bolt://datamate-neo4j:7687}
|
||||||
authentication:
|
authentication:
|
||||||
@@ -31,12 +38,18 @@ datamate:
|
|||||||
# 是否跳过 Token 校验(默认 false = fail-closed)
|
# 是否跳过 Token 校验(默认 false = fail-closed)
|
||||||
# 仅在 dev/test 环境显式设置为 true 以跳过校验
|
# 仅在 dev/test 环境显式设置为 true 以跳过校验
|
||||||
skip-token-check: ${KG_SKIP_TOKEN_CHECK:false}
|
skip-token-check: ${KG_SKIP_TOKEN_CHECK:false}
|
||||||
|
# Schema 迁移配置
|
||||||
|
migration:
|
||||||
|
# 是否启用 Schema 版本化迁移
|
||||||
|
enabled: ${KG_MIGRATION_ENABLED:true}
|
||||||
|
# 是否校验已应用迁移的 checksum(防止迁移被篡改)
|
||||||
|
validate-checksums: ${KG_MIGRATION_VALIDATE_CHECKSUMS:true}
|
||||||
# MySQL → Neo4j 同步配置
|
# MySQL → Neo4j 同步配置
|
||||||
sync:
|
sync:
|
||||||
# 数据管理服务地址
|
# 数据管理服务地址
|
||||||
data-management-url: ${DATA_MANAGEMENT_URL:http://localhost:8080}
|
data-management-url: ${DATA_MANAGEMENT_URL:http://localhost:8080/api}
|
||||||
# 标注服务地址
|
# 标注服务地址
|
||||||
annotation-service-url: ${ANNOTATION_SERVICE_URL:http://localhost:8081}
|
annotation-service-url: ${ANNOTATION_SERVICE_URL:http://localhost:8080/api}
|
||||||
# 每页拉取数量
|
# 每页拉取数量
|
||||||
page-size: ${KG_SYNC_PAGE_SIZE:200}
|
page-size: ${KG_SYNC_PAGE_SIZE:200}
|
||||||
# HTTP 连接超时(毫秒)
|
# HTTP 连接超时(毫秒)
|
||||||
@@ -51,3 +64,13 @@ datamate:
|
|||||||
auto-init-schema: ${KG_AUTO_INIT_SCHEMA:true}
|
auto-init-schema: ${KG_AUTO_INIT_SCHEMA:true}
|
||||||
# 是否允许空快照触发 purge(默认 false,防止上游返回空列表时误删全部同步实体)
|
# 是否允许空快照触发 purge(默认 false,防止上游返回空列表时误删全部同步实体)
|
||||||
allow-purge-on-empty-snapshot: ${KG_ALLOW_PURGE_ON_EMPTY_SNAPSHOT:false}
|
allow-purge-on-empty-snapshot: ${KG_ALLOW_PURGE_ON_EMPTY_SNAPSHOT:false}
|
||||||
|
# 缓存配置
|
||||||
|
cache:
|
||||||
|
# 是否启用 Redis 缓存
|
||||||
|
enabled: ${KG_CACHE_ENABLED:true}
|
||||||
|
# 实体缓存 TTL(秒)
|
||||||
|
entity-ttl-seconds: ${KG_CACHE_ENTITY_TTL:3600}
|
||||||
|
# 查询结果缓存 TTL(秒)
|
||||||
|
query-ttl-seconds: ${KG_CACHE_QUERY_TTL:300}
|
||||||
|
# 全文搜索缓存 TTL(秒)
|
||||||
|
search-ttl-seconds: ${KG_CACHE_SEARCH_TTL:180}
|
||||||
|
|||||||
@@ -0,0 +1,361 @@
|
|||||||
|
package com.datamate.knowledgegraph.application;
|
||||||
|
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
|
import com.datamate.knowledgegraph.domain.model.EditReview;
|
||||||
|
import com.datamate.knowledgegraph.domain.repository.EditReviewRepository;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.EditReviewVO;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.SubmitReviewRequest;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.InjectMocks;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class EditReviewServiceTest {
|
||||||
|
|
||||||
|
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
|
||||||
|
private static final String REVIEW_ID = "660e8400-e29b-41d4-a716-446655440001";
|
||||||
|
private static final String ENTITY_ID = "770e8400-e29b-41d4-a716-446655440002";
|
||||||
|
private static final String USER_ID = "user-1";
|
||||||
|
private static final String REVIEWER_ID = "reviewer-1";
|
||||||
|
private static final String INVALID_GRAPH_ID = "not-a-uuid";
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private EditReviewRepository reviewRepository;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private GraphEntityService entityService;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private GraphRelationService relationService;
|
||||||
|
|
||||||
|
@InjectMocks
|
||||||
|
private EditReviewService reviewService;
|
||||||
|
|
||||||
|
private EditReview pendingReview;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
pendingReview = EditReview.builder()
|
||||||
|
.id(REVIEW_ID)
|
||||||
|
.graphId(GRAPH_ID)
|
||||||
|
.operationType("CREATE_ENTITY")
|
||||||
|
.payload("{\"name\":\"TestEntity\",\"type\":\"Dataset\"}")
|
||||||
|
.status("PENDING")
|
||||||
|
.submittedBy(USER_ID)
|
||||||
|
.createdAt(LocalDateTime.now())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// graphId 校验
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_invalidGraphId_throwsBusinessException() {
|
||||||
|
SubmitReviewRequest request = new SubmitReviewRequest();
|
||||||
|
request.setOperationType("CREATE_ENTITY");
|
||||||
|
request.setPayload("{}");
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.submitReview(INVALID_GRAPH_ID, request, USER_ID))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_invalidGraphId_throwsBusinessException() {
|
||||||
|
assertThatThrownBy(() -> reviewService.approveReview(INVALID_GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// submitReview
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_success() {
|
||||||
|
SubmitReviewRequest request = new SubmitReviewRequest();
|
||||||
|
request.setOperationType("CREATE_ENTITY");
|
||||||
|
request.setPayload("{\"name\":\"NewEntity\",\"type\":\"Dataset\"}");
|
||||||
|
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
EditReviewVO result = reviewService.submitReview(GRAPH_ID, request, USER_ID);
|
||||||
|
|
||||||
|
assertThat(result).isNotNull();
|
||||||
|
assertThat(result.getStatus()).isEqualTo("PENDING");
|
||||||
|
assertThat(result.getOperationType()).isEqualTo("CREATE_ENTITY");
|
||||||
|
verify(reviewRepository).save(any(EditReview.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_withEntityId() {
|
||||||
|
SubmitReviewRequest request = new SubmitReviewRequest();
|
||||||
|
request.setOperationType("UPDATE_ENTITY");
|
||||||
|
request.setEntityId(ENTITY_ID);
|
||||||
|
request.setPayload("{\"name\":\"Updated\"}");
|
||||||
|
|
||||||
|
EditReview savedReview = EditReview.builder()
|
||||||
|
.id(REVIEW_ID)
|
||||||
|
.graphId(GRAPH_ID)
|
||||||
|
.operationType("UPDATE_ENTITY")
|
||||||
|
.entityId(ENTITY_ID)
|
||||||
|
.payload("{\"name\":\"Updated\"}")
|
||||||
|
.status("PENDING")
|
||||||
|
.submittedBy(USER_ID)
|
||||||
|
.createdAt(LocalDateTime.now())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(savedReview);
|
||||||
|
|
||||||
|
EditReviewVO result = reviewService.submitReview(GRAPH_ID, request, USER_ID);
|
||||||
|
|
||||||
|
assertThat(result.getEntityId()).isEqualTo(ENTITY_ID);
|
||||||
|
assertThat(result.getOperationType()).isEqualTo("UPDATE_ENTITY");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// approveReview
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_success_appliesChange() {
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
EditReviewVO result = reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, "LGTM");
|
||||||
|
|
||||||
|
assertThat(result).isNotNull();
|
||||||
|
assertThat(pendingReview.getStatus()).isEqualTo("APPROVED");
|
||||||
|
assertThat(pendingReview.getReviewedBy()).isEqualTo(REVIEWER_ID);
|
||||||
|
assertThat(pendingReview.getReviewComment()).isEqualTo("LGTM");
|
||||||
|
assertThat(pendingReview.getReviewedAt()).isNotNull();
|
||||||
|
|
||||||
|
// Verify applyChange was called (createEntity for CREATE_ENTITY)
|
||||||
|
verify(entityService).createEntity(eq(GRAPH_ID), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_notFound_throwsBusinessException() {
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_alreadyProcessed_throwsBusinessException() {
|
||||||
|
pendingReview.setStatus("APPROVED");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_deleteEntity_appliesChange() {
|
||||||
|
pendingReview.setOperationType("DELETE_ENTITY");
|
||||||
|
pendingReview.setEntityId(ENTITY_ID);
|
||||||
|
pendingReview.setPayload(null);
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null);
|
||||||
|
|
||||||
|
verify(entityService).deleteEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_updateEntity_appliesChange() {
|
||||||
|
pendingReview.setOperationType("UPDATE_ENTITY");
|
||||||
|
pendingReview.setEntityId(ENTITY_ID);
|
||||||
|
pendingReview.setPayload("{\"name\":\"Updated\"}");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null);
|
||||||
|
|
||||||
|
verify(entityService).updateEntity(eq(GRAPH_ID), eq(ENTITY_ID), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_createRelation_appliesChange() {
|
||||||
|
pendingReview.setOperationType("CREATE_RELATION");
|
||||||
|
pendingReview.setPayload("{\"sourceEntityId\":\"a\",\"targetEntityId\":\"b\",\"relationType\":\"HAS_FIELD\"}");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null);
|
||||||
|
|
||||||
|
verify(relationService).createRelation(eq(GRAPH_ID), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_invalidPayload_throwsBusinessException() {
|
||||||
|
pendingReview.setOperationType("CREATE_ENTITY");
|
||||||
|
pendingReview.setPayload("not valid json {{");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_batchDeleteEntity_appliesChange() {
|
||||||
|
pendingReview.setOperationType("BATCH_DELETE_ENTITY");
|
||||||
|
pendingReview.setPayload("{\"ids\":[\"id-1\",\"id-2\",\"id-3\"]}");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null);
|
||||||
|
|
||||||
|
verify(entityService).batchDeleteEntities(eq(GRAPH_ID), eq(List.of("id-1", "id-2", "id-3")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_batchDeleteRelation_appliesChange() {
|
||||||
|
pendingReview.setOperationType("BATCH_DELETE_RELATION");
|
||||||
|
pendingReview.setPayload("{\"ids\":[\"rel-1\",\"rel-2\"]}");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
reviewService.approveReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null);
|
||||||
|
|
||||||
|
verify(relationService).batchDeleteRelations(eq(GRAPH_ID), eq(List.of("rel-1", "rel-2")));
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// rejectReview
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectReview_success() {
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
when(reviewRepository.save(any(EditReview.class))).thenReturn(pendingReview);
|
||||||
|
|
||||||
|
EditReviewVO result = reviewService.rejectReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, "不合适");
|
||||||
|
|
||||||
|
assertThat(result).isNotNull();
|
||||||
|
assertThat(pendingReview.getStatus()).isEqualTo("REJECTED");
|
||||||
|
assertThat(pendingReview.getReviewedBy()).isEqualTo(REVIEWER_ID);
|
||||||
|
assertThat(pendingReview.getReviewComment()).isEqualTo("不合适");
|
||||||
|
assertThat(pendingReview.getReviewedAt()).isNotNull();
|
||||||
|
|
||||||
|
// Verify no change was applied
|
||||||
|
verifyNoInteractions(entityService);
|
||||||
|
verifyNoInteractions(relationService);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectReview_notFound_throwsBusinessException() {
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.rejectReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectReview_alreadyProcessed_throwsBusinessException() {
|
||||||
|
pendingReview.setStatus("REJECTED");
|
||||||
|
|
||||||
|
when(reviewRepository.findById(REVIEW_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(pendingReview));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> reviewService.rejectReview(GRAPH_ID, REVIEW_ID, REVIEWER_ID, null))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// listPendingReviews
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPendingReviews_returnsPagedResult() {
|
||||||
|
when(reviewRepository.findPendingByGraphId(GRAPH_ID, 0L, 20))
|
||||||
|
.thenReturn(List.of(pendingReview));
|
||||||
|
when(reviewRepository.countPendingByGraphId(GRAPH_ID)).thenReturn(1L);
|
||||||
|
|
||||||
|
var result = reviewService.listPendingReviews(GRAPH_ID, 0, 20);
|
||||||
|
|
||||||
|
assertThat(result.getContent()).hasSize(1);
|
||||||
|
assertThat(result.getTotalElements()).isEqualTo(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPendingReviews_clampsPageSize() {
|
||||||
|
when(reviewRepository.findPendingByGraphId(GRAPH_ID, 0L, 200))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
when(reviewRepository.countPendingByGraphId(GRAPH_ID)).thenReturn(0L);
|
||||||
|
|
||||||
|
reviewService.listPendingReviews(GRAPH_ID, 0, 999);
|
||||||
|
|
||||||
|
verify(reviewRepository).findPendingByGraphId(GRAPH_ID, 0L, 200);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPendingReviews_negativePage_clampedToZero() {
|
||||||
|
when(reviewRepository.findPendingByGraphId(GRAPH_ID, 0L, 20))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
when(reviewRepository.countPendingByGraphId(GRAPH_ID)).thenReturn(0L);
|
||||||
|
|
||||||
|
var result = reviewService.listPendingReviews(GRAPH_ID, -1, 20);
|
||||||
|
|
||||||
|
assertThat(result.getPage()).isEqualTo(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// listReviews
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listReviews_withStatusFilter() {
|
||||||
|
when(reviewRepository.findByGraphId(GRAPH_ID, "APPROVED", 0L, 20))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
when(reviewRepository.countByGraphId(GRAPH_ID, "APPROVED")).thenReturn(0L);
|
||||||
|
|
||||||
|
var result = reviewService.listReviews(GRAPH_ID, "APPROVED", 0, 20);
|
||||||
|
|
||||||
|
assertThat(result.getContent()).isEmpty();
|
||||||
|
verify(reviewRepository).findByGraphId(GRAPH_ID, "APPROVED", 0L, 20);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listReviews_withoutStatusFilter() {
|
||||||
|
when(reviewRepository.findByGraphId(GRAPH_ID, null, 0L, 20))
|
||||||
|
.thenReturn(List.of(pendingReview));
|
||||||
|
when(reviewRepository.countByGraphId(GRAPH_ID, null)).thenReturn(1L);
|
||||||
|
|
||||||
|
var result = reviewService.listReviews(GRAPH_ID, null, 0, 20);
|
||||||
|
|
||||||
|
assertThat(result.getContent()).hasSize(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package com.datamate.knowledgegraph.application;
|
|||||||
import com.datamate.common.infrastructure.exception.BusinessException;
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
|
||||||
@@ -37,6 +38,9 @@ class GraphEntityServiceTest {
|
|||||||
@Mock
|
@Mock
|
||||||
private KnowledgeGraphProperties properties;
|
private KnowledgeGraphProperties properties;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private GraphCacheService cacheService;
|
||||||
|
|
||||||
@InjectMocks
|
@InjectMocks
|
||||||
private GraphEntityService entityService;
|
private GraphEntityService entityService;
|
||||||
|
|
||||||
@@ -90,6 +94,8 @@ class GraphEntityServiceTest {
|
|||||||
assertThat(result).isNotNull();
|
assertThat(result).isNotNull();
|
||||||
assertThat(result.getName()).isEqualTo("TestDataset");
|
assertThat(result.getName()).isEqualTo("TestDataset");
|
||||||
verify(entityRepository).save(any(GraphEntity.class));
|
verify(entityRepository).save(any(GraphEntity.class));
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, ENTITY_ID);
|
||||||
|
verify(cacheService).evictSearchCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -150,6 +156,8 @@ class GraphEntityServiceTest {
|
|||||||
|
|
||||||
assertThat(result.getName()).isEqualTo("UpdatedName");
|
assertThat(result.getName()).isEqualTo("UpdatedName");
|
||||||
assertThat(result.getDescription()).isEqualTo("A test dataset");
|
assertThat(result.getDescription()).isEqualTo("A test dataset");
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, ENTITY_ID);
|
||||||
|
verify(cacheService).evictSearchCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -164,6 +172,8 @@ class GraphEntityServiceTest {
|
|||||||
entityService.deleteEntity(GRAPH_ID, ENTITY_ID);
|
entityService.deleteEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
|
||||||
verify(entityRepository).delete(sampleEntity);
|
verify(entityRepository).delete(sampleEntity);
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, ENTITY_ID);
|
||||||
|
verify(cacheService).evictSearchCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import com.datamate.common.infrastructure.exception.BusinessException;
|
|||||||
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.AllPathsVO;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.SubgraphExportVO;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.SubgraphVO;
|
import com.datamate.knowledgegraph.interfaces.dto.SubgraphVO;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Nested;
|
import org.junit.jupiter.api.Nested;
|
||||||
@@ -13,6 +15,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
|||||||
import org.mockito.InjectMocks;
|
import org.mockito.InjectMocks;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.neo4j.driver.Driver;
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient;
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -36,6 +39,9 @@ class GraphQueryServiceTest {
|
|||||||
@Mock
|
@Mock
|
||||||
private Neo4jClient neo4jClient;
|
private Neo4jClient neo4jClient;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Driver neo4jDriver;
|
||||||
|
|
||||||
@Mock
|
@Mock
|
||||||
private GraphEntityRepository entityRepository;
|
private GraphEntityRepository entityRepository;
|
||||||
|
|
||||||
@@ -594,4 +600,295 @@ class GraphQueryServiceTest {
|
|||||||
assertThat(result.getNodes().get(0).getName()).isEqualTo("Normal KS");
|
assertThat(result.getNodes().get(0).getName()).isEqualTo("Normal KS");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// findAllPaths
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class FindAllPathsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_invalidGraphId_throwsBusinessException() {
|
||||||
|
assertThatThrownBy(() -> queryService.findAllPaths(INVALID_GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3, 10))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_sourceNotFound_throwsBusinessException() {
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3, 10))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_targetNotFound_throwsBusinessException() {
|
||||||
|
GraphEntity sourceEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("Source").type("Dataset").graphId(GRAPH_ID).build();
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sourceEntity));
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID_2, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3, 10))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_sameSourceAndTarget_returnsSingleNodePath() {
|
||||||
|
GraphEntity entity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("Node").type("Dataset").graphId(GRAPH_ID).build();
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(entity));
|
||||||
|
|
||||||
|
AllPathsVO result = queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID, 3, 10);
|
||||||
|
|
||||||
|
assertThat(result.getPathCount()).isEqualTo(1);
|
||||||
|
assertThat(result.getPaths()).hasSize(1);
|
||||||
|
assertThat(result.getPaths().get(0).getPathLength()).isEqualTo(0);
|
||||||
|
assertThat(result.getPaths().get(0).getNodes()).hasSize(1);
|
||||||
|
assertThat(result.getPaths().get(0).getEdges()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_nonAdmin_sourceNotAccessible_throws() {
|
||||||
|
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn("user-123");
|
||||||
|
|
||||||
|
GraphEntity sourceEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "other-user")))
|
||||||
|
.build();
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sourceEntity));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3, 10))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
|
||||||
|
verifyNoInteractions(neo4jClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_nonAdmin_targetNotAccessible_throws() {
|
||||||
|
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn("user-123");
|
||||||
|
|
||||||
|
GraphEntity sourceEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("My Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "user-123")))
|
||||||
|
.build();
|
||||||
|
GraphEntity targetEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID_2).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "other-user")))
|
||||||
|
.build();
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sourceEntity));
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID_2, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(targetEntity));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3, 10))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
|
||||||
|
verifyNoInteractions(neo4jClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void findAllPaths_nonAdmin_structuralEntity_sameSourceAndTarget_returnsSingleNode() {
|
||||||
|
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn("user-123");
|
||||||
|
|
||||||
|
GraphEntity structuralEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("Admin User").type("User").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>())
|
||||||
|
.build();
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(structuralEntity));
|
||||||
|
|
||||||
|
AllPathsVO result = queryService.findAllPaths(GRAPH_ID, ENTITY_ID, ENTITY_ID, 3, 10);
|
||||||
|
|
||||||
|
assertThat(result.getPathCount()).isEqualTo(1);
|
||||||
|
assertThat(result.getPaths().get(0).getNodes().get(0).getType()).isEqualTo("User");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// exportSubgraph
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class ExportSubgraphTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_invalidGraphId_throwsBusinessException() {
|
||||||
|
assertThatThrownBy(() -> queryService.exportSubgraph(INVALID_GRAPH_ID, List.of(ENTITY_ID), 0))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_nullEntityIds_returnsEmptyExport() {
|
||||||
|
SubgraphExportVO result = queryService.exportSubgraph(GRAPH_ID, null, 0);
|
||||||
|
|
||||||
|
assertThat(result.getNodes()).isEmpty();
|
||||||
|
assertThat(result.getEdges()).isEmpty();
|
||||||
|
assertThat(result.getNodeCount()).isEqualTo(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_emptyEntityIds_returnsEmptyExport() {
|
||||||
|
SubgraphExportVO result = queryService.exportSubgraph(GRAPH_ID, List.of(), 0);
|
||||||
|
|
||||||
|
assertThat(result.getNodes()).isEmpty();
|
||||||
|
assertThat(result.getEdges()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_exceedsMaxNodes_throwsBusinessException() {
|
||||||
|
when(properties.getMaxNodesPerQuery()).thenReturn(5);
|
||||||
|
|
||||||
|
List<String> tooManyIds = List.of("1", "2", "3", "4", "5", "6");
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> queryService.exportSubgraph(GRAPH_ID, tooManyIds, 0))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_depthZero_noExistingEntities_returnsEmptyExport() {
|
||||||
|
when(properties.getMaxNodesPerQuery()).thenReturn(500);
|
||||||
|
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
|
||||||
|
SubgraphExportVO result = queryService.exportSubgraph(GRAPH_ID, List.of(ENTITY_ID), 0);
|
||||||
|
|
||||||
|
assertThat(result.getNodes()).isEmpty();
|
||||||
|
assertThat(result.getNodeCount()).isEqualTo(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_depthZero_singleEntity_returnsNodeWithProperties() {
|
||||||
|
when(properties.getMaxNodesPerQuery()).thenReturn(500);
|
||||||
|
|
||||||
|
GraphEntity entity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("Test Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.description("A test dataset")
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "user-1", "sensitivity", "PUBLIC")))
|
||||||
|
.build();
|
||||||
|
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
|
||||||
|
.thenReturn(List.of(entity));
|
||||||
|
|
||||||
|
SubgraphExportVO result = queryService.exportSubgraph(GRAPH_ID, List.of(ENTITY_ID), 0);
|
||||||
|
|
||||||
|
assertThat(result.getNodes()).hasSize(1);
|
||||||
|
assertThat(result.getNodeCount()).isEqualTo(1);
|
||||||
|
assertThat(result.getNodes().get(0).getName()).isEqualTo("Test Dataset");
|
||||||
|
assertThat(result.getNodes().get(0).getProperties()).containsEntry("created_by", "user-1");
|
||||||
|
// 单节点无边
|
||||||
|
assertThat(result.getEdges()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exportSubgraph_nonAdmin_filtersInaccessibleEntities() {
|
||||||
|
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn("user-123");
|
||||||
|
when(properties.getMaxNodesPerQuery()).thenReturn(500);
|
||||||
|
|
||||||
|
GraphEntity ownEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("My Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "user-123")))
|
||||||
|
.build();
|
||||||
|
GraphEntity otherEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID_2).name("Other Dataset").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "other-user")))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2)))
|
||||||
|
.thenReturn(List.of(ownEntity, otherEntity));
|
||||||
|
|
||||||
|
SubgraphExportVO result = queryService.exportSubgraph(GRAPH_ID,
|
||||||
|
List.of(ENTITY_ID, ENTITY_ID_2), 0);
|
||||||
|
|
||||||
|
assertThat(result.getNodes()).hasSize(1);
|
||||||
|
assertThat(result.getNodes().get(0).getName()).isEqualTo("My Dataset");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// convertToGraphML
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class ConvertToGraphMLTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void convertToGraphML_emptyExport_producesValidXml() {
|
||||||
|
SubgraphExportVO emptyExport = SubgraphExportVO.builder()
|
||||||
|
.nodes(List.of())
|
||||||
|
.edges(List.of())
|
||||||
|
.nodeCount(0)
|
||||||
|
.edgeCount(0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String graphml = queryService.convertToGraphML(emptyExport);
|
||||||
|
|
||||||
|
assertThat(graphml).contains("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
|
||||||
|
assertThat(graphml).contains("<graphml");
|
||||||
|
assertThat(graphml).contains("<graph id=\"G\" edgedefault=\"directed\">");
|
||||||
|
assertThat(graphml).contains("</graphml>");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void convertToGraphML_withNodesAndEdges_producesCorrectStructure() {
|
||||||
|
SubgraphExportVO export = SubgraphExportVO.builder()
|
||||||
|
.nodes(List.of(
|
||||||
|
com.datamate.knowledgegraph.interfaces.dto.ExportNodeVO.builder()
|
||||||
|
.id("node-1").name("Dataset A").type("Dataset")
|
||||||
|
.description("Test dataset").properties(Map.of())
|
||||||
|
.build(),
|
||||||
|
com.datamate.knowledgegraph.interfaces.dto.ExportNodeVO.builder()
|
||||||
|
.id("node-2").name("Workflow B").type("Workflow")
|
||||||
|
.description(null).properties(Map.of())
|
||||||
|
.build()
|
||||||
|
))
|
||||||
|
.edges(List.of(
|
||||||
|
com.datamate.knowledgegraph.interfaces.dto.ExportEdgeVO.builder()
|
||||||
|
.id("edge-1").sourceEntityId("node-1").targetEntityId("node-2")
|
||||||
|
.relationType("DERIVED_FROM").weight(0.8)
|
||||||
|
.build()
|
||||||
|
))
|
||||||
|
.nodeCount(2)
|
||||||
|
.edgeCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String graphml = queryService.convertToGraphML(export);
|
||||||
|
|
||||||
|
assertThat(graphml).contains("<node id=\"node-1\">");
|
||||||
|
assertThat(graphml).contains("<data key=\"name\">Dataset A</data>");
|
||||||
|
assertThat(graphml).contains("<data key=\"type\">Dataset</data>");
|
||||||
|
assertThat(graphml).contains("<data key=\"description\">Test dataset</data>");
|
||||||
|
assertThat(graphml).contains("<node id=\"node-2\">");
|
||||||
|
assertThat(graphml).contains("<data key=\"type\">Workflow</data>");
|
||||||
|
// null description 不输出
|
||||||
|
assertThat(graphml).doesNotContain("<data key=\"description\">null</data>");
|
||||||
|
assertThat(graphml).contains("<edge id=\"edge-1\" source=\"node-1\" target=\"node-2\">");
|
||||||
|
assertThat(graphml).contains("<data key=\"relationType\">DERIVED_FROM</data>");
|
||||||
|
assertThat(graphml).contains("<data key=\"weight\">0.8</data>");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void convertToGraphML_specialCharactersEscaped() {
|
||||||
|
SubgraphExportVO export = SubgraphExportVO.builder()
|
||||||
|
.nodes(List.of(
|
||||||
|
com.datamate.knowledgegraph.interfaces.dto.ExportNodeVO.builder()
|
||||||
|
.id("node-1").name("A & B <Corp>").type("Org")
|
||||||
|
.description("\"Test\" org").properties(Map.of())
|
||||||
|
.build()
|
||||||
|
))
|
||||||
|
.edges(List.of())
|
||||||
|
.nodeCount(1)
|
||||||
|
.edgeCount(0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
String graphml = queryService.convertToGraphML(export);
|
||||||
|
|
||||||
|
assertThat(graphml).contains("A & B <Corp>");
|
||||||
|
assertThat(graphml).contains(""Test" org");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
|||||||
import com.datamate.knowledgegraph.domain.model.RelationDetail;
|
import com.datamate.knowledgegraph.domain.model.RelationDetail;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
|
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
|
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
|
||||||
import com.datamate.knowledgegraph.interfaces.dto.UpdateRelationRequest;
|
import com.datamate.knowledgegraph.interfaces.dto.UpdateRelationRequest;
|
||||||
@@ -40,6 +41,9 @@ class GraphRelationServiceTest {
|
|||||||
@Mock
|
@Mock
|
||||||
private GraphEntityRepository entityRepository;
|
private GraphEntityRepository entityRepository;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private GraphCacheService cacheService;
|
||||||
|
|
||||||
@InjectMocks
|
@InjectMocks
|
||||||
private GraphRelationService relationService;
|
private GraphRelationService relationService;
|
||||||
|
|
||||||
@@ -106,6 +110,7 @@ class GraphRelationServiceTest {
|
|||||||
assertThat(result.getRelationType()).isEqualTo("HAS_FIELD");
|
assertThat(result.getRelationType()).isEqualTo("HAS_FIELD");
|
||||||
assertThat(result.getSourceEntityId()).isEqualTo(SOURCE_ENTITY_ID);
|
assertThat(result.getSourceEntityId()).isEqualTo(SOURCE_ENTITY_ID);
|
||||||
assertThat(result.getTargetEntityId()).isEqualTo(TARGET_ENTITY_ID);
|
assertThat(result.getTargetEntityId()).isEqualTo(TARGET_ENTITY_ID);
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, SOURCE_ENTITY_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -241,6 +246,7 @@ class GraphRelationServiceTest {
|
|||||||
RelationVO result = relationService.updateRelation(GRAPH_ID, RELATION_ID, request);
|
RelationVO result = relationService.updateRelation(GRAPH_ID, RELATION_ID, request);
|
||||||
|
|
||||||
assertThat(result.getRelationType()).isEqualTo("USES");
|
assertThat(result.getRelationType()).isEqualTo("USES");
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, SOURCE_ENTITY_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -257,6 +263,8 @@ class GraphRelationServiceTest {
|
|||||||
relationService.deleteRelation(GRAPH_ID, RELATION_ID);
|
relationService.deleteRelation(GRAPH_ID, RELATION_ID);
|
||||||
|
|
||||||
verify(relationRepository).deleteByIdAndGraphId(RELATION_ID, GRAPH_ID);
|
verify(relationRepository).deleteByIdAndGraphId(RELATION_ID, GRAPH_ID);
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, SOURCE_ENTITY_ID);
|
||||||
|
verify(cacheService).evictEntityCaches(GRAPH_ID, TARGET_ENTITY_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import com.datamate.common.infrastructure.exception.BusinessException;
|
|||||||
import com.datamate.knowledgegraph.domain.model.SyncMetadata;
|
import com.datamate.knowledgegraph.domain.model.SyncMetadata;
|
||||||
import com.datamate.knowledgegraph.domain.model.SyncResult;
|
import com.datamate.knowledgegraph.domain.model.SyncResult;
|
||||||
import com.datamate.knowledgegraph.domain.repository.SyncHistoryRepository;
|
import com.datamate.knowledgegraph.domain.repository.SyncHistoryRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.cache.GraphCacheService;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
|
||||||
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
|
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
|
||||||
@@ -50,6 +51,9 @@ class GraphSyncServiceTest {
|
|||||||
@Mock
|
@Mock
|
||||||
private SyncHistoryRepository syncHistoryRepository;
|
private SyncHistoryRepository syncHistoryRepository;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private GraphCacheService cacheService;
|
||||||
|
|
||||||
@InjectMocks
|
@InjectMocks
|
||||||
private GraphSyncService syncService;
|
private GraphSyncService syncService;
|
||||||
|
|
||||||
@@ -133,7 +137,9 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("Field").build());
|
.thenReturn(SyncResult.builder().syncType("Field").build());
|
||||||
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("User").build());
|
.thenReturn(SyncResult.builder().syncType("User").build());
|
||||||
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyString()))
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Org").build());
|
.thenReturn(SyncResult.builder().syncType("Org").build());
|
||||||
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
||||||
@@ -152,7 +158,7 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
@@ -186,6 +192,9 @@ class GraphSyncServiceTest {
|
|||||||
assertThat(byType).containsKeys("HAS_FIELD", "DERIVED_FROM", "BELONGS_TO",
|
assertThat(byType).containsKeys("HAS_FIELD", "DERIVED_FROM", "BELONGS_TO",
|
||||||
"USES_DATASET", "PRODUCES", "ASSIGNED_TO", "TRIGGERS",
|
"USES_DATASET", "PRODUCES", "ASSIGNED_TO", "TRIGGERS",
|
||||||
"DEPENDS_ON", "IMPACTS", "SOURCED_FROM");
|
"DEPENDS_ON", "IMPACTS", "SOURCED_FROM");
|
||||||
|
|
||||||
|
// 验证缓存清除(finally 块)
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -200,6 +209,9 @@ class GraphSyncServiceTest {
|
|||||||
assertThatThrownBy(() -> syncService.syncDatasets(GRAPH_ID))
|
assertThatThrownBy(() -> syncService.syncDatasets(GRAPH_ID))
|
||||||
.isInstanceOf(BusinessException.class)
|
.isInstanceOf(BusinessException.class)
|
||||||
.hasMessageContaining("datasets");
|
.hasMessageContaining("datasets");
|
||||||
|
|
||||||
|
// P1 fix: 即使失败,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
@@ -226,6 +238,7 @@ class GraphSyncServiceTest {
|
|||||||
|
|
||||||
assertThat(result.getSyncType()).isEqualTo("Workflow");
|
assertThat(result.getSyncType()).isEqualTo("Workflow");
|
||||||
verify(stepService).upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString());
|
verify(stepService).upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString());
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -245,6 +258,7 @@ class GraphSyncServiceTest {
|
|||||||
|
|
||||||
assertThat(result.getSyncType()).isEqualTo("Job");
|
assertThat(result.getSyncType()).isEqualTo("Job");
|
||||||
verify(stepService).upsertJobEntities(eq(GRAPH_ID), anyList(), anyString());
|
verify(stepService).upsertJobEntities(eq(GRAPH_ID), anyList(), anyString());
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -263,6 +277,7 @@ class GraphSyncServiceTest {
|
|||||||
SyncResult result = syncService.syncLabelTasks(GRAPH_ID);
|
SyncResult result = syncService.syncLabelTasks(GRAPH_ID);
|
||||||
|
|
||||||
assertThat(result.getSyncType()).isEqualTo("LabelTask");
|
assertThat(result.getSyncType()).isEqualTo("LabelTask");
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -281,6 +296,7 @@ class GraphSyncServiceTest {
|
|||||||
SyncResult result = syncService.syncKnowledgeSets(GRAPH_ID);
|
SyncResult result = syncService.syncKnowledgeSets(GRAPH_ID);
|
||||||
|
|
||||||
assertThat(result.getSyncType()).isEqualTo("KnowledgeSet");
|
assertThat(result.getSyncType()).isEqualTo("KnowledgeSet");
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -291,6 +307,9 @@ class GraphSyncServiceTest {
|
|||||||
assertThatThrownBy(() -> syncService.syncWorkflows(GRAPH_ID))
|
assertThatThrownBy(() -> syncService.syncWorkflows(GRAPH_ID))
|
||||||
.isInstanceOf(BusinessException.class)
|
.isInstanceOf(BusinessException.class)
|
||||||
.hasMessageContaining("workflows");
|
.hasMessageContaining("workflows");
|
||||||
|
|
||||||
|
// P1 fix: 即使失败,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,7 +390,9 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("Field").build());
|
.thenReturn(SyncResult.builder().syncType("Field").build());
|
||||||
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("User").build());
|
.thenReturn(SyncResult.builder().syncType("User").build());
|
||||||
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyString()))
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Org").build());
|
.thenReturn(SyncResult.builder().syncType("Org").build());
|
||||||
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
||||||
@@ -387,7 +408,7 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
@@ -425,6 +446,9 @@ class GraphSyncServiceTest {
|
|||||||
SyncMetadata saved = captor.getValue();
|
SyncMetadata saved = captor.getValue();
|
||||||
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_SUCCESS);
|
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_SUCCESS);
|
||||||
assertThat(saved.getGraphId()).isEqualTo(GRAPH_ID);
|
assertThat(saved.getGraphId()).isEqualTo(GRAPH_ID);
|
||||||
|
|
||||||
|
// 验证缓存清除
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -450,7 +474,9 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("Field").build());
|
.thenReturn(SyncResult.builder().syncType("Field").build());
|
||||||
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("User").build());
|
.thenReturn(SyncResult.builder().syncType("User").build());
|
||||||
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyString()))
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Org").build());
|
.thenReturn(SyncResult.builder().syncType("Org").build());
|
||||||
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
||||||
@@ -466,7 +492,7 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
@@ -505,6 +531,9 @@ class GraphSyncServiceTest {
|
|||||||
assertThat(saved.getErrorMessage()).isNotNull();
|
assertThat(saved.getErrorMessage()).isNotNull();
|
||||||
assertThat(saved.getGraphId()).isEqualTo(GRAPH_ID);
|
assertThat(saved.getGraphId()).isEqualTo(GRAPH_ID);
|
||||||
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_FULL);
|
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_FULL);
|
||||||
|
|
||||||
|
// P1 fix: 即使失败,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -528,6 +557,8 @@ class GraphSyncServiceTest {
|
|||||||
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_SUCCESS);
|
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_SUCCESS);
|
||||||
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_DATASETS);
|
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_DATASETS);
|
||||||
assertThat(saved.getTotalCreated()).isEqualTo(1);
|
assertThat(saved.getTotalCreated()).isEqualTo(1);
|
||||||
|
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -543,6 +574,9 @@ class GraphSyncServiceTest {
|
|||||||
SyncMetadata saved = captor.getValue();
|
SyncMetadata saved = captor.getValue();
|
||||||
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_FAILED);
|
assertThat(saved.getStatus()).isEqualTo(SyncMetadata.STATUS_FAILED);
|
||||||
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_DATASETS);
|
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_DATASETS);
|
||||||
|
|
||||||
|
// P1 fix: 即使失败,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -637,6 +671,9 @@ class GraphSyncServiceTest {
|
|||||||
|
|
||||||
// 验证不执行 purge
|
// 验证不执行 purge
|
||||||
verify(stepService, never()).purgeStaleEntities(anyString(), anyString(), anySet(), anyString());
|
verify(stepService, never()).purgeStaleEntities(anyString(), anyString(), anySet(), anyString());
|
||||||
|
|
||||||
|
// 验证缓存清除
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -655,6 +692,9 @@ class GraphSyncServiceTest {
|
|||||||
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_INCREMENTAL);
|
assertThat(saved.getSyncType()).isEqualTo(SyncMetadata.TYPE_INCREMENTAL);
|
||||||
assertThat(saved.getUpdatedFrom()).isEqualTo(UPDATED_FROM);
|
assertThat(saved.getUpdatedFrom()).isEqualTo(UPDATED_FROM);
|
||||||
assertThat(saved.getUpdatedTo()).isEqualTo(UPDATED_TO);
|
assertThat(saved.getUpdatedTo()).isEqualTo(UPDATED_TO);
|
||||||
|
|
||||||
|
// P1 fix: 即使失败,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void stubAllEntityUpserts() {
|
private void stubAllEntityUpserts() {
|
||||||
@@ -664,7 +704,9 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("Field").build());
|
.thenReturn(SyncResult.builder().syncType("Field").build());
|
||||||
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("User").build());
|
.thenReturn(SyncResult.builder().syncType("User").build());
|
||||||
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyString()))
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Org").build());
|
.thenReturn(SyncResult.builder().syncType("Org").build());
|
||||||
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
||||||
@@ -682,7 +724,7 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
lenient().when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
lenient().when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
lenient().when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString()))
|
lenient().when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
lenient().when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
lenient().when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
||||||
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
@@ -704,7 +746,7 @@ class GraphSyncServiceTest {
|
|||||||
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
lenient().when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString(), any()))
|
lenient().when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString(), any()))
|
||||||
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
lenient().when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString(), any()))
|
lenient().when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString(), any()))
|
||||||
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
|
||||||
lenient().when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString(), any()))
|
lenient().when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString(), any()))
|
||||||
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
@@ -820,4 +862,148 @@ class GraphSyncServiceTest {
|
|||||||
.hasMessageContaining("分页偏移量");
|
.hasMessageContaining("分页偏移量");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 组织同步
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class OrgSyncTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void syncOrgs_fetchesUserOrgMapAndPassesToStepService() {
|
||||||
|
when(properties.getSync()).thenReturn(syncConfig);
|
||||||
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate", "alice", "三甲医院"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Org").created(3).build());
|
||||||
|
when(stepService.purgeStaleEntities(eq(GRAPH_ID), eq("Org"), anySet(), anyString()))
|
||||||
|
.thenReturn(0);
|
||||||
|
|
||||||
|
SyncResult result = syncService.syncOrgs(GRAPH_ID);
|
||||||
|
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("Org");
|
||||||
|
assertThat(result.getCreated()).isEqualTo(3);
|
||||||
|
verify(dataManagementClient).fetchUserOrganizationMap();
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ArgumentCaptor<Map<String, String>> mapCaptor = ArgumentCaptor.forClass(Map.class);
|
||||||
|
verify(stepService).upsertOrgEntities(eq(GRAPH_ID), mapCaptor.capture(), anyString());
|
||||||
|
assertThat(mapCaptor.getValue()).containsKeys("admin", "alice");
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void syncOrgs_fetchUserOrgMapFails_gracefulDegradation() {
|
||||||
|
when(properties.getSync()).thenReturn(syncConfig);
|
||||||
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenThrow(new RuntimeException("auth service down"));
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Org").created(1).build());
|
||||||
|
|
||||||
|
SyncResult result = syncService.syncOrgs(GRAPH_ID);
|
||||||
|
|
||||||
|
// 应优雅降级,使用空 map(仅创建未分配组织)
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("Org");
|
||||||
|
assertThat(result.getCreated()).isEqualTo(1);
|
||||||
|
// P0 fix: 降级时不执行 Org purge,防止误删已有组织节点
|
||||||
|
verify(stepService, never()).purgeStaleEntities(anyString(), eq("Org"), anySet(), anyString());
|
||||||
|
// 即使降级,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void syncAll_fetchUserOrgMapFails_skipsBelongsToRelationBuild() {
|
||||||
|
when(properties.getSync()).thenReturn(syncConfig);
|
||||||
|
|
||||||
|
DatasetDTO dto = new DatasetDTO();
|
||||||
|
dto.setId("ds-001");
|
||||||
|
dto.setName("Test");
|
||||||
|
dto.setCreatedBy("admin");
|
||||||
|
when(dataManagementClient.listAllDatasets()).thenReturn(List.of(dto));
|
||||||
|
when(dataManagementClient.listAllWorkflows()).thenReturn(List.of());
|
||||||
|
when(dataManagementClient.listAllJobs()).thenReturn(List.of());
|
||||||
|
when(dataManagementClient.listAllLabelTasks()).thenReturn(List.of());
|
||||||
|
when(dataManagementClient.listAllKnowledgeSets()).thenReturn(List.of());
|
||||||
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenThrow(new RuntimeException("auth service down"));
|
||||||
|
|
||||||
|
when(stepService.upsertDatasetEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Dataset").build());
|
||||||
|
when(stepService.upsertFieldEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Field").build());
|
||||||
|
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("User").build());
|
||||||
|
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Org").build());
|
||||||
|
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Workflow").build());
|
||||||
|
when(stepService.upsertJobEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("Job").build());
|
||||||
|
when(stepService.upsertLabelTaskEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("LabelTask").build());
|
||||||
|
when(stepService.upsertKnowledgeSetEntities(eq(GRAPH_ID), anyList(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("KnowledgeSet").build());
|
||||||
|
when(stepService.purgeStaleEntities(eq(GRAPH_ID), anyString(), anySet(), anyString()))
|
||||||
|
.thenReturn(0);
|
||||||
|
when(stepService.mergeHasFieldRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
|
||||||
|
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
|
||||||
|
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
|
||||||
|
when(stepService.mergeProducesRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("PRODUCES").build());
|
||||||
|
when(stepService.mergeAssignedToRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("ASSIGNED_TO").build());
|
||||||
|
when(stepService.mergeTriggersRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("TRIGGERS").build());
|
||||||
|
when(stepService.mergeDependsOnRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("DEPENDS_ON").build());
|
||||||
|
when(stepService.mergeImpactsRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("IMPACTS").build());
|
||||||
|
when(stepService.mergeSourcedFromRelations(eq(GRAPH_ID), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("SOURCED_FROM").build());
|
||||||
|
|
||||||
|
SyncMetadata metadata = syncService.syncAll(GRAPH_ID);
|
||||||
|
|
||||||
|
assertThat(metadata.getResults()).hasSize(18);
|
||||||
|
// BELONGS_TO merge must NOT be called when org map is degraded
|
||||||
|
verify(stepService, never()).mergeBelongsToRelations(anyString(), anyMap(), anyString());
|
||||||
|
// Org purge must also be skipped
|
||||||
|
verify(stepService, never()).purgeStaleEntities(anyString(), eq("Org"), anySet(), anyString());
|
||||||
|
// 验证缓存清除
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void buildBelongsToRelations_passesUserOrgMap() {
|
||||||
|
when(properties.getSync()).thenReturn(syncConfig);
|
||||||
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenReturn(Map.of("admin", "DataMate"));
|
||||||
|
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyMap(), anyString()))
|
||||||
|
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").created(2).build());
|
||||||
|
|
||||||
|
SyncResult result = syncService.buildBelongsToRelations(GRAPH_ID);
|
||||||
|
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("BELONGS_TO");
|
||||||
|
verify(dataManagementClient).fetchUserOrganizationMap();
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void buildBelongsToRelations_fetchDegraded_skipsRelationBuild() {
|
||||||
|
when(properties.getSync()).thenReturn(syncConfig);
|
||||||
|
when(dataManagementClient.fetchUserOrganizationMap())
|
||||||
|
.thenThrow(new RuntimeException("auth service down"));
|
||||||
|
|
||||||
|
SyncResult result = syncService.buildBelongsToRelations(GRAPH_ID);
|
||||||
|
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("BELONGS_TO");
|
||||||
|
// BELONGS_TO merge must NOT be called when degraded
|
||||||
|
verify(stepService, never()).mergeBelongsToRelations(anyString(), anyMap(), anyString());
|
||||||
|
// 即使降级,finally 块也会清除缓存
|
||||||
|
verify(cacheService).evictGraphCaches(GRAPH_ID);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -505,11 +505,12 @@ class GraphSyncStepServiceTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void mergeBelongsTo_noDefaultOrg_returnsError() {
|
void mergeBelongsTo_noOrgEntities_returnsError() {
|
||||||
when(entityRepository.findByGraphIdAndSourceIdAndType(GRAPH_ID, "org:default", "Org"))
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Org"))
|
||||||
.thenReturn(Optional.empty());
|
.thenReturn(List.of());
|
||||||
|
|
||||||
SyncResult result = stepService.mergeBelongsToRelations(GRAPH_ID, SYNC_ID);
|
Map<String, String> userOrgMap = Map.of("admin", "DataMate");
|
||||||
|
SyncResult result = stepService.mergeBelongsToRelations(GRAPH_ID, userOrgMap, SYNC_ID);
|
||||||
|
|
||||||
assertThat(result.getFailed()).isGreaterThan(0);
|
assertThat(result.getFailed()).isGreaterThan(0);
|
||||||
assertThat(result.getErrors()).contains("belongs_to:org_missing");
|
assertThat(result.getErrors()).contains("belongs_to:org_missing");
|
||||||
@@ -933,4 +934,151 @@ class GraphSyncStepServiceTest {
|
|||||||
verify(neo4jClient, times(1)).query(anyString());
|
verify(neo4jClient, times(1)).query(anyString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// upsertOrgEntities(多组织同步)
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class UpsertOrgEntitiesTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsert_multipleOrgs_createsEntityPerDistinctOrg() {
|
||||||
|
setupNeo4jQueryChain(Boolean.class, true);
|
||||||
|
|
||||||
|
Map<String, String> userOrgMap = new LinkedHashMap<>();
|
||||||
|
userOrgMap.put("admin", "DataMate");
|
||||||
|
userOrgMap.put("alice", "三甲医院");
|
||||||
|
userOrgMap.put("bob", null);
|
||||||
|
userOrgMap.put("carol", "DataMate"); // 重复
|
||||||
|
|
||||||
|
SyncResult result = stepService.upsertOrgEntities(GRAPH_ID, userOrgMap, SYNC_ID);
|
||||||
|
|
||||||
|
// 3 个去重组织: 未分配, DataMate, 三甲医院
|
||||||
|
assertThat(result.getCreated()).isEqualTo(3);
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("Org");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsert_emptyMap_createsOnlyDefaultOrg() {
|
||||||
|
setupNeo4jQueryChain(Boolean.class, true);
|
||||||
|
|
||||||
|
SyncResult result = stepService.upsertOrgEntities(
|
||||||
|
GRAPH_ID, Collections.emptyMap(), SYNC_ID);
|
||||||
|
|
||||||
|
assertThat(result.getCreated()).isEqualTo(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void upsert_allUsersHaveBlankOrg_createsOnlyDefaultOrg() {
|
||||||
|
setupNeo4jQueryChain(Boolean.class, true);
|
||||||
|
|
||||||
|
Map<String, String> userOrgMap = new LinkedHashMap<>();
|
||||||
|
userOrgMap.put("admin", "");
|
||||||
|
userOrgMap.put("alice", " ");
|
||||||
|
|
||||||
|
SyncResult result = stepService.upsertOrgEntities(GRAPH_ID, userOrgMap, SYNC_ID);
|
||||||
|
|
||||||
|
assertThat(result.getCreated()).isEqualTo(1); // 仅未分配
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// mergeBelongsToRelations(多组织映射)
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class MergeBelongsToWithRealOrgsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void mergeBelongsTo_usersMapToCorrectOrgs() {
|
||||||
|
setupNeo4jQueryChain(String.class, "new-rel-id");
|
||||||
|
|
||||||
|
GraphEntity orgDataMate = GraphEntity.builder()
|
||||||
|
.id("org-entity-dm").sourceId("org:DataMate").type("Org").graphId(GRAPH_ID).build();
|
||||||
|
GraphEntity orgUnassigned = GraphEntity.builder()
|
||||||
|
.id("org-entity-ua").sourceId("org:unassigned").type("Org").graphId(GRAPH_ID).build();
|
||||||
|
|
||||||
|
GraphEntity userAdmin = GraphEntity.builder()
|
||||||
|
.id("user-entity-admin").sourceId("user:admin").type("User").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("username", "admin")))
|
||||||
|
.build();
|
||||||
|
GraphEntity userBob = GraphEntity.builder()
|
||||||
|
.id("user-entity-bob").sourceId("user:bob").type("User").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("username", "bob")))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Org"))
|
||||||
|
.thenReturn(List.of(orgDataMate, orgUnassigned));
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "User"))
|
||||||
|
.thenReturn(List.of(userAdmin, userBob));
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset"))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
|
||||||
|
Map<String, String> userOrgMap = new HashMap<>();
|
||||||
|
userOrgMap.put("admin", "DataMate");
|
||||||
|
userOrgMap.put("bob", null);
|
||||||
|
|
||||||
|
SyncResult result = stepService.mergeBelongsToRelations(GRAPH_ID, userOrgMap, SYNC_ID);
|
||||||
|
|
||||||
|
assertThat(result.getSyncType()).isEqualTo("BELONGS_TO");
|
||||||
|
// 1 delete (cleanup old BELONGS_TO) + 2 merge (one per user)
|
||||||
|
verify(neo4jClient, times(3)).query(anyString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void mergeBelongsTo_datasetMappedToCreatorOrg() {
|
||||||
|
setupNeo4jQueryChain(String.class, "new-rel-id");
|
||||||
|
|
||||||
|
GraphEntity orgHospital = GraphEntity.builder()
|
||||||
|
.id("org-entity-hosp").sourceId("org:三甲医院").type("Org").graphId(GRAPH_ID).build();
|
||||||
|
GraphEntity orgUnassigned = GraphEntity.builder()
|
||||||
|
.id("org-entity-ua").sourceId("org:unassigned").type("Org").graphId(GRAPH_ID).build();
|
||||||
|
|
||||||
|
GraphEntity dataset = GraphEntity.builder()
|
||||||
|
.id("ds-entity-1").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "alice")))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Org"))
|
||||||
|
.thenReturn(List.of(orgHospital, orgUnassigned));
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "User"))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset"))
|
||||||
|
.thenReturn(List.of(dataset));
|
||||||
|
|
||||||
|
Map<String, String> userOrgMap = Map.of("alice", "三甲医院");
|
||||||
|
|
||||||
|
SyncResult result = stepService.mergeBelongsToRelations(GRAPH_ID, userOrgMap, SYNC_ID);
|
||||||
|
|
||||||
|
// 1 delete (cleanup old BELONGS_TO) + 1 merge (dataset → org)
|
||||||
|
verify(neo4jClient, times(2)).query(anyString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void mergeBelongsTo_unknownCreator_fallsBackToUnassigned() {
|
||||||
|
setupNeo4jQueryChain(String.class, "new-rel-id");
|
||||||
|
|
||||||
|
GraphEntity orgUnassigned = GraphEntity.builder()
|
||||||
|
.id("org-entity-ua").sourceId("org:unassigned").type("Org").graphId(GRAPH_ID).build();
|
||||||
|
|
||||||
|
GraphEntity dataset = GraphEntity.builder()
|
||||||
|
.id("ds-entity-1").type("Dataset").graphId(GRAPH_ID)
|
||||||
|
.properties(new HashMap<>(Map.of("created_by", "unknown_user")))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Org"))
|
||||||
|
.thenReturn(List.of(orgUnassigned));
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "User"))
|
||||||
|
.thenReturn(List.of());
|
||||||
|
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset"))
|
||||||
|
.thenReturn(List.of(dataset));
|
||||||
|
|
||||||
|
SyncResult result = stepService.mergeBelongsToRelations(
|
||||||
|
GRAPH_ID, Collections.emptyMap(), SYNC_ID);
|
||||||
|
|
||||||
|
// 1 delete (cleanup old BELONGS_TO) + 1 merge (dataset → unassigned)
|
||||||
|
verify(neo4jClient, times(2)).query(anyString());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package com.datamate.knowledgegraph.application;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class IndexHealthServiceTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Neo4jClient neo4jClient;
|
||||||
|
|
||||||
|
private IndexHealthService indexHealthService;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
indexHealthService = new IndexHealthService(neo4jClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void allIndexesOnline_empty_returns_false() {
|
||||||
|
// Neo4jClient mocking is complex; verify the logic conceptually
|
||||||
|
// When no indexes found, should return false
|
||||||
|
// This tests the service was correctly constructed
|
||||||
|
assertThat(indexHealthService).isNotNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void service_is_injectable() {
|
||||||
|
// Verify the service can be instantiated with a Neo4jClient
|
||||||
|
IndexHealthService service = new IndexHealthService(neo4jClient);
|
||||||
|
assertThat(service).isNotNull();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.cache;
|
||||||
|
|
||||||
|
import com.datamate.knowledgegraph.application.GraphEntityService;
|
||||||
|
import com.datamate.knowledgegraph.domain.model.GraphEntity;
|
||||||
|
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Nested;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.cache.CacheManager;
|
||||||
|
import org.springframework.cache.annotation.EnableCaching;
|
||||||
|
import org.springframework.cache.concurrent.ConcurrentMapCacheManager;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.test.context.ContextConfiguration;
|
||||||
|
import org.springframework.test.context.junit.jupiter.SpringExtension;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 集成测试:验证 @Cacheable 代理在 Spring 上下文中正确工作。
|
||||||
|
* <p>
|
||||||
|
* 使用 {@link ConcurrentMapCacheManager} 替代 Redis,验证:
|
||||||
|
* <ul>
|
||||||
|
* <li>缓存命中时不重复查询数据库</li>
|
||||||
|
* <li>缓存失效后重新查询数据库</li>
|
||||||
|
* <li>不同图谱的缓存独立</li>
|
||||||
|
* <li>不同用户上下文产生不同缓存 key(权限隔离)</li>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
@ExtendWith(SpringExtension.class)
|
||||||
|
@ContextConfiguration(classes = CacheableIntegrationTest.Config.class)
|
||||||
|
class CacheableIntegrationTest {
|
||||||
|
|
||||||
|
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
|
||||||
|
private static final String GRAPH_ID_2 = "660e8400-e29b-41d4-a716-446655440099";
|
||||||
|
private static final String ENTITY_ID = "660e8400-e29b-41d4-a716-446655440001";
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@EnableCaching
|
||||||
|
static class Config {
|
||||||
|
|
||||||
|
@Bean("knowledgeGraphCacheManager")
|
||||||
|
CacheManager knowledgeGraphCacheManager() {
|
||||||
|
return new ConcurrentMapCacheManager(
|
||||||
|
RedisCacheConfig.CACHE_ENTITIES,
|
||||||
|
RedisCacheConfig.CACHE_QUERIES,
|
||||||
|
RedisCacheConfig.CACHE_SEARCH
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
GraphEntityRepository entityRepository() {
|
||||||
|
return mock(GraphEntityRepository.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
KnowledgeGraphProperties properties() {
|
||||||
|
return mock(KnowledgeGraphProperties.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
GraphCacheService graphCacheService(CacheManager cacheManager) {
|
||||||
|
return new GraphCacheService(cacheManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
GraphEntityService graphEntityService(
|
||||||
|
GraphEntityRepository entityRepository,
|
||||||
|
KnowledgeGraphProperties properties,
|
||||||
|
GraphCacheService graphCacheService) {
|
||||||
|
return new GraphEntityService(entityRepository, properties, graphCacheService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private GraphEntityService entityService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private GraphEntityRepository entityRepository;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private CacheManager cacheManager;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private GraphCacheService graphCacheService;
|
||||||
|
|
||||||
|
private GraphEntity sampleEntity;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
sampleEntity = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID)
|
||||||
|
.name("TestDataset")
|
||||||
|
.type("Dataset")
|
||||||
|
.description("A test dataset")
|
||||||
|
.graphId(GRAPH_ID)
|
||||||
|
.confidence(1.0)
|
||||||
|
.createdAt(LocalDateTime.now())
|
||||||
|
.updatedAt(LocalDateTime.now())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
cacheManager.getCacheNames().forEach(name -> {
|
||||||
|
var cache = cacheManager.getCache(name);
|
||||||
|
if (cache != null) cache.clear();
|
||||||
|
});
|
||||||
|
reset(entityRepository);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// @Cacheable 代理行为
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class CacheProxyTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void getEntity_secondCall_returnsCachedResultWithoutHittingRepository() {
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sampleEntity));
|
||||||
|
|
||||||
|
GraphEntity first = entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
assertThat(first.getId()).isEqualTo(ENTITY_ID);
|
||||||
|
|
||||||
|
GraphEntity second = entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
assertThat(second.getId()).isEqualTo(ENTITY_ID);
|
||||||
|
|
||||||
|
verify(entityRepository, times(1)).findByIdAndGraphId(ENTITY_ID, GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listEntities_secondCall_returnsCachedResult() {
|
||||||
|
when(entityRepository.findByGraphId(GRAPH_ID))
|
||||||
|
.thenReturn(List.of(sampleEntity));
|
||||||
|
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
|
||||||
|
verify(entityRepository, times(1)).findByGraphId(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void differentGraphIds_produceSeparateCacheEntries() {
|
||||||
|
GraphEntity entity2 = GraphEntity.builder()
|
||||||
|
.id(ENTITY_ID).name("OtherDataset").type("Dataset")
|
||||||
|
.graphId(GRAPH_ID_2).confidence(1.0)
|
||||||
|
.createdAt(LocalDateTime.now()).updatedAt(LocalDateTime.now())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sampleEntity));
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID_2))
|
||||||
|
.thenReturn(Optional.of(entity2));
|
||||||
|
|
||||||
|
GraphEntity result1 = entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
GraphEntity result2 = entityService.getEntity(GRAPH_ID_2, ENTITY_ID);
|
||||||
|
|
||||||
|
assertThat(result1.getName()).isEqualTo("TestDataset");
|
||||||
|
assertThat(result2.getName()).isEqualTo("OtherDataset");
|
||||||
|
verify(entityRepository).findByIdAndGraphId(ENTITY_ID, GRAPH_ID);
|
||||||
|
verify(entityRepository).findByIdAndGraphId(ENTITY_ID, GRAPH_ID_2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 缓存失效行为
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class CacheEvictionTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictEntityCaches_causesNextCallToHitRepository() {
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sampleEntity));
|
||||||
|
|
||||||
|
entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
verify(entityRepository, times(1)).findByIdAndGraphId(ENTITY_ID, GRAPH_ID);
|
||||||
|
|
||||||
|
graphCacheService.evictEntityCaches(GRAPH_ID, ENTITY_ID);
|
||||||
|
|
||||||
|
entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
verify(entityRepository, times(2)).findByIdAndGraphId(ENTITY_ID, GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictEntityCaches_alsoEvictsListCache() {
|
||||||
|
when(entityRepository.findByGraphId(GRAPH_ID))
|
||||||
|
.thenReturn(List.of(sampleEntity));
|
||||||
|
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
verify(entityRepository, times(1)).findByGraphId(GRAPH_ID);
|
||||||
|
|
||||||
|
graphCacheService.evictEntityCaches(GRAPH_ID, ENTITY_ID);
|
||||||
|
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
verify(entityRepository, times(2)).findByGraphId(GRAPH_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_clearsAllCacheRegions() {
|
||||||
|
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
|
||||||
|
.thenReturn(Optional.of(sampleEntity));
|
||||||
|
when(entityRepository.findByGraphId(GRAPH_ID))
|
||||||
|
.thenReturn(List.of(sampleEntity));
|
||||||
|
|
||||||
|
entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
|
||||||
|
graphCacheService.evictGraphCaches(GRAPH_ID);
|
||||||
|
|
||||||
|
entityService.getEntity(GRAPH_ID, ENTITY_ID);
|
||||||
|
entityService.listEntities(GRAPH_ID);
|
||||||
|
verify(entityRepository, times(2)).findByIdAndGraphId(ENTITY_ID, GRAPH_ID);
|
||||||
|
verify(entityRepository, times(2)).findByGraphId(GRAPH_ID);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 权限隔离(缓存 key 级别验证)
|
||||||
|
//
|
||||||
|
// GraphQueryService 的 @Cacheable 使用 SpEL 表达式:
|
||||||
|
// @resourceAccessService.resolveOwnerFilterUserId()
|
||||||
|
// @resourceAccessService.canViewConfidential()
|
||||||
|
// 这些值最终传入 GraphCacheService.cacheKey() 生成 key。
|
||||||
|
// 以下测试验证不同用户上下文产生不同的缓存 key,
|
||||||
|
// 结合上方的代理测试,确保不同用户获得独立的缓存条目。
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class PermissionIsolationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void adminAndRegularUser_produceDifferentCacheKeys() {
|
||||||
|
String adminKey = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, null, true);
|
||||||
|
String userKey = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", false);
|
||||||
|
|
||||||
|
assertThat(adminKey).isNotEqualTo(userKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void differentUsers_produceDifferentCacheKeys() {
|
||||||
|
String userAKey = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", false);
|
||||||
|
String userBKey = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-b", false);
|
||||||
|
|
||||||
|
assertThat(userAKey).isNotEqualTo(userBKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void sameUserDifferentConfidentialAccess_produceDifferentCacheKeys() {
|
||||||
|
String withConfidential = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", true);
|
||||||
|
String withoutConfidential = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", false);
|
||||||
|
|
||||||
|
assertThat(withConfidential).isNotEqualTo(withoutConfidential);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void sameParametersAndUser_produceIdenticalCacheKeys() {
|
||||||
|
String key1 = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", false);
|
||||||
|
String key2 = GraphCacheService.cacheKey(
|
||||||
|
GRAPH_ID, "query", 0, 20, "user-a", false);
|
||||||
|
|
||||||
|
assertThat(key1).isEqualTo(key2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,273 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.cache;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Nested;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.springframework.cache.Cache;
|
||||||
|
import org.springframework.cache.CacheManager;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class GraphCacheServiceTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private CacheManager cacheManager;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private StringRedisTemplate redisTemplate;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Cache entityCache;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Cache queryCache;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Cache searchCache;
|
||||||
|
|
||||||
|
private GraphCacheService cacheService;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
cacheService = new GraphCacheService(cacheManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 退化模式(无 RedisTemplate):清空整个缓存区域
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class FallbackModeTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_withoutRedis_clearsAllCaches() {
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_ENTITIES)).thenReturn(entityCache);
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_QUERIES)).thenReturn(queryCache);
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_SEARCH)).thenReturn(searchCache);
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-id");
|
||||||
|
|
||||||
|
verify(entityCache).clear();
|
||||||
|
verify(queryCache).clear();
|
||||||
|
verify(searchCache).clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictEntityCaches_withoutRedis_evictsSpecificKeysAndClearsQueries() {
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_ENTITIES)).thenReturn(entityCache);
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_QUERIES)).thenReturn(queryCache);
|
||||||
|
|
||||||
|
cacheService.evictEntityCaches("graph-1", "entity-1");
|
||||||
|
|
||||||
|
// 精确失效两个 key
|
||||||
|
verify(entityCache).evict("graph-1:entity-1");
|
||||||
|
verify(entityCache).evict("graph-1:list");
|
||||||
|
// 查询缓存退化为清空(因无 Redis 做前缀匹配)
|
||||||
|
verify(queryCache).clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictSearchCaches_withGraphId_withoutRedis_clearsAll() {
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_SEARCH)).thenReturn(searchCache);
|
||||||
|
|
||||||
|
cacheService.evictSearchCaches("graph-1");
|
||||||
|
|
||||||
|
verify(searchCache).clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictSearchCaches_noArgs_clearsAll() {
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_SEARCH)).thenReturn(searchCache);
|
||||||
|
|
||||||
|
cacheService.evictSearchCaches();
|
||||||
|
|
||||||
|
verify(searchCache).clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_toleratesNullCache() {
|
||||||
|
when(cacheManager.getCache(anyString())).thenReturn(null);
|
||||||
|
|
||||||
|
// 不应抛出异常
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// 细粒度模式(有 RedisTemplate):按 graphId 前缀失效
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class FineGrainedModeTest {
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUpRedis() {
|
||||||
|
cacheService.setRedisTemplate(redisTemplate);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_withRedis_deletesKeysByGraphPrefix() {
|
||||||
|
Set<String> entityKeys = new HashSet<>(Set.of("datamate:kg:entities::graph-1:ent-1", "datamate:kg:entities::graph-1:list"));
|
||||||
|
Set<String> queryKeys = new HashSet<>(Set.of("datamate:kg:queries::graph-1:ent-1:2:100:null:true"));
|
||||||
|
Set<String> searchKeys = new HashSet<>(Set.of("datamate:kg:search::graph-1:keyword:0:20:null:true"));
|
||||||
|
|
||||||
|
when(redisTemplate.keys("datamate:kg:entities::graph-1:*")).thenReturn(entityKeys);
|
||||||
|
when(redisTemplate.keys("datamate:kg:queries::graph-1:*")).thenReturn(queryKeys);
|
||||||
|
when(redisTemplate.keys("datamate:kg:search::graph-1:*")).thenReturn(searchKeys);
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
|
||||||
|
verify(redisTemplate).delete(entityKeys);
|
||||||
|
verify(redisTemplate).delete(queryKeys);
|
||||||
|
verify(redisTemplate).delete(searchKeys);
|
||||||
|
// CacheManager.clear() should NOT be called
|
||||||
|
verify(cacheManager, never()).getCache(anyString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_withRedis_emptyKeysDoesNotCallDelete() {
|
||||||
|
when(redisTemplate.keys(anyString())).thenReturn(Set.of());
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
|
||||||
|
verify(redisTemplate, never()).delete(anyCollection());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_withRedis_nullKeysDoesNotCallDelete() {
|
||||||
|
when(redisTemplate.keys(anyString())).thenReturn(null);
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
|
||||||
|
verify(redisTemplate, never()).delete(anyCollection());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_redisException_fallsBackToClear() {
|
||||||
|
when(redisTemplate.keys(anyString())).thenThrow(new RuntimeException("Redis down"));
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_ENTITIES)).thenReturn(entityCache);
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_QUERIES)).thenReturn(queryCache);
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_SEARCH)).thenReturn(searchCache);
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
|
||||||
|
// 应退化为清空整个缓存
|
||||||
|
verify(entityCache).clear();
|
||||||
|
verify(queryCache).clear();
|
||||||
|
verify(searchCache).clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictEntityCaches_withRedis_evictsSpecificKeysAndQueriesByPrefix() {
|
||||||
|
when(cacheManager.getCache(RedisCacheConfig.CACHE_ENTITIES)).thenReturn(entityCache);
|
||||||
|
Set<String> queryKeys = new HashSet<>(Set.of("datamate:kg:queries::graph-1:ent-1:2:100:null:true"));
|
||||||
|
when(redisTemplate.keys("datamate:kg:queries::graph-1:*")).thenReturn(queryKeys);
|
||||||
|
|
||||||
|
cacheService.evictEntityCaches("graph-1", "entity-1");
|
||||||
|
|
||||||
|
// 精确失效实体缓存
|
||||||
|
verify(entityCache).evict("graph-1:entity-1");
|
||||||
|
verify(entityCache).evict("graph-1:list");
|
||||||
|
// 查询缓存按前缀失效
|
||||||
|
verify(redisTemplate).delete(queryKeys);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictSearchCaches_withRedis_deletesKeysByGraphPrefix() {
|
||||||
|
Set<String> searchKeys = new HashSet<>(Set.of("datamate:kg:search::graph-1:query:0:20:user1:false"));
|
||||||
|
when(redisTemplate.keys("datamate:kg:search::graph-1:*")).thenReturn(searchKeys);
|
||||||
|
|
||||||
|
cacheService.evictSearchCaches("graph-1");
|
||||||
|
|
||||||
|
verify(redisTemplate).delete(searchKeys);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void evictGraphCaches_isolatesGraphIds() {
|
||||||
|
// graph-1 的 key
|
||||||
|
Set<String> graph1Keys = new HashSet<>(Set.of("datamate:kg:entities::graph-1:ent-1"));
|
||||||
|
when(redisTemplate.keys("datamate:kg:entities::graph-1:*")).thenReturn(graph1Keys);
|
||||||
|
when(redisTemplate.keys("datamate:kg:queries::graph-1:*")).thenReturn(Set.of());
|
||||||
|
when(redisTemplate.keys("datamate:kg:search::graph-1:*")).thenReturn(Set.of());
|
||||||
|
|
||||||
|
cacheService.evictGraphCaches("graph-1");
|
||||||
|
|
||||||
|
// 仅删除 graph-1 的 key
|
||||||
|
verify(redisTemplate).delete(graph1Keys);
|
||||||
|
// 不应查询 graph-2 的 key
|
||||||
|
verify(redisTemplate, never()).keys(contains("graph-2"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// cacheKey 静态方法
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class CacheKeyTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_joinsPartsWithColon() {
|
||||||
|
String key = GraphCacheService.cacheKey("a", "b", "c");
|
||||||
|
assertThat(key).isEqualTo("a:b:c");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_handlesNullParts() {
|
||||||
|
String key = GraphCacheService.cacheKey("a", null, "c");
|
||||||
|
assertThat(key).isEqualTo("a:null:c");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_handlesSinglePart() {
|
||||||
|
String key = GraphCacheService.cacheKey("only");
|
||||||
|
assertThat(key).isEqualTo("only");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_handlesNumericParts() {
|
||||||
|
String key = GraphCacheService.cacheKey("graph", 42, 0, 20);
|
||||||
|
assertThat(key).isEqualTo("graph:42:0:20");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_withUserContext_differentUsersProduceDifferentKeys() {
|
||||||
|
String adminKey = GraphCacheService.cacheKey("graph-1", "query", 0, 20, null, true);
|
||||||
|
String userAKey = GraphCacheService.cacheKey("graph-1", "query", 0, 20, "user-a", false);
|
||||||
|
String userBKey = GraphCacheService.cacheKey("graph-1", "query", 0, 20, "user-b", false);
|
||||||
|
String userAConfKey = GraphCacheService.cacheKey("graph-1", "query", 0, 20, "user-a", true);
|
||||||
|
|
||||||
|
assertThat(adminKey).isNotEqualTo(userAKey);
|
||||||
|
assertThat(userAKey).isNotEqualTo(userBKey);
|
||||||
|
assertThat(userAKey).isNotEqualTo(userAConfKey);
|
||||||
|
|
||||||
|
// 相同参数应产生相同 key
|
||||||
|
String adminKey2 = GraphCacheService.cacheKey("graph-1", "query", 0, 20, null, true);
|
||||||
|
assertThat(adminKey).isEqualTo(adminKey2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_graphIdIsFirstSegment() {
|
||||||
|
String key = GraphCacheService.cacheKey("graph-123", "entity-456");
|
||||||
|
assertThat(key).startsWith("graph-123:");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void cacheKey_booleanParts() {
|
||||||
|
String keyTrue = GraphCacheService.cacheKey("g", "q", true);
|
||||||
|
String keyFalse = GraphCacheService.cacheKey("g", "q", false);
|
||||||
|
assertThat(keyTrue).isEqualTo("g:q:true");
|
||||||
|
assertThat(keyFalse).isEqualTo("g:q:false");
|
||||||
|
assertThat(keyTrue).isNotEqualTo(keyFalse);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,13 +1,11 @@
|
|||||||
package com.datamate.knowledgegraph.infrastructure.neo4j;
|
package com.datamate.knowledgegraph.infrastructure.neo4j;
|
||||||
|
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.migration.SchemaMigrationService;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
import org.springframework.boot.DefaultApplicationArguments;
|
import org.springframework.boot.DefaultApplicationArguments;
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient;
|
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient.UnboundRunnableSpec;
|
|
||||||
import org.springframework.data.neo4j.core.Neo4jClient.RunnableSpec;
|
|
||||||
import org.springframework.test.util.ReflectionTestUtils;
|
import org.springframework.test.util.ReflectionTestUtils;
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||||
@@ -19,13 +17,13 @@ import static org.mockito.Mockito.*;
|
|||||||
class GraphInitializerTest {
|
class GraphInitializerTest {
|
||||||
|
|
||||||
@Mock
|
@Mock
|
||||||
private Neo4jClient neo4jClient;
|
private SchemaMigrationService schemaMigrationService;
|
||||||
|
|
||||||
private GraphInitializer createInitializer(String password, String profile, boolean autoInit) {
|
private GraphInitializer createInitializer(String password, String profile, boolean autoInit) {
|
||||||
KnowledgeGraphProperties properties = new KnowledgeGraphProperties();
|
KnowledgeGraphProperties properties = new KnowledgeGraphProperties();
|
||||||
properties.getSync().setAutoInitSchema(autoInit);
|
properties.getSync().setAutoInitSchema(autoInit);
|
||||||
|
|
||||||
GraphInitializer initializer = new GraphInitializer(neo4jClient, properties);
|
GraphInitializer initializer = new GraphInitializer(properties, schemaMigrationService);
|
||||||
ReflectionTestUtils.setField(initializer, "neo4jPassword", password);
|
ReflectionTestUtils.setField(initializer, "neo4jPassword", password);
|
||||||
ReflectionTestUtils.setField(initializer, "activeProfile", profile);
|
ReflectionTestUtils.setField(initializer, "activeProfile", profile);
|
||||||
return initializer;
|
return initializer;
|
||||||
@@ -97,20 +95,16 @@ class GraphInitializerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
// Schema 初始化 — 成功
|
// Schema 初始化 — 委托给 SchemaMigrationService
|
||||||
// -----------------------------------------------------------------------
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void run_autoInitEnabled_executesAllStatements() {
|
void run_autoInitEnabled_delegatesToMigrationService() {
|
||||||
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
|
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
|
||||||
|
|
||||||
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
|
||||||
when(neo4jClient.query(anyString())).thenReturn(spec);
|
|
||||||
|
|
||||||
initializer.run(new DefaultApplicationArguments());
|
initializer.run(new DefaultApplicationArguments());
|
||||||
|
|
||||||
// Should execute all schema statements (constraints + indexes + fulltext)
|
verify(schemaMigrationService).migrate(anyString());
|
||||||
verify(neo4jClient, atLeast(10)).query(anyString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -119,39 +113,18 @@ class GraphInitializerTest {
|
|||||||
|
|
||||||
initializer.run(new DefaultApplicationArguments());
|
initializer.run(new DefaultApplicationArguments());
|
||||||
|
|
||||||
verifyNoInteractions(neo4jClient);
|
verifyNoInteractions(schemaMigrationService);
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
// P2-7: Schema 初始化错误处理
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void run_alreadyExistsError_safelyIgnored() {
|
|
||||||
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
|
|
||||||
|
|
||||||
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
|
||||||
when(neo4jClient.query(anyString())).thenReturn(spec);
|
|
||||||
doThrow(new RuntimeException("Constraint already exists"))
|
|
||||||
.when(spec).run();
|
|
||||||
|
|
||||||
// Should not throw — "already exists" errors are safely ignored
|
|
||||||
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
|
|
||||||
.doesNotThrowAnyException();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void run_nonExistenceError_throwsException() {
|
void run_migrationServiceThrows_propagatesException() {
|
||||||
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
|
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
|
||||||
|
|
||||||
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
doThrow(new RuntimeException("Migration failed"))
|
||||||
when(neo4jClient.query(anyString())).thenReturn(spec);
|
.when(schemaMigrationService).migrate(anyString());
|
||||||
doThrow(new RuntimeException("Connection refused to Neo4j"))
|
|
||||||
.when(spec).run();
|
|
||||||
|
|
||||||
// Non-"already exists" errors should propagate
|
|
||||||
assertThatThrownBy(() -> initializer.run(new DefaultApplicationArguments()))
|
assertThatThrownBy(() -> initializer.run(new DefaultApplicationArguments()))
|
||||||
.isInstanceOf(IllegalStateException.class)
|
.isInstanceOf(RuntimeException.class)
|
||||||
.hasMessageContaining("schema initialization failed");
|
.hasMessageContaining("Migration failed");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,578 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Nested;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient.RecordFetchSpec;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient.RunnableSpec;
|
||||||
|
import org.springframework.data.neo4j.core.Neo4jClient.UnboundRunnableSpec;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.*;
|
||||||
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class SchemaMigrationServiceTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private Neo4jClient neo4jClient;
|
||||||
|
|
||||||
|
private KnowledgeGraphProperties properties;
|
||||||
|
|
||||||
|
private SchemaMigration v1Migration;
|
||||||
|
private SchemaMigration v2Migration;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
properties = new KnowledgeGraphProperties();
|
||||||
|
|
||||||
|
v1Migration = new SchemaMigration() {
|
||||||
|
@Override
|
||||||
|
public int getVersion() { return 1; }
|
||||||
|
@Override
|
||||||
|
public String getDescription() { return "Initial schema"; }
|
||||||
|
@Override
|
||||||
|
public List<String> getStatements() {
|
||||||
|
return List.of("CREATE CONSTRAINT test1 IF NOT EXISTS FOR (n:Test) REQUIRE n.id IS UNIQUE");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
v2Migration = new SchemaMigration() {
|
||||||
|
@Override
|
||||||
|
public int getVersion() { return 2; }
|
||||||
|
@Override
|
||||||
|
public String getDescription() { return "Add index"; }
|
||||||
|
@Override
|
||||||
|
public List<String> getStatements() {
|
||||||
|
return List.of("CREATE INDEX test_name IF NOT EXISTS FOR (n:Test) ON (n.name)");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private SchemaMigrationService createService(List<SchemaMigration> migrations) {
|
||||||
|
return new SchemaMigrationService(neo4jClient, properties, migrations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a spy of the service with bootstrapMigrationSchema, acquireLock,
|
||||||
|
* releaseLock, and recordMigration stubbed out, and loadAppliedMigrations
|
||||||
|
* returning the given records.
|
||||||
|
*/
|
||||||
|
private SchemaMigrationService createSpiedService(List<SchemaMigration> migrations,
|
||||||
|
List<SchemaMigrationRecord> applied) {
|
||||||
|
SchemaMigrationService service = spy(createService(migrations));
|
||||||
|
doNothing().when(service).bootstrapMigrationSchema();
|
||||||
|
doNothing().when(service).acquireLock(anyString());
|
||||||
|
doNothing().when(service).releaseLock(anyString());
|
||||||
|
doReturn(applied).when(service).loadAppliedMigrations();
|
||||||
|
lenient().doNothing().when(service).recordMigration(any());
|
||||||
|
return service;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupQueryRunnable() {
|
||||||
|
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
||||||
|
when(neo4jClient.query(anyString())).thenReturn(spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SchemaMigrationRecord appliedRecord(SchemaMigration migration) {
|
||||||
|
return SchemaMigrationRecord.builder()
|
||||||
|
.version(migration.getVersion())
|
||||||
|
.description(migration.getDescription())
|
||||||
|
.checksum(SchemaMigrationService.computeChecksum(migration.getStatements()))
|
||||||
|
.appliedAt("2025-01-01T00:00:00Z")
|
||||||
|
.executionTimeMs(100L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(migration.getStatements().size())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Migration Disabled
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class MigrationDisabled {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_whenDisabled_skipsEverything() {
|
||||||
|
properties.getMigration().setEnabled(false);
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
verifyNoInteractions(neo4jClient);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Fresh Database
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class FreshDatabase {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_freshDb_appliesAllMigrations() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
setupQueryRunnable();
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
// Verify migration statement was executed
|
||||||
|
verify(neo4jClient).query(contains("test1"));
|
||||||
|
// Verify migration record was created
|
||||||
|
verify(service).recordMigration(argThat(r -> r.getVersion() == 1 && r.isSuccess()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_freshDb_bootstrapConstraintsCreated() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
setupQueryRunnable();
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
// Verify bootstrap, lock acquisition, and release were called
|
||||||
|
verify(service).bootstrapMigrationSchema();
|
||||||
|
verify(service).acquireLock("test-instance");
|
||||||
|
verify(service).releaseLock("test-instance");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Partially Applied
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class PartiallyApplied {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_v1Applied_onlyExecutesPending() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration, v2Migration), List.of(appliedRecord(v1Migration)));
|
||||||
|
setupQueryRunnable();
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
// V1 statement should NOT be executed
|
||||||
|
verify(neo4jClient, never()).query(contains("test1"));
|
||||||
|
// V2 statement should be executed
|
||||||
|
verify(neo4jClient).query(contains("test_name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_allApplied_noop() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), List.of(appliedRecord(v1Migration)));
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
// No migration statements should be executed
|
||||||
|
verifyNoInteractions(neo4jClient);
|
||||||
|
// recordMigration should NOT be called (only the stubbed setup, no real call)
|
||||||
|
verify(service, never()).recordMigration(any());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Checksum Validation
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class ChecksumValidation {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_checksumMismatch_throwsException() {
|
||||||
|
SchemaMigrationRecord tampered = SchemaMigrationRecord.builder()
|
||||||
|
.version(1)
|
||||||
|
.description("Initial schema")
|
||||||
|
.checksum("wrong-checksum")
|
||||||
|
.appliedAt("2025-01-01T00:00:00Z")
|
||||||
|
.executionTimeMs(100L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), List.of(tampered));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.migrate("test-instance"))
|
||||||
|
.isInstanceOf(BusinessException.class)
|
||||||
|
.satisfies(e -> assertThat(((BusinessException) e).getErrorCodeEnum())
|
||||||
|
.isEqualTo(KnowledgeGraphErrorCode.SCHEMA_CHECKSUM_MISMATCH));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_checksumValidationDisabled_skipsCheck() {
|
||||||
|
properties.getMigration().setValidateChecksums(false);
|
||||||
|
|
||||||
|
SchemaMigrationRecord tampered = SchemaMigrationRecord.builder()
|
||||||
|
.version(1)
|
||||||
|
.description("Initial schema")
|
||||||
|
.checksum("wrong-checksum")
|
||||||
|
.appliedAt("2025-01-01T00:00:00Z")
|
||||||
|
.executionTimeMs(100L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), List.of(tampered));
|
||||||
|
|
||||||
|
// Should NOT throw even with wrong checksum — all applied, no pending
|
||||||
|
assertThatCode(() -> service.migrate("test-instance"))
|
||||||
|
.doesNotThrowAnyException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_emptyChecksum_skipsValidation() {
|
||||||
|
SchemaMigrationRecord legacyRecord = SchemaMigrationRecord.builder()
|
||||||
|
.version(1)
|
||||||
|
.description("Initial schema")
|
||||||
|
.checksum("") // empty checksum from legacy/repaired node
|
||||||
|
.appliedAt("")
|
||||||
|
.executionTimeMs(0L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), List.of(legacyRecord));
|
||||||
|
|
||||||
|
// Should NOT throw — empty checksum is skipped, and V1 is treated as applied
|
||||||
|
assertThatCode(() -> service.migrate("test-instance"))
|
||||||
|
.doesNotThrowAnyException();
|
||||||
|
|
||||||
|
// V1 should NOT be re-executed (it's in the applied set)
|
||||||
|
verify(neo4jClient, never()).query(contains("test1"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Lock Management
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class LockManagement {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_lockAcquired_executesAndReleases() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
setupQueryRunnable();
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
var inOrder = inOrder(service);
|
||||||
|
inOrder.verify(service).acquireLock("test-instance");
|
||||||
|
inOrder.verify(service).releaseLock("test-instance");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@Test
|
||||||
|
void migrate_lockHeldByAnother_throwsException() {
|
||||||
|
SchemaMigrationService service = spy(createService(List.of(v1Migration)));
|
||||||
|
doNothing().when(service).bootstrapMigrationSchema();
|
||||||
|
|
||||||
|
// Let acquireLock run for real — mock neo4jClient for lock query
|
||||||
|
UnboundRunnableSpec lockSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
RunnableSpec runnableSpec = mock(RunnableSpec.class);
|
||||||
|
RecordFetchSpec<Map<String, Object>> fetchSpec = mock(RecordFetchSpec.class);
|
||||||
|
|
||||||
|
when(neo4jClient.query(contains("MERGE (lock:_SchemaLock"))).thenReturn(lockSpec);
|
||||||
|
when(lockSpec.bindAll(anyMap())).thenReturn(runnableSpec);
|
||||||
|
when(runnableSpec.fetch()).thenReturn(fetchSpec);
|
||||||
|
when(fetchSpec.first()).thenReturn(Optional.of(Map.of(
|
||||||
|
"lockedBy", "other-instance",
|
||||||
|
"canAcquire", false
|
||||||
|
)));
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.migrate("test-instance"))
|
||||||
|
.isInstanceOf(BusinessException.class)
|
||||||
|
.satisfies(e -> assertThat(((BusinessException) e).getErrorCodeEnum())
|
||||||
|
.isEqualTo(KnowledgeGraphErrorCode.SCHEMA_MIGRATION_LOCKED));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_lockReleasedOnFailure() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
|
||||||
|
// Make migration statement fail
|
||||||
|
UnboundRunnableSpec failSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
when(neo4jClient.query(anyString())).thenReturn(failSpec);
|
||||||
|
doThrow(new RuntimeException("Connection refused"))
|
||||||
|
.when(failSpec).run();
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.migrate("test-instance"))
|
||||||
|
.isInstanceOf(BusinessException.class);
|
||||||
|
|
||||||
|
// Lock should still be released even after failure
|
||||||
|
verify(service).releaseLock("test-instance");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Migration Failure
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class MigrationFailure {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_statementFails_recordsFailureAndThrows() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
|
||||||
|
// Make migration statement fail
|
||||||
|
UnboundRunnableSpec failSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
when(neo4jClient.query(anyString())).thenReturn(failSpec);
|
||||||
|
doThrow(new RuntimeException("Connection refused"))
|
||||||
|
.when(failSpec).run();
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.migrate("test-instance"))
|
||||||
|
.isInstanceOf(BusinessException.class)
|
||||||
|
.satisfies(e -> assertThat(((BusinessException) e).getErrorCodeEnum())
|
||||||
|
.isEqualTo(KnowledgeGraphErrorCode.SCHEMA_MIGRATION_FAILED));
|
||||||
|
|
||||||
|
// Failure should be recorded
|
||||||
|
verify(service).recordMigration(argThat(r -> !r.isSuccess()
|
||||||
|
&& r.getErrorMessage() != null
|
||||||
|
&& r.getErrorMessage().contains("Connection refused")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_alreadyExistsError_safelySkipped() {
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
|
||||||
|
// Make migration statement throw "already exists"
|
||||||
|
UnboundRunnableSpec existsSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
when(neo4jClient.query(anyString())).thenReturn(existsSpec);
|
||||||
|
doThrow(new RuntimeException("Constraint already exists"))
|
||||||
|
.when(existsSpec).run();
|
||||||
|
|
||||||
|
// Should not throw
|
||||||
|
assertThatCode(() -> service.migrate("test-instance"))
|
||||||
|
.doesNotThrowAnyException();
|
||||||
|
|
||||||
|
// Success should be recorded
|
||||||
|
verify(service).recordMigration(argThat(r -> r.isSuccess() && r.getVersion() == 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Retry After Failure (P0)
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class RetryAfterFailure {
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@Test
|
||||||
|
void recordMigration_usesMerge_allowsRetryAfterFailure() {
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
UnboundRunnableSpec unboundSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
RunnableSpec runnableSpec = mock(RunnableSpec.class);
|
||||||
|
when(neo4jClient.query(contains("MERGE"))).thenReturn(unboundSpec);
|
||||||
|
when(unboundSpec.bindAll(anyMap())).thenReturn(runnableSpec);
|
||||||
|
|
||||||
|
SchemaMigrationRecord record = SchemaMigrationRecord.builder()
|
||||||
|
.version(1)
|
||||||
|
.description("test")
|
||||||
|
.checksum("abc123")
|
||||||
|
.appliedAt("2025-01-01T00:00:00Z")
|
||||||
|
.executionTimeMs(100L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
service.recordMigration(record);
|
||||||
|
|
||||||
|
// Verify MERGE is used (not CREATE) — ensures retries update
|
||||||
|
// existing failed records instead of hitting unique constraint violations
|
||||||
|
verify(neo4jClient).query(contains("MERGE"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings({"unchecked", "rawtypes"})
|
||||||
|
@Test
|
||||||
|
void recordMigration_nullErrorMessage_boundAsEmptyString() {
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
UnboundRunnableSpec unboundSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
RunnableSpec runnableSpec = mock(RunnableSpec.class);
|
||||||
|
when(neo4jClient.query(contains("MERGE"))).thenReturn(unboundSpec);
|
||||||
|
when(unboundSpec.bindAll(anyMap())).thenReturn(runnableSpec);
|
||||||
|
|
||||||
|
SchemaMigrationRecord record = SchemaMigrationRecord.builder()
|
||||||
|
.version(1)
|
||||||
|
.description("test")
|
||||||
|
.checksum("abc123")
|
||||||
|
.appliedAt("2025-01-01T00:00:00Z")
|
||||||
|
.executionTimeMs(100L)
|
||||||
|
.success(true)
|
||||||
|
.statementsCount(1)
|
||||||
|
// errorMessage intentionally not set (null)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
service.recordMigration(record);
|
||||||
|
|
||||||
|
ArgumentCaptor<Map> paramsCaptor = ArgumentCaptor.forClass(Map.class);
|
||||||
|
verify(unboundSpec).bindAll(paramsCaptor.capture());
|
||||||
|
Map<String, Object> params = paramsCaptor.getValue();
|
||||||
|
|
||||||
|
// All String params must be non-null to avoid Neo4j driver issues
|
||||||
|
assertThat(params.get("errorMessage")).isEqualTo("");
|
||||||
|
assertThat(params.get("description")).isEqualTo("test");
|
||||||
|
assertThat(params.get("checksum")).isEqualTo("abc123");
|
||||||
|
assertThat(params.get("appliedAt")).isEqualTo("2025-01-01T00:00:00Z");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void migrate_retryAfterFailure_recordsSuccess() {
|
||||||
|
// Simulate: first run recorded a failure, second run should succeed.
|
||||||
|
// loadAppliedMigrations only returns success=true, so failed V1 won't be in applied set.
|
||||||
|
SchemaMigrationService service = createSpiedService(
|
||||||
|
List.of(v1Migration), Collections.emptyList());
|
||||||
|
setupQueryRunnable();
|
||||||
|
|
||||||
|
service.migrate("test-instance");
|
||||||
|
|
||||||
|
// Verify success record is written (MERGE will update existing failed record)
|
||||||
|
verify(service).recordMigration(argThat(r -> r.isSuccess() && r.getVersion() == 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Database Time for Lock (P1-1)
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class DatabaseTimeLock {
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@Test
|
||||||
|
void acquireLock_usesDatabaseTime_notLocalTime() {
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
UnboundRunnableSpec lockSpec = mock(UnboundRunnableSpec.class);
|
||||||
|
RunnableSpec runnableSpec = mock(RunnableSpec.class);
|
||||||
|
RecordFetchSpec<Map<String, Object>> fetchSpec = mock(RecordFetchSpec.class);
|
||||||
|
|
||||||
|
when(neo4jClient.query(contains("MERGE (lock:_SchemaLock"))).thenReturn(lockSpec);
|
||||||
|
when(lockSpec.bindAll(anyMap())).thenReturn(runnableSpec);
|
||||||
|
when(runnableSpec.fetch()).thenReturn(fetchSpec);
|
||||||
|
when(fetchSpec.first()).thenReturn(Optional.of(Map.of(
|
||||||
|
"lockedBy", "test-instance",
|
||||||
|
"canAcquire", true
|
||||||
|
)));
|
||||||
|
|
||||||
|
service.acquireLock("test-instance");
|
||||||
|
|
||||||
|
// Verify that local time is NOT passed as parameters — database time is used instead
|
||||||
|
@SuppressWarnings("rawtypes")
|
||||||
|
ArgumentCaptor<Map> paramsCaptor = ArgumentCaptor.forClass(Map.class);
|
||||||
|
verify(lockSpec).bindAll(paramsCaptor.capture());
|
||||||
|
Map<String, Object> params = paramsCaptor.getValue();
|
||||||
|
assertThat(params).containsKey("instanceId");
|
||||||
|
assertThat(params).containsKey("timeoutMs");
|
||||||
|
assertThat(params).doesNotContainKey("now");
|
||||||
|
assertThat(params).doesNotContainKey("expiry");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Checksum Computation
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class ChecksumComputation {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void computeChecksum_deterministic() {
|
||||||
|
List<String> statements = List.of("stmt1", "stmt2");
|
||||||
|
String checksum1 = SchemaMigrationService.computeChecksum(statements);
|
||||||
|
String checksum2 = SchemaMigrationService.computeChecksum(statements);
|
||||||
|
|
||||||
|
assertThat(checksum1).isEqualTo(checksum2);
|
||||||
|
assertThat(checksum1).hasSize(64); // SHA-256 hex length
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void computeChecksum_orderMatters() {
|
||||||
|
String checksum1 = SchemaMigrationService.computeChecksum(List.of("stmt1", "stmt2"));
|
||||||
|
String checksum2 = SchemaMigrationService.computeChecksum(List.of("stmt2", "stmt1"));
|
||||||
|
|
||||||
|
assertThat(checksum1).isNotEqualTo(checksum2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Bootstrap Repair
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class BootstrapRepair {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void bootstrapMigrationSchema_executesRepairQuery() {
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
||||||
|
when(neo4jClient.query(anyString())).thenReturn(spec);
|
||||||
|
|
||||||
|
service.bootstrapMigrationSchema();
|
||||||
|
|
||||||
|
// Verify 3 queries: 2 constraints + 1 repair
|
||||||
|
verify(neo4jClient, times(3)).query(anyString());
|
||||||
|
// Verify repair query targets nodes with missing properties
|
||||||
|
verify(neo4jClient).query(contains("m.description IS NULL OR m.checksum IS NULL"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Load Applied Migrations Query
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class LoadAppliedMigrationsQuery {
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
@Test
|
||||||
|
void loadAppliedMigrations_usesCoalesceInQuery() {
|
||||||
|
SchemaMigrationService service = createService(List.of(v1Migration));
|
||||||
|
|
||||||
|
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
|
||||||
|
RecordFetchSpec<Map<String, Object>> fetchSpec = mock(RecordFetchSpec.class);
|
||||||
|
when(neo4jClient.query(contains("COALESCE"))).thenReturn(spec);
|
||||||
|
when(spec.fetch()).thenReturn(fetchSpec);
|
||||||
|
when(fetchSpec.all()).thenReturn(Collections.emptyList());
|
||||||
|
|
||||||
|
service.loadAppliedMigrations();
|
||||||
|
|
||||||
|
// Verify COALESCE is used for all optional properties
|
||||||
|
ArgumentCaptor<String> queryCaptor = ArgumentCaptor.forClass(String.class);
|
||||||
|
verify(neo4jClient).query(queryCaptor.capture());
|
||||||
|
String capturedQuery = queryCaptor.getValue();
|
||||||
|
assertThat(capturedQuery)
|
||||||
|
.contains("COALESCE(m.description, '')")
|
||||||
|
.contains("COALESCE(m.checksum, '')")
|
||||||
|
.contains("COALESCE(m.applied_at, '')")
|
||||||
|
.contains("COALESCE(m.execution_time_ms, 0)")
|
||||||
|
.contains("COALESCE(m.statements_count, 0)")
|
||||||
|
.contains("COALESCE(m.error_message, '')");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package com.datamate.knowledgegraph.infrastructure.neo4j.migration;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
class V2__PerformanceIndexesTest {
|
||||||
|
|
||||||
|
private final V2__PerformanceIndexes migration = new V2__PerformanceIndexes();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void version_is_2() {
|
||||||
|
assertThat(migration.getVersion()).isEqualTo(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void description_is_not_empty() {
|
||||||
|
assertThat(migration.getDescription()).isNotBlank();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void statements_are_not_empty() {
|
||||||
|
List<String> statements = migration.getStatements();
|
||||||
|
assertThat(statements).isNotEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void all_statements_use_if_not_exists() {
|
||||||
|
for (String stmt : migration.getStatements()) {
|
||||||
|
assertThat(stmt).containsIgnoringCase("IF NOT EXISTS");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void contains_relationship_index() {
|
||||||
|
List<String> statements = migration.getStatements();
|
||||||
|
boolean hasRelIndex = statements.stream()
|
||||||
|
.anyMatch(s -> s.contains("RELATED_TO") && s.contains("graph_id"));
|
||||||
|
assertThat(hasRelIndex).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void contains_updated_at_index() {
|
||||||
|
List<String> statements = migration.getStatements();
|
||||||
|
boolean hasUpdatedAt = statements.stream()
|
||||||
|
.anyMatch(s -> s.contains("updated_at"));
|
||||||
|
assertThat(hasUpdatedAt).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void contains_composite_graph_id_name_index() {
|
||||||
|
List<String> statements = migration.getStatements();
|
||||||
|
boolean hasComposite = statements.stream()
|
||||||
|
.anyMatch(s -> s.contains("graph_id") && s.contains("n.name"));
|
||||||
|
assertThat(hasComposite).isTrue();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
package com.datamate.knowledgegraph.interfaces.rest;
|
||||||
|
|
||||||
|
import com.datamate.common.infrastructure.exception.BusinessException;
|
||||||
|
import com.datamate.common.interfaces.PagedResponse;
|
||||||
|
import com.datamate.knowledgegraph.application.EditReviewService;
|
||||||
|
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
|
||||||
|
import com.datamate.knowledgegraph.interfaces.dto.EditReviewVO;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.InjectMocks;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.test.web.servlet.MockMvc;
|
||||||
|
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
||||||
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
||||||
|
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
|
||||||
|
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
|
||||||
|
|
||||||
|
@ExtendWith(MockitoExtension.class)
|
||||||
|
class EditReviewControllerTest {
|
||||||
|
|
||||||
|
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
|
||||||
|
private static final String REVIEW_ID = "660e8400-e29b-41d4-a716-446655440001";
|
||||||
|
private static final String ENTITY_ID = "770e8400-e29b-41d4-a716-446655440002";
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
private EditReviewService reviewService;
|
||||||
|
|
||||||
|
@InjectMocks
|
||||||
|
private EditReviewController controller;
|
||||||
|
|
||||||
|
private MockMvc mockMvc;
|
||||||
|
private ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
mockMvc = MockMvcBuilders.standaloneSetup(controller).build();
|
||||||
|
objectMapper = new ObjectMapper();
|
||||||
|
objectMapper.registerModule(new JavaTimeModule());
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// POST /knowledge-graph/{graphId}/review/submit
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_success() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("PENDING");
|
||||||
|
when(reviewService.submitReview(eq(GRAPH_ID), any(), eq("user-1")))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/submit", GRAPH_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.header("X-User-Id", "user-1")
|
||||||
|
.content(objectMapper.writeValueAsString(Map.of(
|
||||||
|
"operationType", "CREATE_ENTITY",
|
||||||
|
"payload", "{\"name\":\"Test\",\"type\":\"Dataset\"}"
|
||||||
|
))))
|
||||||
|
.andExpect(status().isCreated())
|
||||||
|
.andExpect(jsonPath("$.id").value(REVIEW_ID))
|
||||||
|
.andExpect(jsonPath("$.status").value("PENDING"))
|
||||||
|
.andExpect(jsonPath("$.operationType").value("CREATE_ENTITY"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_delegatesToService() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("PENDING");
|
||||||
|
when(reviewService.submitReview(eq(GRAPH_ID), any(), eq("user-1")))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/submit", GRAPH_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.header("X-User-Id", "user-1")
|
||||||
|
.content(objectMapper.writeValueAsString(Map.of(
|
||||||
|
"operationType", "DELETE_ENTITY",
|
||||||
|
"entityId", ENTITY_ID
|
||||||
|
))))
|
||||||
|
.andExpect(status().isCreated());
|
||||||
|
|
||||||
|
verify(reviewService).submitReview(eq(GRAPH_ID), any(), eq("user-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void submitReview_defaultUserId_whenHeaderMissing() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("PENDING");
|
||||||
|
when(reviewService.submitReview(eq(GRAPH_ID), any(), eq("anonymous")))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/submit", GRAPH_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.content(objectMapper.writeValueAsString(Map.of(
|
||||||
|
"operationType", "CREATE_ENTITY",
|
||||||
|
"payload", "{\"name\":\"Test\"}"
|
||||||
|
))))
|
||||||
|
.andExpect(status().isCreated());
|
||||||
|
|
||||||
|
verify(reviewService).submitReview(eq(GRAPH_ID), any(), eq("anonymous"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// POST /knowledge-graph/{graphId}/review/{reviewId}/approve
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_success() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("APPROVED");
|
||||||
|
when(reviewService.approveReview(eq(GRAPH_ID), eq(REVIEW_ID), eq("reviewer-1"), isNull()))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/{reviewId}/approve", GRAPH_ID, REVIEW_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.header("X-User-Id", "reviewer-1"))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.status").value("APPROVED"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void approveReview_withComment() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("APPROVED");
|
||||||
|
when(reviewService.approveReview(eq(GRAPH_ID), eq(REVIEW_ID), eq("reviewer-1"), eq("LGTM")))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/{reviewId}/approve", GRAPH_ID, REVIEW_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.header("X-User-Id", "reviewer-1")
|
||||||
|
.content(objectMapper.writeValueAsString(Map.of("comment", "LGTM"))))
|
||||||
|
.andExpect(status().isOk());
|
||||||
|
|
||||||
|
verify(reviewService).approveReview(GRAPH_ID, REVIEW_ID, "reviewer-1", "LGTM");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// POST /knowledge-graph/{graphId}/review/{reviewId}/reject
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rejectReview_success() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("REJECTED");
|
||||||
|
when(reviewService.rejectReview(eq(GRAPH_ID), eq(REVIEW_ID), eq("reviewer-1"), eq("不合适")))
|
||||||
|
.thenReturn(vo);
|
||||||
|
|
||||||
|
mockMvc.perform(post("/knowledge-graph/{graphId}/review/{reviewId}/reject", GRAPH_ID, REVIEW_ID)
|
||||||
|
.contentType(MediaType.APPLICATION_JSON)
|
||||||
|
.header("X-User-Id", "reviewer-1")
|
||||||
|
.content(objectMapper.writeValueAsString(Map.of("comment", "不合适"))))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.status").value("REJECTED"));
|
||||||
|
|
||||||
|
verify(reviewService).rejectReview(GRAPH_ID, REVIEW_ID, "reviewer-1", "不合适");
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// GET /knowledge-graph/{graphId}/review/pending
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPendingReviews_success() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("PENDING");
|
||||||
|
PagedResponse<EditReviewVO> page = PagedResponse.of(List.of(vo), 0, 1, 1);
|
||||||
|
when(reviewService.listPendingReviews(GRAPH_ID, 0, 20)).thenReturn(page);
|
||||||
|
|
||||||
|
mockMvc.perform(get("/knowledge-graph/{graphId}/review/pending", GRAPH_ID))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.content").isArray())
|
||||||
|
.andExpect(jsonPath("$.content[0].id").value(REVIEW_ID))
|
||||||
|
.andExpect(jsonPath("$.totalElements").value(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPendingReviews_customPageSize() throws Exception {
|
||||||
|
PagedResponse<EditReviewVO> page = PagedResponse.of(List.of(), 0, 0, 0);
|
||||||
|
when(reviewService.listPendingReviews(GRAPH_ID, 1, 10)).thenReturn(page);
|
||||||
|
|
||||||
|
mockMvc.perform(get("/knowledge-graph/{graphId}/review/pending", GRAPH_ID)
|
||||||
|
.param("page", "1")
|
||||||
|
.param("size", "10"))
|
||||||
|
.andExpect(status().isOk());
|
||||||
|
|
||||||
|
verify(reviewService).listPendingReviews(GRAPH_ID, 1, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// GET /knowledge-graph/{graphId}/review
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listReviews_withStatusFilter() throws Exception {
|
||||||
|
PagedResponse<EditReviewVO> page = PagedResponse.of(List.of(), 0, 0, 0);
|
||||||
|
when(reviewService.listReviews(GRAPH_ID, "APPROVED", 0, 20)).thenReturn(page);
|
||||||
|
|
||||||
|
mockMvc.perform(get("/knowledge-graph/{graphId}/review", GRAPH_ID)
|
||||||
|
.param("status", "APPROVED"))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.content").isEmpty());
|
||||||
|
|
||||||
|
verify(reviewService).listReviews(GRAPH_ID, "APPROVED", 0, 20);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listReviews_withoutStatusFilter() throws Exception {
|
||||||
|
EditReviewVO vo = buildReviewVO("PENDING");
|
||||||
|
PagedResponse<EditReviewVO> page = PagedResponse.of(List.of(vo), 0, 1, 1);
|
||||||
|
when(reviewService.listReviews(GRAPH_ID, null, 0, 20)).thenReturn(page);
|
||||||
|
|
||||||
|
mockMvc.perform(get("/knowledge-graph/{graphId}/review", GRAPH_ID))
|
||||||
|
.andExpect(status().isOk())
|
||||||
|
.andExpect(jsonPath("$.content").isArray())
|
||||||
|
.andExpect(jsonPath("$.content[0].id").value(REVIEW_ID));
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
private EditReviewVO buildReviewVO(String status) {
|
||||||
|
return EditReviewVO.builder()
|
||||||
|
.id(REVIEW_ID)
|
||||||
|
.graphId(GRAPH_ID)
|
||||||
|
.operationType("CREATE_ENTITY")
|
||||||
|
.payload("{\"name\":\"Test\",\"type\":\"Dataset\"}")
|
||||||
|
.status(status)
|
||||||
|
.submittedBy("user-1")
|
||||||
|
.createdAt(LocalDateTime.now())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -110,6 +110,17 @@ public class AuthApplicationService {
|
|||||||
return responses;
|
return responses;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回所有用户的用户名与组织映射,供内部同步服务使用。
|
||||||
|
*/
|
||||||
|
public List<UserOrgMapping> listUserOrganizations() {
|
||||||
|
return authMapper.listUsers().stream()
|
||||||
|
.map(u -> new UserOrgMapping(u.getUsername(), u.getOrganization()))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public record UserOrgMapping(String username, String organization) {}
|
||||||
|
|
||||||
public List<AuthRoleInfo> listRoles() {
|
public List<AuthRoleInfo> listRoles() {
|
||||||
return authMapper.listRoles();
|
return authMapper.listRoles();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,5 +14,6 @@ public class AuthUserSummary {
|
|||||||
private String email;
|
private String email;
|
||||||
private String fullName;
|
private String fullName;
|
||||||
private Boolean enabled;
|
private Boolean enabled;
|
||||||
|
private String organization;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,14 @@ public class AuthController {
|
|||||||
return authApplicationService.listUsersWithRoles();
|
return authApplicationService.listUsersWithRoles();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 内部接口:返回所有用户的用户名与组织映射,供知识图谱同步服务调用。
|
||||||
|
*/
|
||||||
|
@GetMapping("/users/organizations")
|
||||||
|
public List<AuthApplicationService.UserOrgMapping> listUserOrganizations() {
|
||||||
|
return authApplicationService.listUserOrganizations();
|
||||||
|
}
|
||||||
|
|
||||||
@PutMapping("/users/{userId}/roles")
|
@PutMapping("/users/{userId}/roles")
|
||||||
public void assignRoles(@PathVariable("userId") Long userId,
|
public void assignRoles(@PathVariable("userId") Long userId,
|
||||||
@RequestBody @Valid AssignUserRolesRequest request) {
|
@RequestBody @Valid AssignUserRolesRequest request) {
|
||||||
|
|||||||
@@ -66,7 +66,8 @@
|
|||||||
username,
|
username,
|
||||||
email,
|
email,
|
||||||
full_name AS fullName,
|
full_name AS fullName,
|
||||||
enabled
|
enabled,
|
||||||
|
organization
|
||||||
FROM users
|
FROM users
|
||||||
ORDER BY id ASC
|
ORDER BY id ASC
|
||||||
</select>
|
</select>
|
||||||
|
|||||||
1444
frontend/package-lock.json
generated
1444
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,8 @@
|
|||||||
"react-dom": "^18.1.1",
|
"react-dom": "^18.1.1",
|
||||||
"react-redux": "^9.2.0",
|
"react-redux": "^9.2.0",
|
||||||
"react-router": "^7.8.0",
|
"react-router": "^7.8.0",
|
||||||
"recharts": "2.15.0"
|
"recharts": "2.15.0",
|
||||||
|
"@antv/g6": "^5.0.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@eslint/js": "^9.33.0",
|
"@eslint/js": "^9.33.0",
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ export const PermissionCodes = {
|
|||||||
taskCoordinationAssign: "module:task-coordination:assign",
|
taskCoordinationAssign: "module:task-coordination:assign",
|
||||||
contentGenerationUse: "module:content-generation:use",
|
contentGenerationUse: "module:content-generation:use",
|
||||||
agentUse: "module:agent:use",
|
agentUse: "module:agent:use",
|
||||||
|
knowledgeGraphRead: "module:knowledge-graph:read",
|
||||||
|
knowledgeGraphWrite: "module:knowledge-graph:write",
|
||||||
userManage: "system:user:manage",
|
userManage: "system:user:manage",
|
||||||
roleManage: "system:role:manage",
|
roleManage: "system:role:manage",
|
||||||
permissionManage: "system:permission:manage",
|
permissionManage: "system:permission:manage",
|
||||||
@@ -39,6 +41,7 @@ const routePermissionRules: Array<{ prefix: string; permission: string }> = [
|
|||||||
{ prefix: "/data/orchestration", permission: PermissionCodes.orchestrationRead },
|
{ prefix: "/data/orchestration", permission: PermissionCodes.orchestrationRead },
|
||||||
{ prefix: "/data/task-coordination", permission: PermissionCodes.taskCoordinationRead },
|
{ prefix: "/data/task-coordination", permission: PermissionCodes.taskCoordinationRead },
|
||||||
{ prefix: "/data/content-generation", permission: PermissionCodes.contentGenerationUse },
|
{ prefix: "/data/content-generation", permission: PermissionCodes.contentGenerationUse },
|
||||||
|
{ prefix: "/data/knowledge-graph", permission: PermissionCodes.knowledgeGraphRead },
|
||||||
{ prefix: "/chat", permission: PermissionCodes.agentUse },
|
{ prefix: "/chat", permission: PermissionCodes.agentUse },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|||||||
509
frontend/src/pages/KnowledgeGraph/Home/KnowledgeGraphPage.tsx
Normal file
509
frontend/src/pages/KnowledgeGraph/Home/KnowledgeGraphPage.tsx
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
import { useState, useCallback, useEffect } from "react";
|
||||||
|
import { Card, Input, Select, Button, Tag, Space, Empty, Tabs, Switch, message, Popconfirm } from "antd";
|
||||||
|
import { Network, RotateCcw, Plus, Link2, Trash2 } from "lucide-react";
|
||||||
|
import { useSearchParams } from "react-router";
|
||||||
|
import { useAppSelector } from "@/store/hooks";
|
||||||
|
import { hasPermission, PermissionCodes } from "@/auth/permissions";
|
||||||
|
import GraphCanvas from "../components/GraphCanvas";
|
||||||
|
import SearchPanel from "../components/SearchPanel";
|
||||||
|
import QueryBuilder from "../components/QueryBuilder";
|
||||||
|
import NodeDetail from "../components/NodeDetail";
|
||||||
|
import RelationDetail from "../components/RelationDetail";
|
||||||
|
import EntityEditForm from "../components/EntityEditForm";
|
||||||
|
import RelationEditForm from "../components/RelationEditForm";
|
||||||
|
import ReviewPanel from "../components/ReviewPanel";
|
||||||
|
import useGraphData from "../hooks/useGraphData";
|
||||||
|
import useGraphLayout, { LAYOUT_OPTIONS } from "../hooks/useGraphLayout";
|
||||||
|
import type { GraphEntity, RelationVO } from "../knowledge-graph.model";
|
||||||
|
import {
|
||||||
|
ENTITY_TYPE_COLORS,
|
||||||
|
DEFAULT_ENTITY_COLOR,
|
||||||
|
ENTITY_TYPE_LABELS,
|
||||||
|
} from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
const UUID_REGEX = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i;
|
||||||
|
|
||||||
|
export default function KnowledgeGraphPage() {
|
||||||
|
const [params, setParams] = useSearchParams();
|
||||||
|
const [graphId, setGraphId] = useState(() => params.get("graphId") ?? "");
|
||||||
|
const [graphIdInput, setGraphIdInput] = useState(() => params.get("graphId") ?? "");
|
||||||
|
|
||||||
|
// Permission check
|
||||||
|
const permissions = useAppSelector((state) => state.auth.permissions);
|
||||||
|
const canWrite = hasPermission(permissions, PermissionCodes.knowledgeGraphWrite);
|
||||||
|
|
||||||
|
const {
|
||||||
|
graphData,
|
||||||
|
loading,
|
||||||
|
searchResults,
|
||||||
|
searchLoading,
|
||||||
|
highlightedNodeIds,
|
||||||
|
loadInitialData,
|
||||||
|
expandNode,
|
||||||
|
searchEntities,
|
||||||
|
mergePathData,
|
||||||
|
clearGraph,
|
||||||
|
clearSearch,
|
||||||
|
} = useGraphData();
|
||||||
|
|
||||||
|
const { layoutType, setLayoutType } = useGraphLayout();
|
||||||
|
|
||||||
|
// Edit mode (only allowed with write permission)
|
||||||
|
const [editMode, setEditMode] = useState(false);
|
||||||
|
|
||||||
|
// Detail panel state
|
||||||
|
const [selectedNodeId, setSelectedNodeId] = useState<string | null>(null);
|
||||||
|
const [selectedEdgeId, setSelectedEdgeId] = useState<string | null>(null);
|
||||||
|
const [nodeDetailOpen, setNodeDetailOpen] = useState(false);
|
||||||
|
const [relationDetailOpen, setRelationDetailOpen] = useState(false);
|
||||||
|
|
||||||
|
// Edit form state
|
||||||
|
const [entityFormOpen, setEntityFormOpen] = useState(false);
|
||||||
|
const [editingEntity, setEditingEntity] = useState<GraphEntity | null>(null);
|
||||||
|
const [relationFormOpen, setRelationFormOpen] = useState(false);
|
||||||
|
const [editingRelation, setEditingRelation] = useState<RelationVO | null>(null);
|
||||||
|
const [defaultRelationSourceId, setDefaultRelationSourceId] = useState<string | undefined>();
|
||||||
|
|
||||||
|
// Batch selection state
|
||||||
|
const [selectedNodeIds, setSelectedNodeIds] = useState<string[]>([]);
|
||||||
|
const [selectedEdgeIds, setSelectedEdgeIds] = useState<string[]>([]);
|
||||||
|
|
||||||
|
// Load graph when graphId changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (graphId && UUID_REGEX.test(graphId)) {
|
||||||
|
clearGraph();
|
||||||
|
loadInitialData(graphId);
|
||||||
|
}
|
||||||
|
}, [graphId, loadInitialData, clearGraph]);
|
||||||
|
|
||||||
|
const handleLoadGraph = useCallback(() => {
|
||||||
|
if (!UUID_REGEX.test(graphIdInput)) {
|
||||||
|
message.warning("请输入有效的图谱 ID(UUID 格式)");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setGraphId(graphIdInput);
|
||||||
|
setParams({ graphId: graphIdInput });
|
||||||
|
}, [graphIdInput, setParams]);
|
||||||
|
|
||||||
|
const handleNodeClick = useCallback((nodeId: string) => {
|
||||||
|
setSelectedNodeId(nodeId);
|
||||||
|
setSelectedEdgeId(null);
|
||||||
|
setNodeDetailOpen(true);
|
||||||
|
setRelationDetailOpen(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleEdgeClick = useCallback((edgeId: string) => {
|
||||||
|
setSelectedEdgeId(edgeId);
|
||||||
|
setSelectedNodeId(null);
|
||||||
|
setRelationDetailOpen(true);
|
||||||
|
setNodeDetailOpen(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleNodeDoubleClick = useCallback(
|
||||||
|
(nodeId: string) => {
|
||||||
|
if (!graphId) return;
|
||||||
|
expandNode(graphId, nodeId);
|
||||||
|
},
|
||||||
|
[graphId, expandNode]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleCanvasClick = useCallback(() => {
|
||||||
|
setSelectedNodeId(null);
|
||||||
|
setSelectedEdgeId(null);
|
||||||
|
setNodeDetailOpen(false);
|
||||||
|
setRelationDetailOpen(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleExpandNode = useCallback(
|
||||||
|
(entityId: string) => {
|
||||||
|
if (!graphId) return;
|
||||||
|
expandNode(graphId, entityId);
|
||||||
|
},
|
||||||
|
[graphId, expandNode]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleEntityNavigate = useCallback(
|
||||||
|
(entityId: string) => {
|
||||||
|
setSelectedNodeId(entityId);
|
||||||
|
setNodeDetailOpen(true);
|
||||||
|
setRelationDetailOpen(false);
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSearchResultClick = useCallback(
|
||||||
|
(entityId: string) => {
|
||||||
|
handleNodeClick(entityId);
|
||||||
|
if (!graphData.nodes.find((n) => n.id === entityId) && graphId) {
|
||||||
|
expandNode(graphId, entityId);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[handleNodeClick, graphData.nodes, graphId, expandNode]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleRelationClick = useCallback((relationId: string) => {
|
||||||
|
setSelectedEdgeId(relationId);
|
||||||
|
setRelationDetailOpen(true);
|
||||||
|
setNodeDetailOpen(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleSelectionChange = useCallback((nodeIds: string[], edgeIds: string[]) => {
|
||||||
|
setSelectedNodeIds(nodeIds);
|
||||||
|
setSelectedEdgeIds(edgeIds);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ---- Edit handlers ----
|
||||||
|
|
||||||
|
const refreshGraph = useCallback(() => {
|
||||||
|
if (graphId) {
|
||||||
|
loadInitialData(graphId);
|
||||||
|
}
|
||||||
|
}, [graphId, loadInitialData]);
|
||||||
|
|
||||||
|
const handleEditEntity = useCallback((entity: GraphEntity) => {
|
||||||
|
setEditingEntity(entity);
|
||||||
|
setEntityFormOpen(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleCreateEntity = useCallback(() => {
|
||||||
|
setEditingEntity(null);
|
||||||
|
setEntityFormOpen(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleDeleteEntity = useCallback(
|
||||||
|
async (entityId: string) => {
|
||||||
|
if (!graphId) return;
|
||||||
|
try {
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "DELETE_ENTITY",
|
||||||
|
entityId,
|
||||||
|
});
|
||||||
|
message.success("实体删除已提交审核");
|
||||||
|
setNodeDetailOpen(false);
|
||||||
|
setSelectedNodeId(null);
|
||||||
|
refreshGraph();
|
||||||
|
} catch {
|
||||||
|
message.error("提交实体删除审核失败");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[graphId, refreshGraph]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleEditRelation = useCallback((relation: RelationVO) => {
|
||||||
|
setEditingRelation(relation);
|
||||||
|
setDefaultRelationSourceId(undefined);
|
||||||
|
setRelationFormOpen(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleCreateRelation = useCallback((sourceEntityId?: string) => {
|
||||||
|
setEditingRelation(null);
|
||||||
|
setDefaultRelationSourceId(sourceEntityId);
|
||||||
|
setRelationFormOpen(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleDeleteRelation = useCallback(
|
||||||
|
async (relationId: string) => {
|
||||||
|
if (!graphId) return;
|
||||||
|
try {
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "DELETE_RELATION",
|
||||||
|
relationId,
|
||||||
|
});
|
||||||
|
message.success("关系删除已提交审核");
|
||||||
|
setRelationDetailOpen(false);
|
||||||
|
setSelectedEdgeId(null);
|
||||||
|
refreshGraph();
|
||||||
|
} catch {
|
||||||
|
message.error("提交关系删除审核失败");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[graphId, refreshGraph]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleEntityFormSuccess = useCallback(() => {
|
||||||
|
refreshGraph();
|
||||||
|
}, [refreshGraph]);
|
||||||
|
|
||||||
|
const handleRelationFormSuccess = useCallback(() => {
|
||||||
|
refreshGraph();
|
||||||
|
}, [refreshGraph]);
|
||||||
|
|
||||||
|
// ---- Batch operations ----
|
||||||
|
|
||||||
|
const handleBatchDeleteNodes = useCallback(async () => {
|
||||||
|
if (!graphId || selectedNodeIds.length === 0) return;
|
||||||
|
try {
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "BATCH_DELETE_ENTITY",
|
||||||
|
payload: JSON.stringify({ ids: selectedNodeIds }),
|
||||||
|
});
|
||||||
|
message.success("批量删除实体已提交审核");
|
||||||
|
setSelectedNodeIds([]);
|
||||||
|
refreshGraph();
|
||||||
|
} catch {
|
||||||
|
message.error("提交批量删除实体审核失败");
|
||||||
|
}
|
||||||
|
}, [graphId, selectedNodeIds, refreshGraph]);
|
||||||
|
|
||||||
|
const handleBatchDeleteEdges = useCallback(async () => {
|
||||||
|
if (!graphId || selectedEdgeIds.length === 0) return;
|
||||||
|
try {
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "BATCH_DELETE_RELATION",
|
||||||
|
payload: JSON.stringify({ ids: selectedEdgeIds }),
|
||||||
|
});
|
||||||
|
message.success("批量删除关系已提交审核");
|
||||||
|
setSelectedEdgeIds([]);
|
||||||
|
refreshGraph();
|
||||||
|
} catch {
|
||||||
|
message.error("提交批量删除关系审核失败");
|
||||||
|
}
|
||||||
|
}, [graphId, selectedEdgeIds, refreshGraph]);
|
||||||
|
|
||||||
|
const hasGraph = graphId && UUID_REGEX.test(graphId);
|
||||||
|
const nodeCount = graphData.nodes.length;
|
||||||
|
const edgeCount = graphData.edges.length;
|
||||||
|
const hasBatchSelection = editMode && (selectedNodeIds.length > 1 || selectedEdgeIds.length > 1);
|
||||||
|
|
||||||
|
// Collect unique entity types in current graph for legend
|
||||||
|
const entityTypes = [...new Set(graphData.nodes.map((n) => n.data.type))].sort();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="h-full flex flex-col gap-4">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h1 className="text-xl font-bold flex items-center gap-2">
|
||||||
|
<Network className="w-5 h-5" />
|
||||||
|
知识图谱浏览器
|
||||||
|
</h1>
|
||||||
|
{hasGraph && canWrite && (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-sm text-gray-500">编辑模式</span>
|
||||||
|
<Switch
|
||||||
|
checked={editMode}
|
||||||
|
onChange={setEditMode}
|
||||||
|
size="small"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Graph ID Input + Controls */}
|
||||||
|
<div className="flex items-center gap-3 flex-wrap">
|
||||||
|
<Space.Compact className="w-[420px]">
|
||||||
|
<Input
|
||||||
|
placeholder="输入图谱 ID (UUID)..."
|
||||||
|
value={graphIdInput}
|
||||||
|
onChange={(e) => setGraphIdInput(e.target.value)}
|
||||||
|
onPressEnter={handleLoadGraph}
|
||||||
|
allowClear
|
||||||
|
/>
|
||||||
|
<Button type="primary" onClick={handleLoadGraph}>
|
||||||
|
加载
|
||||||
|
</Button>
|
||||||
|
</Space.Compact>
|
||||||
|
|
||||||
|
<Select
|
||||||
|
value={layoutType}
|
||||||
|
onChange={setLayoutType}
|
||||||
|
options={LAYOUT_OPTIONS}
|
||||||
|
className="w-28"
|
||||||
|
/>
|
||||||
|
|
||||||
|
{hasGraph && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
icon={<RotateCcw className="w-3.5 h-3.5" />}
|
||||||
|
onClick={() => loadInitialData(graphId)}
|
||||||
|
>
|
||||||
|
重新加载
|
||||||
|
</Button>
|
||||||
|
<span className="text-sm text-gray-500">
|
||||||
|
节点: {nodeCount} | 边: {edgeCount}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Edit mode toolbar */}
|
||||||
|
{hasGraph && editMode && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
icon={<Plus className="w-3.5 h-3.5" />}
|
||||||
|
onClick={handleCreateEntity}
|
||||||
|
>
|
||||||
|
创建实体
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
icon={<Link2 className="w-3.5 h-3.5" />}
|
||||||
|
onClick={() => handleCreateRelation()}
|
||||||
|
>
|
||||||
|
创建关系
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Batch operations toolbar */}
|
||||||
|
{hasBatchSelection && (
|
||||||
|
<>
|
||||||
|
{selectedNodeIds.length > 1 && (
|
||||||
|
<Popconfirm
|
||||||
|
title={`确认批量删除 ${selectedNodeIds.length} 个实体?`}
|
||||||
|
description="删除后关联的关系也会被移除"
|
||||||
|
onConfirm={handleBatchDeleteNodes}
|
||||||
|
okText="确认"
|
||||||
|
cancelText="取消"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
danger
|
||||||
|
icon={<Trash2 className="w-3.5 h-3.5" />}
|
||||||
|
>
|
||||||
|
批量删除实体 ({selectedNodeIds.length})
|
||||||
|
</Button>
|
||||||
|
</Popconfirm>
|
||||||
|
)}
|
||||||
|
{selectedEdgeIds.length > 1 && (
|
||||||
|
<Popconfirm
|
||||||
|
title={`确认批量删除 ${selectedEdgeIds.length} 条关系?`}
|
||||||
|
onConfirm={handleBatchDeleteEdges}
|
||||||
|
okText="确认"
|
||||||
|
cancelText="取消"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
danger
|
||||||
|
icon={<Trash2 className="w-3.5 h-3.5" />}
|
||||||
|
>
|
||||||
|
批量删除关系 ({selectedEdgeIds.length})
|
||||||
|
</Button>
|
||||||
|
</Popconfirm>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Legend */}
|
||||||
|
{entityTypes.length > 0 && (
|
||||||
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
|
<span className="text-xs text-gray-500">图例:</span>
|
||||||
|
{entityTypes.map((type) => (
|
||||||
|
<Tag key={type} color={ENTITY_TYPE_COLORS[type] ?? DEFAULT_ENTITY_COLOR}>
|
||||||
|
{ENTITY_TYPE_LABELS[type] ?? type}
|
||||||
|
</Tag>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Main content */}
|
||||||
|
<div className="flex-1 flex gap-4 min-h-0">
|
||||||
|
{/* Sidebar with tabs */}
|
||||||
|
{hasGraph && (
|
||||||
|
<Card className="w-72 shrink-0 overflow-auto" size="small" bodyStyle={{ padding: 0 }}>
|
||||||
|
<Tabs
|
||||||
|
size="small"
|
||||||
|
className="px-3"
|
||||||
|
items={[
|
||||||
|
{
|
||||||
|
key: "search",
|
||||||
|
label: "搜索",
|
||||||
|
children: (
|
||||||
|
<SearchPanel
|
||||||
|
graphId={graphId}
|
||||||
|
results={searchResults}
|
||||||
|
loading={searchLoading}
|
||||||
|
onSearch={searchEntities}
|
||||||
|
onResultClick={handleSearchResultClick}
|
||||||
|
onClear={clearSearch}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "query",
|
||||||
|
label: "路径查询",
|
||||||
|
children: (
|
||||||
|
<QueryBuilder
|
||||||
|
graphId={graphId}
|
||||||
|
onPathResult={mergePathData}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: "review",
|
||||||
|
label: "审核",
|
||||||
|
children: <ReviewPanel graphId={graphId} />,
|
||||||
|
},
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Canvas */}
|
||||||
|
<Card className="flex-1 min-w-0" bodyStyle={{ height: "100%", padding: 0 }}>
|
||||||
|
{hasGraph ? (
|
||||||
|
<GraphCanvas
|
||||||
|
data={graphData}
|
||||||
|
loading={loading}
|
||||||
|
layoutType={layoutType}
|
||||||
|
highlightedNodeIds={highlightedNodeIds}
|
||||||
|
editMode={editMode}
|
||||||
|
onNodeClick={handleNodeClick}
|
||||||
|
onEdgeClick={handleEdgeClick}
|
||||||
|
onNodeDoubleClick={handleNodeDoubleClick}
|
||||||
|
onCanvasClick={handleCanvasClick}
|
||||||
|
onSelectionChange={handleSelectionChange}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="h-full flex items-center justify-center">
|
||||||
|
<Empty
|
||||||
|
description="请输入图谱 ID 加载知识图谱"
|
||||||
|
image={<Network className="w-16 h-16 text-gray-300 mx-auto" />}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Detail drawers */}
|
||||||
|
<NodeDetail
|
||||||
|
graphId={graphId}
|
||||||
|
entityId={selectedNodeId}
|
||||||
|
open={nodeDetailOpen}
|
||||||
|
editMode={editMode}
|
||||||
|
onClose={() => setNodeDetailOpen(false)}
|
||||||
|
onExpandNode={handleExpandNode}
|
||||||
|
onRelationClick={handleRelationClick}
|
||||||
|
onEntityNavigate={handleEntityNavigate}
|
||||||
|
onEditEntity={handleEditEntity}
|
||||||
|
onDeleteEntity={handleDeleteEntity}
|
||||||
|
onCreateRelation={handleCreateRelation}
|
||||||
|
/>
|
||||||
|
<RelationDetail
|
||||||
|
graphId={graphId}
|
||||||
|
relationId={selectedEdgeId}
|
||||||
|
open={relationDetailOpen}
|
||||||
|
editMode={editMode}
|
||||||
|
onClose={() => setRelationDetailOpen(false)}
|
||||||
|
onEntityNavigate={handleEntityNavigate}
|
||||||
|
onEditRelation={handleEditRelation}
|
||||||
|
onDeleteRelation={handleDeleteRelation}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Edit forms */}
|
||||||
|
<EntityEditForm
|
||||||
|
graphId={graphId}
|
||||||
|
entity={editingEntity}
|
||||||
|
open={entityFormOpen}
|
||||||
|
onClose={() => setEntityFormOpen(false)}
|
||||||
|
onSuccess={handleEntityFormSuccess}
|
||||||
|
/>
|
||||||
|
<RelationEditForm
|
||||||
|
graphId={graphId}
|
||||||
|
relation={editingRelation}
|
||||||
|
open={relationFormOpen}
|
||||||
|
onClose={() => setRelationFormOpen(false)}
|
||||||
|
onSuccess={handleRelationFormSuccess}
|
||||||
|
defaultSourceId={defaultRelationSourceId}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
143
frontend/src/pages/KnowledgeGraph/components/EntityEditForm.tsx
Normal file
143
frontend/src/pages/KnowledgeGraph/components/EntityEditForm.tsx
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import { useEffect } from "react";
|
||||||
|
import { Modal, Form, Input, Select, InputNumber, message } from "antd";
|
||||||
|
import type { GraphEntity } from "../knowledge-graph.model";
|
||||||
|
import { ENTITY_TYPES, ENTITY_TYPE_LABELS } from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
interface EntityEditFormProps {
|
||||||
|
graphId: string;
|
||||||
|
entity?: GraphEntity | null;
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onSuccess: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function EntityEditForm({
|
||||||
|
graphId,
|
||||||
|
entity,
|
||||||
|
open,
|
||||||
|
onClose,
|
||||||
|
onSuccess,
|
||||||
|
}: EntityEditFormProps) {
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const isEdit = !!entity;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (open && entity) {
|
||||||
|
form.setFieldsValue({
|
||||||
|
name: entity.name,
|
||||||
|
type: entity.type,
|
||||||
|
description: entity.description ?? "",
|
||||||
|
aliases: entity.aliases?.join(", ") ?? "",
|
||||||
|
confidence: entity.confidence ?? 1.0,
|
||||||
|
});
|
||||||
|
} else if (open) {
|
||||||
|
form.resetFields();
|
||||||
|
}
|
||||||
|
}, [open, entity, form]);
|
||||||
|
|
||||||
|
const handleSubmit = async () => {
|
||||||
|
let values;
|
||||||
|
try {
|
||||||
|
values = await form.validateFields();
|
||||||
|
} catch {
|
||||||
|
return; // Form validation failed — Antd shows inline errors
|
||||||
|
}
|
||||||
|
|
||||||
|
const parsedAliases = values.aliases
|
||||||
|
? values.aliases
|
||||||
|
.split(",")
|
||||||
|
.map((a: string) => a.trim())
|
||||||
|
.filter(Boolean)
|
||||||
|
: [];
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (isEdit && entity) {
|
||||||
|
const payload = JSON.stringify({
|
||||||
|
name: values.name,
|
||||||
|
description: values.description || undefined,
|
||||||
|
aliases: parsedAliases.length > 0 ? parsedAliases : undefined,
|
||||||
|
properties: entity.properties,
|
||||||
|
confidence: values.confidence,
|
||||||
|
});
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "UPDATE_ENTITY",
|
||||||
|
entityId: entity.id,
|
||||||
|
payload,
|
||||||
|
});
|
||||||
|
message.success("实体更新已提交审核");
|
||||||
|
} else {
|
||||||
|
const payload = JSON.stringify({
|
||||||
|
name: values.name,
|
||||||
|
type: values.type,
|
||||||
|
description: values.description || undefined,
|
||||||
|
aliases: parsedAliases.length > 0 ? parsedAliases : undefined,
|
||||||
|
properties: {},
|
||||||
|
confidence: values.confidence,
|
||||||
|
});
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "CREATE_ENTITY",
|
||||||
|
payload,
|
||||||
|
});
|
||||||
|
message.success("实体创建已提交审核");
|
||||||
|
}
|
||||||
|
onSuccess();
|
||||||
|
onClose();
|
||||||
|
} catch {
|
||||||
|
message.error(isEdit ? "提交实体更新审核失败" : "提交实体创建审核失败");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={isEdit ? "编辑实体" : "创建实体"}
|
||||||
|
open={open}
|
||||||
|
onCancel={onClose}
|
||||||
|
onOk={handleSubmit}
|
||||||
|
okText={isEdit ? "提交审核" : "提交审核"}
|
||||||
|
cancelText="取消"
|
||||||
|
destroyOnClose
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical" className="mt-4">
|
||||||
|
<Form.Item
|
||||||
|
name="name"
|
||||||
|
label="名称"
|
||||||
|
rules={[{ required: true, message: "请输入实体名称" }]}
|
||||||
|
>
|
||||||
|
<Input placeholder="输入实体名称" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="type"
|
||||||
|
label="类型"
|
||||||
|
rules={[{ required: true, message: "请选择实体类型" }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="选择实体类型"
|
||||||
|
disabled={isEdit}
|
||||||
|
options={ENTITY_TYPES.map((t) => ({
|
||||||
|
label: ENTITY_TYPE_LABELS[t] ?? t,
|
||||||
|
value: t,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="description" label="描述">
|
||||||
|
<Input.TextArea rows={3} placeholder="输入实体描述(可选)" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="aliases"
|
||||||
|
label="别名"
|
||||||
|
tooltip="多个别名用逗号分隔"
|
||||||
|
>
|
||||||
|
<Input placeholder="别名1, 别名2, ..." />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="confidence" label="置信度">
|
||||||
|
<InputNumber min={0} max={1} step={0.1} className="w-full" />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
||||||
258
frontend/src/pages/KnowledgeGraph/components/GraphCanvas.tsx
Normal file
258
frontend/src/pages/KnowledgeGraph/components/GraphCanvas.tsx
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
import { useEffect, useRef, useCallback, memo } from "react";
|
||||||
|
import { Graph } from "@antv/g6";
|
||||||
|
import { Spin } from "antd";
|
||||||
|
import type { G6GraphData } from "../graphTransform";
|
||||||
|
import { createGraphOptions, LARGE_GRAPH_THRESHOLD } from "../graphConfig";
|
||||||
|
import type { LayoutType } from "../hooks/useGraphLayout";
|
||||||
|
|
||||||
|
interface GraphCanvasProps {
|
||||||
|
data: G6GraphData;
|
||||||
|
loading?: boolean;
|
||||||
|
layoutType: LayoutType;
|
||||||
|
highlightedNodeIds?: Set<string>;
|
||||||
|
editMode?: boolean;
|
||||||
|
onNodeClick?: (nodeId: string) => void;
|
||||||
|
onEdgeClick?: (edgeId: string) => void;
|
||||||
|
onNodeDoubleClick?: (nodeId: string) => void;
|
||||||
|
onCanvasClick?: () => void;
|
||||||
|
onSelectionChange?: (nodeIds: string[], edgeIds: string[]) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
type GraphElementEvent = {
|
||||||
|
item?: {
|
||||||
|
id?: string;
|
||||||
|
getID?: () => string;
|
||||||
|
getModel?: () => { id?: string };
|
||||||
|
};
|
||||||
|
target?: { id?: string };
|
||||||
|
};
|
||||||
|
|
||||||
|
function GraphCanvas({
|
||||||
|
data,
|
||||||
|
loading = false,
|
||||||
|
layoutType,
|
||||||
|
highlightedNodeIds,
|
||||||
|
editMode = false,
|
||||||
|
onNodeClick,
|
||||||
|
onEdgeClick,
|
||||||
|
onNodeDoubleClick,
|
||||||
|
onCanvasClick,
|
||||||
|
onSelectionChange,
|
||||||
|
}: GraphCanvasProps) {
|
||||||
|
const containerRef = useRef<HTMLDivElement>(null);
|
||||||
|
const graphRef = useRef<Graph | null>(null);
|
||||||
|
|
||||||
|
// Initialize graph
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
|
||||||
|
const options = createGraphOptions(containerRef.current, editMode);
|
||||||
|
const graph = new Graph(options);
|
||||||
|
graphRef.current = graph;
|
||||||
|
|
||||||
|
graph.render();
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
graphRef.current = null;
|
||||||
|
graph.destroy();
|
||||||
|
};
|
||||||
|
// editMode is intentionally included so the graph re-creates with correct multi-select setting
|
||||||
|
}, [editMode]);
|
||||||
|
|
||||||
|
// Update data (with large-graph performance optimization)
|
||||||
|
useEffect(() => {
|
||||||
|
const graph = graphRef.current;
|
||||||
|
if (!graph) return;
|
||||||
|
|
||||||
|
const isLargeGraph = data.nodes.length >= LARGE_GRAPH_THRESHOLD;
|
||||||
|
if (isLargeGraph) {
|
||||||
|
graph.setOptions({ animation: false });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.nodes.length === 0 && data.edges.length === 0) {
|
||||||
|
graph.setData({ nodes: [], edges: [] });
|
||||||
|
graph.render();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
graph.setData(data);
|
||||||
|
graph.render();
|
||||||
|
}, [data]);
|
||||||
|
|
||||||
|
// Update layout
|
||||||
|
useEffect(() => {
|
||||||
|
const graph = graphRef.current;
|
||||||
|
if (!graph) return;
|
||||||
|
|
||||||
|
const layoutConfigs: Record<string, Record<string, unknown>> = {
|
||||||
|
"d3-force": {
|
||||||
|
type: "d3-force",
|
||||||
|
preventOverlap: true,
|
||||||
|
link: { distance: 180 },
|
||||||
|
charge: { strength: -400 },
|
||||||
|
collide: { radius: 50 },
|
||||||
|
},
|
||||||
|
circular: { type: "circular", radius: 250 },
|
||||||
|
grid: { type: "grid" },
|
||||||
|
radial: { type: "radial", unitRadius: 120, preventOverlap: true, nodeSpacing: 30 },
|
||||||
|
concentric: { type: "concentric", preventOverlap: true, nodeSpacing: 30 },
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.setLayout(layoutConfigs[layoutType] ?? layoutConfigs["d3-force"]);
|
||||||
|
graph.layout();
|
||||||
|
}, [layoutType]);
|
||||||
|
|
||||||
|
// Highlight nodes
|
||||||
|
useEffect(() => {
|
||||||
|
const graph = graphRef.current;
|
||||||
|
if (!graph || !highlightedNodeIds) return;
|
||||||
|
|
||||||
|
const allNodeIds = data.nodes.map((n) => n.id);
|
||||||
|
if (highlightedNodeIds.size === 0) {
|
||||||
|
// Clear all states
|
||||||
|
allNodeIds.forEach((id) => {
|
||||||
|
graph.setElementState(id, []);
|
||||||
|
});
|
||||||
|
data.edges.forEach((e) => {
|
||||||
|
graph.setElementState(e.id, []);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
allNodeIds.forEach((id) => {
|
||||||
|
if (highlightedNodeIds.has(id)) {
|
||||||
|
graph.setElementState(id, ["highlighted"]);
|
||||||
|
} else {
|
||||||
|
graph.setElementState(id, ["dimmed"]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
data.edges.forEach((e) => {
|
||||||
|
if (highlightedNodeIds.has(e.source) || highlightedNodeIds.has(e.target)) {
|
||||||
|
graph.setElementState(e.id, []);
|
||||||
|
} else {
|
||||||
|
graph.setElementState(e.id, ["dimmed"]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, [highlightedNodeIds, data]);
|
||||||
|
|
||||||
|
// Helper: query selected elements from graph and notify parent
|
||||||
|
const resolveElementId = useCallback(
|
||||||
|
(event: GraphElementEvent, elementType: "node" | "edge"): string | null => {
|
||||||
|
const itemId =
|
||||||
|
event.item?.getID?.() ??
|
||||||
|
event.item?.getModel?.()?.id ??
|
||||||
|
event.item?.id;
|
||||||
|
if (itemId) {
|
||||||
|
return itemId;
|
||||||
|
}
|
||||||
|
|
||||||
|
const targetId = event.target?.id;
|
||||||
|
if (!targetId) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const existsInData =
|
||||||
|
elementType === "node"
|
||||||
|
? data.nodes.some((node) => node.id === targetId)
|
||||||
|
: data.edges.some((edge) => edge.id === targetId);
|
||||||
|
return existsInData ? targetId : null;
|
||||||
|
},
|
||||||
|
[data.nodes, data.edges]
|
||||||
|
);
|
||||||
|
|
||||||
|
const emitSelectionChange = useCallback(() => {
|
||||||
|
const graph = graphRef.current;
|
||||||
|
if (!graph || !onSelectionChange) return;
|
||||||
|
// Defer to next tick so G6 internal state has settled
|
||||||
|
setTimeout(() => {
|
||||||
|
try {
|
||||||
|
const selectedNodes = graph.getElementDataByState("node", "selected");
|
||||||
|
const selectedEdges = graph.getElementDataByState("edge", "selected");
|
||||||
|
onSelectionChange(
|
||||||
|
selectedNodes.map((n: { id: string }) => n.id),
|
||||||
|
selectedEdges.map((e: { id: string }) => e.id)
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// graph may be destroyed
|
||||||
|
}
|
||||||
|
}, 0);
|
||||||
|
}, [onSelectionChange]);
|
||||||
|
|
||||||
|
// Bind events
|
||||||
|
useEffect(() => {
|
||||||
|
const graph = graphRef.current;
|
||||||
|
if (!graph) return;
|
||||||
|
|
||||||
|
const handleNodeClick = (event: GraphElementEvent) => {
|
||||||
|
const nodeId = resolveElementId(event, "node");
|
||||||
|
if (nodeId) {
|
||||||
|
onNodeClick?.(nodeId);
|
||||||
|
}
|
||||||
|
emitSelectionChange();
|
||||||
|
};
|
||||||
|
const handleEdgeClick = (event: GraphElementEvent) => {
|
||||||
|
const edgeId = resolveElementId(event, "edge");
|
||||||
|
if (edgeId) {
|
||||||
|
onEdgeClick?.(edgeId);
|
||||||
|
}
|
||||||
|
emitSelectionChange();
|
||||||
|
};
|
||||||
|
const handleNodeDblClick = (event: GraphElementEvent) => {
|
||||||
|
const nodeId = resolveElementId(event, "node");
|
||||||
|
if (nodeId) {
|
||||||
|
onNodeDoubleClick?.(nodeId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const handleCanvasClick = () => {
|
||||||
|
onCanvasClick?.();
|
||||||
|
emitSelectionChange();
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.on("node:click", handleNodeClick);
|
||||||
|
graph.on("edge:click", handleEdgeClick);
|
||||||
|
graph.on("node:dblclick", handleNodeDblClick);
|
||||||
|
graph.on("canvas:click", handleCanvasClick);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
graph.off("node:click", handleNodeClick);
|
||||||
|
graph.off("edge:click", handleEdgeClick);
|
||||||
|
graph.off("node:dblclick", handleNodeDblClick);
|
||||||
|
graph.off("canvas:click", handleCanvasClick);
|
||||||
|
};
|
||||||
|
}, [
|
||||||
|
onNodeClick,
|
||||||
|
onEdgeClick,
|
||||||
|
onNodeDoubleClick,
|
||||||
|
onCanvasClick,
|
||||||
|
emitSelectionChange,
|
||||||
|
resolveElementId,
|
||||||
|
]);
|
||||||
|
|
||||||
|
// Fit view helper
|
||||||
|
const handleFitView = useCallback(() => {
|
||||||
|
graphRef.current?.fitView();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="relative w-full h-full">
|
||||||
|
<Spin spinning={loading} tip="加载中...">
|
||||||
|
<div ref={containerRef} className="w-full h-full min-h-[500px]" />
|
||||||
|
</Spin>
|
||||||
|
<div className="absolute bottom-4 right-4 flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={handleFitView}
|
||||||
|
className="px-3 py-1.5 bg-white border border-gray-300 rounded shadow-sm text-xs hover:bg-gray-50"
|
||||||
|
>
|
||||||
|
适应画布
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => graphRef.current?.zoomTo(1)}
|
||||||
|
className="px-3 py-1.5 bg-white border border-gray-300 rounded shadow-sm text-xs hover:bg-gray-50"
|
||||||
|
>
|
||||||
|
重置缩放
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(GraphCanvas);
|
||||||
240
frontend/src/pages/KnowledgeGraph/components/NodeDetail.tsx
Normal file
240
frontend/src/pages/KnowledgeGraph/components/NodeDetail.tsx
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { Drawer, Descriptions, Tag, List, Button, Spin, Empty, Popconfirm, Space, message } from "antd";
|
||||||
|
import { Expand, Pencil, Trash2 } from "lucide-react";
|
||||||
|
import type { GraphEntity, RelationVO, PagedResponse } from "../knowledge-graph.model";
|
||||||
|
import {
|
||||||
|
ENTITY_TYPE_LABELS,
|
||||||
|
ENTITY_TYPE_COLORS,
|
||||||
|
DEFAULT_ENTITY_COLOR,
|
||||||
|
RELATION_TYPE_LABELS,
|
||||||
|
} from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
interface NodeDetailProps {
|
||||||
|
graphId: string;
|
||||||
|
entityId: string | null;
|
||||||
|
open: boolean;
|
||||||
|
editMode?: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onExpandNode: (entityId: string) => void;
|
||||||
|
onRelationClick: (relationId: string) => void;
|
||||||
|
onEntityNavigate: (entityId: string) => void;
|
||||||
|
onEditEntity?: (entity: GraphEntity) => void;
|
||||||
|
onDeleteEntity?: (entityId: string) => void;
|
||||||
|
onCreateRelation?: (sourceEntityId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function NodeDetail({
|
||||||
|
graphId,
|
||||||
|
entityId,
|
||||||
|
open,
|
||||||
|
editMode = false,
|
||||||
|
onClose,
|
||||||
|
onExpandNode,
|
||||||
|
onRelationClick,
|
||||||
|
onEntityNavigate,
|
||||||
|
onEditEntity,
|
||||||
|
onDeleteEntity,
|
||||||
|
onCreateRelation,
|
||||||
|
}: NodeDetailProps) {
|
||||||
|
const [entity, setEntity] = useState<GraphEntity | null>(null);
|
||||||
|
const [relations, setRelations] = useState<RelationVO[]>([]);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!entityId || !graphId) {
|
||||||
|
setEntity(null);
|
||||||
|
setRelations([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!open) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
Promise.all([
|
||||||
|
api.getEntity(graphId, entityId),
|
||||||
|
api.listEntityRelations(graphId, entityId, { page: 0, size: 50 }),
|
||||||
|
])
|
||||||
|
.then(([entityData, relData]: [GraphEntity, PagedResponse<RelationVO>]) => {
|
||||||
|
setEntity(entityData);
|
||||||
|
setRelations(relData.content);
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
message.error("加载实体详情失败");
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
}, [graphId, entityId, open]);
|
||||||
|
|
||||||
|
const handleDelete = () => {
|
||||||
|
if (entityId) {
|
||||||
|
onDeleteEntity?.(entityId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Drawer
|
||||||
|
title={
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span>实体详情</span>
|
||||||
|
{entity && (
|
||||||
|
<Tag color={ENTITY_TYPE_COLORS[entity.type] ?? DEFAULT_ENTITY_COLOR}>
|
||||||
|
{ENTITY_TYPE_LABELS[entity.type] ?? entity.type}
|
||||||
|
</Tag>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
open={open}
|
||||||
|
onClose={onClose}
|
||||||
|
width={420}
|
||||||
|
extra={
|
||||||
|
entityId && (
|
||||||
|
<Space>
|
||||||
|
{editMode && entity && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
icon={<Pencil className="w-3 h-3" />}
|
||||||
|
onClick={() => onEditEntity?.(entity)}
|
||||||
|
>
|
||||||
|
编辑
|
||||||
|
</Button>
|
||||||
|
<Popconfirm
|
||||||
|
title="确认删除此实体?"
|
||||||
|
description="删除后关联的关系也会被移除"
|
||||||
|
onConfirm={handleDelete}
|
||||||
|
okText="确认"
|
||||||
|
cancelText="取消"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
danger
|
||||||
|
icon={<Trash2 className="w-3 h-3" />}
|
||||||
|
>
|
||||||
|
删除
|
||||||
|
</Button>
|
||||||
|
</Popconfirm>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
size="small"
|
||||||
|
icon={<Expand className="w-3 h-3" />}
|
||||||
|
onClick={() => onExpandNode(entityId)}
|
||||||
|
>
|
||||||
|
展开邻居
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{entity ? (
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<Descriptions column={1} size="small" bordered>
|
||||||
|
<Descriptions.Item label="名称">{entity.name}</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="类型">
|
||||||
|
{ENTITY_TYPE_LABELS[entity.type] ?? entity.type}
|
||||||
|
</Descriptions.Item>
|
||||||
|
{entity.description && (
|
||||||
|
<Descriptions.Item label="描述">{entity.description}</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{entity.aliases && entity.aliases.length > 0 && (
|
||||||
|
<Descriptions.Item label="别名">
|
||||||
|
{entity.aliases.map((a) => (
|
||||||
|
<Tag key={a}>{a}</Tag>
|
||||||
|
))}
|
||||||
|
</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{entity.confidence != null && (
|
||||||
|
<Descriptions.Item label="置信度">
|
||||||
|
{(entity.confidence * 100).toFixed(0)}%
|
||||||
|
</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{entity.sourceType && (
|
||||||
|
<Descriptions.Item label="来源">{entity.sourceType}</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{entity.createdAt && (
|
||||||
|
<Descriptions.Item label="创建时间">{entity.createdAt}</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
</Descriptions>
|
||||||
|
|
||||||
|
{entity.properties && Object.keys(entity.properties).length > 0 && (
|
||||||
|
<>
|
||||||
|
<h4 className="font-medium text-sm">扩展属性</h4>
|
||||||
|
<Descriptions column={1} size="small" bordered>
|
||||||
|
{Object.entries(entity.properties).map(([key, value]) => (
|
||||||
|
<Descriptions.Item key={key} label={key}>
|
||||||
|
{String(value)}
|
||||||
|
</Descriptions.Item>
|
||||||
|
))}
|
||||||
|
</Descriptions>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h4 className="font-medium text-sm">关系列表 ({relations.length})</h4>
|
||||||
|
{editMode && entityId && (
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
type="link"
|
||||||
|
onClick={() => onCreateRelation?.(entityId)}
|
||||||
|
>
|
||||||
|
+ 添加关系
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
{relations.length > 0 ? (
|
||||||
|
<List
|
||||||
|
size="small"
|
||||||
|
dataSource={relations}
|
||||||
|
renderItem={(rel) => {
|
||||||
|
const isSource = rel.sourceEntityId === entityId;
|
||||||
|
const otherName = isSource ? rel.targetEntityName : rel.sourceEntityName;
|
||||||
|
const otherType = isSource ? rel.targetEntityType : rel.sourceEntityType;
|
||||||
|
const otherId = isSource ? rel.targetEntityId : rel.sourceEntityId;
|
||||||
|
const direction = isSource ? "→" : "←";
|
||||||
|
|
||||||
|
return (
|
||||||
|
<List.Item
|
||||||
|
className="cursor-pointer hover:bg-gray-50 !px-2"
|
||||||
|
onClick={() => onRelationClick(rel.id)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-1.5 w-full min-w-0 text-sm">
|
||||||
|
<span className="text-gray-400">{direction}</span>
|
||||||
|
<Tag
|
||||||
|
className="shrink-0"
|
||||||
|
color={ENTITY_TYPE_COLORS[otherType] ?? DEFAULT_ENTITY_COLOR}
|
||||||
|
>
|
||||||
|
{ENTITY_TYPE_LABELS[otherType] ?? otherType}
|
||||||
|
</Tag>
|
||||||
|
<Button
|
||||||
|
type="link"
|
||||||
|
size="small"
|
||||||
|
className="!p-0 truncate"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
onEntityNavigate(otherId);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{otherName}
|
||||||
|
</Button>
|
||||||
|
<span className="ml-auto text-xs text-gray-400 shrink-0">
|
||||||
|
{RELATION_TYPE_LABELS[rel.relationType] ?? rel.relationType}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</List.Item>
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty description="暂无关系" image={Empty.PRESENTED_IMAGE_SIMPLE} />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : !loading ? (
|
||||||
|
<Empty description="选择一个节点查看详情" />
|
||||||
|
) : null}
|
||||||
|
</Spin>
|
||||||
|
</Drawer>
|
||||||
|
);
|
||||||
|
}
|
||||||
173
frontend/src/pages/KnowledgeGraph/components/QueryBuilder.tsx
Normal file
173
frontend/src/pages/KnowledgeGraph/components/QueryBuilder.tsx
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import { useState, useCallback } from "react";
|
||||||
|
import { Input, Button, Select, InputNumber, List, Tag, Empty, message, Spin } from "antd";
|
||||||
|
import type { PathVO, AllPathsVO, EntitySummaryVO, EdgeSummaryVO } from "../knowledge-graph.model";
|
||||||
|
import {
|
||||||
|
ENTITY_TYPE_LABELS,
|
||||||
|
ENTITY_TYPE_COLORS,
|
||||||
|
DEFAULT_ENTITY_COLOR,
|
||||||
|
RELATION_TYPE_LABELS,
|
||||||
|
} from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
type QueryType = "shortest-path" | "all-paths";
|
||||||
|
|
||||||
|
interface QueryBuilderProps {
|
||||||
|
graphId: string;
|
||||||
|
onPathResult: (nodes: EntitySummaryVO[], edges: EdgeSummaryVO[]) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function QueryBuilder({ graphId, onPathResult }: QueryBuilderProps) {
|
||||||
|
const [queryType, setQueryType] = useState<QueryType>("shortest-path");
|
||||||
|
const [sourceId, setSourceId] = useState("");
|
||||||
|
const [targetId, setTargetId] = useState("");
|
||||||
|
const [maxDepth, setMaxDepth] = useState(5);
|
||||||
|
const [maxPaths, setMaxPaths] = useState(3);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [pathResults, setPathResults] = useState<PathVO[]>([]);
|
||||||
|
|
||||||
|
const handleQuery = useCallback(async () => {
|
||||||
|
if (!sourceId.trim() || !targetId.trim()) {
|
||||||
|
message.warning("请输入源实体和目标实体 ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setLoading(true);
|
||||||
|
setPathResults([]);
|
||||||
|
try {
|
||||||
|
if (queryType === "shortest-path") {
|
||||||
|
const path: PathVO = await api.getShortestPath(graphId, {
|
||||||
|
sourceId: sourceId.trim(),
|
||||||
|
targetId: targetId.trim(),
|
||||||
|
maxDepth,
|
||||||
|
});
|
||||||
|
setPathResults([path]);
|
||||||
|
onPathResult(path.nodes, path.edges);
|
||||||
|
} else {
|
||||||
|
const result: AllPathsVO = await api.getAllPaths(graphId, {
|
||||||
|
sourceId: sourceId.trim(),
|
||||||
|
targetId: targetId.trim(),
|
||||||
|
maxDepth,
|
||||||
|
maxPaths,
|
||||||
|
});
|
||||||
|
setPathResults(result.paths);
|
||||||
|
if (result.paths.length > 0) {
|
||||||
|
const allNodes = result.paths.flatMap((p) => p.nodes);
|
||||||
|
const allEdges = result.paths.flatMap((p) => p.edges);
|
||||||
|
onPathResult(allNodes, allEdges);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
message.error("路径查询失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, [graphId, queryType, sourceId, targetId, maxDepth, maxPaths, onPathResult]);
|
||||||
|
|
||||||
|
const handleClear = useCallback(() => {
|
||||||
|
setPathResults([]);
|
||||||
|
setSourceId("");
|
||||||
|
setTargetId("");
|
||||||
|
onPathResult([], []);
|
||||||
|
}, [onPathResult]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-3">
|
||||||
|
<Select
|
||||||
|
value={queryType}
|
||||||
|
onChange={setQueryType}
|
||||||
|
className="w-full"
|
||||||
|
options={[
|
||||||
|
{ label: "最短路径", value: "shortest-path" },
|
||||||
|
{ label: "所有路径", value: "all-paths" },
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Input
|
||||||
|
placeholder="源实体 ID"
|
||||||
|
value={sourceId}
|
||||||
|
onChange={(e) => setSourceId(e.target.value)}
|
||||||
|
allowClear
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Input
|
||||||
|
placeholder="目标实体 ID"
|
||||||
|
value={targetId}
|
||||||
|
onChange={(e) => setTargetId(e.target.value)}
|
||||||
|
allowClear
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-xs text-gray-500 shrink-0">最大深度</span>
|
||||||
|
<InputNumber
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
value={maxDepth}
|
||||||
|
onChange={(v) => setMaxDepth(v ?? 5)}
|
||||||
|
size="small"
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{queryType === "all-paths" && (
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-xs text-gray-500 shrink-0">最大路径数</span>
|
||||||
|
<InputNumber
|
||||||
|
min={1}
|
||||||
|
max={20}
|
||||||
|
value={maxPaths}
|
||||||
|
onChange={(v) => setMaxPaths(v ?? 3)}
|
||||||
|
size="small"
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button type="primary" onClick={handleQuery} loading={loading} className="flex-1">
|
||||||
|
查询
|
||||||
|
</Button>
|
||||||
|
<Button onClick={handleClear}>清除</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{pathResults.length > 0 ? (
|
||||||
|
<List
|
||||||
|
size="small"
|
||||||
|
dataSource={pathResults}
|
||||||
|
renderItem={(path, index) => (
|
||||||
|
<List.Item className="!px-2">
|
||||||
|
<div className="flex flex-col gap-1 w-full">
|
||||||
|
<div className="text-xs font-medium text-gray-600">
|
||||||
|
路径 {index + 1}({path.pathLength} 跳)
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1 flex-wrap">
|
||||||
|
{path.nodes.map((node, ni) => (
|
||||||
|
<span key={node.id} className="flex items-center gap-1">
|
||||||
|
{ni > 0 && (
|
||||||
|
<span className="text-xs text-gray-400">
|
||||||
|
{path.edges[ni - 1]
|
||||||
|
? RELATION_TYPE_LABELS[path.edges[ni - 1].relationType] ??
|
||||||
|
path.edges[ni - 1].relationType
|
||||||
|
: "→"}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
<Tag
|
||||||
|
color={ENTITY_TYPE_COLORS[node.type] ?? DEFAULT_ENTITY_COLOR}
|
||||||
|
className="!m-0"
|
||||||
|
>
|
||||||
|
{ENTITY_TYPE_LABELS[node.type] ?? node.type}
|
||||||
|
</Tag>
|
||||||
|
<span className="text-xs">{node.name}</span>
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : !loading && sourceId && targetId ? (
|
||||||
|
<Empty description="暂无结果" image={Empty.PRESENTED_IMAGE_SIMPLE} />
|
||||||
|
) : null}
|
||||||
|
</Spin>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
167
frontend/src/pages/KnowledgeGraph/components/RelationDetail.tsx
Normal file
167
frontend/src/pages/KnowledgeGraph/components/RelationDetail.tsx
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { Drawer, Descriptions, Tag, Spin, Empty, Button, Popconfirm, Space, message } from "antd";
|
||||||
|
import { Pencil, Trash2 } from "lucide-react";
|
||||||
|
import type { RelationVO } from "../knowledge-graph.model";
|
||||||
|
import {
|
||||||
|
ENTITY_TYPE_LABELS,
|
||||||
|
ENTITY_TYPE_COLORS,
|
||||||
|
DEFAULT_ENTITY_COLOR,
|
||||||
|
RELATION_TYPE_LABELS,
|
||||||
|
} from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
interface RelationDetailProps {
|
||||||
|
graphId: string;
|
||||||
|
relationId: string | null;
|
||||||
|
open: boolean;
|
||||||
|
editMode?: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onEntityNavigate: (entityId: string) => void;
|
||||||
|
onEditRelation?: (relation: RelationVO) => void;
|
||||||
|
onDeleteRelation?: (relationId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function RelationDetail({
|
||||||
|
graphId,
|
||||||
|
relationId,
|
||||||
|
open,
|
||||||
|
editMode = false,
|
||||||
|
onClose,
|
||||||
|
onEntityNavigate,
|
||||||
|
onEditRelation,
|
||||||
|
onDeleteRelation,
|
||||||
|
}: RelationDetailProps) {
|
||||||
|
const [relation, setRelation] = useState<RelationVO | null>(null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!relationId || !graphId) {
|
||||||
|
setRelation(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!open) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
api
|
||||||
|
.getRelation(graphId, relationId)
|
||||||
|
.then((data) => setRelation(data))
|
||||||
|
.catch(() => message.error("加载关系详情失败"))
|
||||||
|
.finally(() => setLoading(false));
|
||||||
|
}, [graphId, relationId, open]);
|
||||||
|
|
||||||
|
const handleDelete = () => {
|
||||||
|
if (relationId) {
|
||||||
|
onDeleteRelation?.(relationId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Drawer
|
||||||
|
title="关系详情"
|
||||||
|
open={open}
|
||||||
|
onClose={onClose}
|
||||||
|
width={400}
|
||||||
|
extra={
|
||||||
|
editMode && relation && (
|
||||||
|
<Space>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
icon={<Pencil className="w-3 h-3" />}
|
||||||
|
onClick={() => onEditRelation?.(relation)}
|
||||||
|
>
|
||||||
|
编辑
|
||||||
|
</Button>
|
||||||
|
<Popconfirm
|
||||||
|
title="确认删除此关系?"
|
||||||
|
onConfirm={handleDelete}
|
||||||
|
okText="确认"
|
||||||
|
cancelText="取消"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
danger
|
||||||
|
icon={<Trash2 className="w-3 h-3" />}
|
||||||
|
>
|
||||||
|
删除
|
||||||
|
</Button>
|
||||||
|
</Popconfirm>
|
||||||
|
</Space>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{relation ? (
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<Descriptions column={1} size="small" bordered>
|
||||||
|
<Descriptions.Item label="关系类型">
|
||||||
|
<Tag color="blue">
|
||||||
|
{RELATION_TYPE_LABELS[relation.relationType] ?? relation.relationType}
|
||||||
|
</Tag>
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="源实体">
|
||||||
|
<div className="flex items-center gap-1.5">
|
||||||
|
<Tag
|
||||||
|
color={
|
||||||
|
ENTITY_TYPE_COLORS[relation.sourceEntityType] ?? DEFAULT_ENTITY_COLOR
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{ENTITY_TYPE_LABELS[relation.sourceEntityType] ?? relation.sourceEntityType}
|
||||||
|
</Tag>
|
||||||
|
<a
|
||||||
|
className="text-blue-500 cursor-pointer hover:underline"
|
||||||
|
onClick={() => onEntityNavigate(relation.sourceEntityId)}
|
||||||
|
>
|
||||||
|
{relation.sourceEntityName}
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</Descriptions.Item>
|
||||||
|
<Descriptions.Item label="目标实体">
|
||||||
|
<div className="flex items-center gap-1.5">
|
||||||
|
<Tag
|
||||||
|
color={
|
||||||
|
ENTITY_TYPE_COLORS[relation.targetEntityType] ?? DEFAULT_ENTITY_COLOR
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{ENTITY_TYPE_LABELS[relation.targetEntityType] ?? relation.targetEntityType}
|
||||||
|
</Tag>
|
||||||
|
<a
|
||||||
|
className="text-blue-500 cursor-pointer hover:underline"
|
||||||
|
onClick={() => onEntityNavigate(relation.targetEntityId)}
|
||||||
|
>
|
||||||
|
{relation.targetEntityName}
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
</Descriptions.Item>
|
||||||
|
{relation.weight != null && (
|
||||||
|
<Descriptions.Item label="权重">{relation.weight}</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{relation.confidence != null && (
|
||||||
|
<Descriptions.Item label="置信度">
|
||||||
|
{(relation.confidence * 100).toFixed(0)}%
|
||||||
|
</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
{relation.createdAt && (
|
||||||
|
<Descriptions.Item label="创建时间">{relation.createdAt}</Descriptions.Item>
|
||||||
|
)}
|
||||||
|
</Descriptions>
|
||||||
|
|
||||||
|
{relation.properties && Object.keys(relation.properties).length > 0 && (
|
||||||
|
<>
|
||||||
|
<h4 className="font-medium text-sm">扩展属性</h4>
|
||||||
|
<Descriptions column={1} size="small" bordered>
|
||||||
|
{Object.entries(relation.properties).map(([key, value]) => (
|
||||||
|
<Descriptions.Item key={key} label={key}>
|
||||||
|
{String(value)}
|
||||||
|
</Descriptions.Item>
|
||||||
|
))}
|
||||||
|
</Descriptions>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : !loading ? (
|
||||||
|
<Empty description="选择一条边查看详情" />
|
||||||
|
) : null}
|
||||||
|
</Spin>
|
||||||
|
</Drawer>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
import { useEffect, useState, useCallback } from "react";
|
||||||
|
import { Modal, Form, Select, InputNumber, message, Spin } from "antd";
|
||||||
|
import type { RelationVO, GraphEntity } from "../knowledge-graph.model";
|
||||||
|
import { RELATION_TYPES, RELATION_TYPE_LABELS } from "../knowledge-graph.const";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
interface RelationEditFormProps {
|
||||||
|
graphId: string;
|
||||||
|
relation?: RelationVO | null;
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onSuccess: () => void;
|
||||||
|
/** Pre-fill source entity when creating from a node context */
|
||||||
|
defaultSourceId?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function RelationEditForm({
|
||||||
|
graphId,
|
||||||
|
relation,
|
||||||
|
open,
|
||||||
|
onClose,
|
||||||
|
onSuccess,
|
||||||
|
defaultSourceId,
|
||||||
|
}: RelationEditFormProps) {
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const isEdit = !!relation;
|
||||||
|
const [entityOptions, setEntityOptions] = useState<
|
||||||
|
{ label: string; value: string }[]
|
||||||
|
>([]);
|
||||||
|
const [searchLoading, setSearchLoading] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (open && relation) {
|
||||||
|
form.setFieldsValue({
|
||||||
|
relationType: relation.relationType,
|
||||||
|
sourceEntityId: relation.sourceEntityId,
|
||||||
|
targetEntityId: relation.targetEntityId,
|
||||||
|
weight: relation.weight,
|
||||||
|
confidence: relation.confidence,
|
||||||
|
});
|
||||||
|
} else if (open) {
|
||||||
|
form.resetFields();
|
||||||
|
if (defaultSourceId) {
|
||||||
|
form.setFieldsValue({ sourceEntityId: defaultSourceId });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, [open, relation, form, defaultSourceId]);
|
||||||
|
|
||||||
|
const searchEntities = useCallback(
|
||||||
|
async (keyword: string) => {
|
||||||
|
if (!keyword.trim() || !graphId) return;
|
||||||
|
setSearchLoading(true);
|
||||||
|
try {
|
||||||
|
const result = await api.listEntitiesPaged(graphId, {
|
||||||
|
keyword,
|
||||||
|
page: 0,
|
||||||
|
size: 20,
|
||||||
|
});
|
||||||
|
setEntityOptions(
|
||||||
|
result.content.map((e: GraphEntity) => ({
|
||||||
|
label: `${e.name} (${e.type})`,
|
||||||
|
value: e.id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
} catch {
|
||||||
|
// ignore
|
||||||
|
} finally {
|
||||||
|
setSearchLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[graphId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSubmit = async () => {
|
||||||
|
let values;
|
||||||
|
try {
|
||||||
|
values = await form.validateFields();
|
||||||
|
} catch {
|
||||||
|
return; // Form validation failed — Antd shows inline errors
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (isEdit && relation) {
|
||||||
|
const payload = JSON.stringify({
|
||||||
|
relationType: values.relationType,
|
||||||
|
weight: values.weight,
|
||||||
|
confidence: values.confidence,
|
||||||
|
});
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "UPDATE_RELATION",
|
||||||
|
relationId: relation.id,
|
||||||
|
payload,
|
||||||
|
});
|
||||||
|
message.success("关系更新已提交审核");
|
||||||
|
} else {
|
||||||
|
const payload = JSON.stringify({
|
||||||
|
sourceEntityId: values.sourceEntityId,
|
||||||
|
targetEntityId: values.targetEntityId,
|
||||||
|
relationType: values.relationType,
|
||||||
|
weight: values.weight,
|
||||||
|
confidence: values.confidence,
|
||||||
|
});
|
||||||
|
await api.submitReview(graphId, {
|
||||||
|
operationType: "CREATE_RELATION",
|
||||||
|
payload,
|
||||||
|
});
|
||||||
|
message.success("关系创建已提交审核");
|
||||||
|
}
|
||||||
|
onSuccess();
|
||||||
|
onClose();
|
||||||
|
} catch {
|
||||||
|
message.error(isEdit ? "提交关系更新审核失败" : "提交关系创建审核失败");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={isEdit ? "编辑关系" : "创建关系"}
|
||||||
|
open={open}
|
||||||
|
onCancel={onClose}
|
||||||
|
onOk={handleSubmit}
|
||||||
|
okText="提交审核"
|
||||||
|
cancelText="取消"
|
||||||
|
destroyOnClose
|
||||||
|
>
|
||||||
|
<Form form={form} layout="vertical" className="mt-4">
|
||||||
|
<Form.Item
|
||||||
|
name="sourceEntityId"
|
||||||
|
label="源实体"
|
||||||
|
rules={[{ required: true, message: "请选择源实体" }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
showSearch
|
||||||
|
placeholder="搜索并选择源实体"
|
||||||
|
disabled={isEdit}
|
||||||
|
filterOption={false}
|
||||||
|
onSearch={searchEntities}
|
||||||
|
options={entityOptions}
|
||||||
|
notFoundContent={searchLoading ? <Spin size="small" /> : null}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="targetEntityId"
|
||||||
|
label="目标实体"
|
||||||
|
rules={[{ required: true, message: "请选择目标实体" }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
showSearch
|
||||||
|
placeholder="搜索并选择目标实体"
|
||||||
|
disabled={isEdit}
|
||||||
|
filterOption={false}
|
||||||
|
onSearch={searchEntities}
|
||||||
|
options={entityOptions}
|
||||||
|
notFoundContent={searchLoading ? <Spin size="small" /> : null}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="relationType"
|
||||||
|
label="关系类型"
|
||||||
|
rules={[{ required: true, message: "请选择关系类型" }]}
|
||||||
|
>
|
||||||
|
<Select
|
||||||
|
placeholder="选择关系类型"
|
||||||
|
options={RELATION_TYPES.map((t) => ({
|
||||||
|
label: RELATION_TYPE_LABELS[t] ?? t,
|
||||||
|
value: t,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="weight" label="权重">
|
||||||
|
<InputNumber min={0} max={1} step={0.1} className="w-full" />
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item name="confidence" label="置信度">
|
||||||
|
<InputNumber min={0} max={1} step={0.1} className="w-full" />
|
||||||
|
</Form.Item>
|
||||||
|
</Form>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
||||||
206
frontend/src/pages/KnowledgeGraph/components/ReviewPanel.tsx
Normal file
206
frontend/src/pages/KnowledgeGraph/components/ReviewPanel.tsx
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
import { useState, useCallback, useEffect } from "react";
|
||||||
|
import { List, Tag, Button, Empty, Spin, Popconfirm, Input, message } from "antd";
|
||||||
|
import { Check, X } from "lucide-react";
|
||||||
|
import type { EditReviewVO, PagedResponse } from "../knowledge-graph.model";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
const OPERATION_LABELS: Record<string, string> = {
|
||||||
|
CREATE_ENTITY: "创建实体",
|
||||||
|
UPDATE_ENTITY: "更新实体",
|
||||||
|
DELETE_ENTITY: "删除实体",
|
||||||
|
CREATE_RELATION: "创建关系",
|
||||||
|
UPDATE_RELATION: "更新关系",
|
||||||
|
DELETE_RELATION: "删除关系",
|
||||||
|
};
|
||||||
|
|
||||||
|
const STATUS_COLORS: Record<string, string> = {
|
||||||
|
PENDING: "orange",
|
||||||
|
APPROVED: "green",
|
||||||
|
REJECTED: "red",
|
||||||
|
};
|
||||||
|
|
||||||
|
const STATUS_LABELS: Record<string, string> = {
|
||||||
|
PENDING: "待审核",
|
||||||
|
APPROVED: "已通过",
|
||||||
|
REJECTED: "已拒绝",
|
||||||
|
};
|
||||||
|
|
||||||
|
interface ReviewPanelProps {
|
||||||
|
graphId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function ReviewPanel({ graphId }: ReviewPanelProps) {
|
||||||
|
const [reviews, setReviews] = useState<EditReviewVO[]>([]);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [total, setTotal] = useState(0);
|
||||||
|
|
||||||
|
const loadReviews = useCallback(async () => {
|
||||||
|
if (!graphId) return;
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const result: PagedResponse<EditReviewVO> = await api.listPendingReviews(
|
||||||
|
graphId,
|
||||||
|
{ page: 0, size: 50 }
|
||||||
|
);
|
||||||
|
setReviews(result.content);
|
||||||
|
setTotal(result.totalElements);
|
||||||
|
} catch {
|
||||||
|
message.error("加载审核列表失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, [graphId]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
loadReviews();
|
||||||
|
}, [loadReviews]);
|
||||||
|
|
||||||
|
const handleApprove = useCallback(
|
||||||
|
async (reviewId: string) => {
|
||||||
|
try {
|
||||||
|
await api.approveReview(graphId, reviewId);
|
||||||
|
message.success("审核通过");
|
||||||
|
loadReviews();
|
||||||
|
} catch {
|
||||||
|
message.error("审核操作失败");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[graphId, loadReviews]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReject = useCallback(
|
||||||
|
async (reviewId: string, comment: string) => {
|
||||||
|
try {
|
||||||
|
await api.rejectReview(graphId, reviewId, { comment });
|
||||||
|
message.success("已拒绝");
|
||||||
|
loadReviews();
|
||||||
|
} catch {
|
||||||
|
message.error("审核操作失败");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[graphId, loadReviews]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-xs text-gray-500">
|
||||||
|
待审核: {total}
|
||||||
|
</span>
|
||||||
|
<Button size="small" onClick={loadReviews}>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
{reviews.length > 0 ? (
|
||||||
|
<List
|
||||||
|
size="small"
|
||||||
|
dataSource={reviews}
|
||||||
|
renderItem={(review) => (
|
||||||
|
<ReviewItem
|
||||||
|
review={review}
|
||||||
|
onApprove={handleApprove}
|
||||||
|
onReject={handleReject}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Empty
|
||||||
|
description="暂无待审核项"
|
||||||
|
image={Empty.PRESENTED_IMAGE_SIMPLE}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Spin>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function ReviewItem({
|
||||||
|
review,
|
||||||
|
onApprove,
|
||||||
|
onReject,
|
||||||
|
}: {
|
||||||
|
review: EditReviewVO;
|
||||||
|
onApprove: (id: string) => void;
|
||||||
|
onReject: (id: string, comment: string) => void;
|
||||||
|
}) {
|
||||||
|
const [rejectComment, setRejectComment] = useState("");
|
||||||
|
|
||||||
|
const payload = review.payload ? tryParsePayload(review.payload) : null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<List.Item className="!px-2">
|
||||||
|
<div className="flex flex-col gap-1.5 w-full">
|
||||||
|
<div className="flex items-center gap-1.5">
|
||||||
|
<Tag color={STATUS_COLORS[review.status] ?? "default"}>
|
||||||
|
{STATUS_LABELS[review.status] ?? review.status}
|
||||||
|
</Tag>
|
||||||
|
<span className="text-xs font-medium">
|
||||||
|
{OPERATION_LABELS[review.operationType] ?? review.operationType}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{payload && (
|
||||||
|
<div className="text-xs text-gray-500 truncate">
|
||||||
|
{payload.name && <span>名称: {payload.name} </span>}
|
||||||
|
{payload.relationType && <span>类型: {payload.relationType}</span>}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="text-xs text-gray-400">
|
||||||
|
{review.submittedBy && <span>提交人: {review.submittedBy}</span>}
|
||||||
|
{review.createdAt && <span className="ml-2">{review.createdAt}</span>}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{review.status === "PENDING" && (
|
||||||
|
<div className="flex gap-1.5 mt-1">
|
||||||
|
<Button
|
||||||
|
type="primary"
|
||||||
|
size="small"
|
||||||
|
icon={<Check className="w-3 h-3" />}
|
||||||
|
onClick={() => onApprove(review.id)}
|
||||||
|
>
|
||||||
|
通过
|
||||||
|
</Button>
|
||||||
|
<Popconfirm
|
||||||
|
title="拒绝审核"
|
||||||
|
description={
|
||||||
|
<Input.TextArea
|
||||||
|
rows={2}
|
||||||
|
placeholder="拒绝原因(可选)"
|
||||||
|
value={rejectComment}
|
||||||
|
onChange={(e) => setRejectComment(e.target.value)}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
onConfirm={() => {
|
||||||
|
onReject(review.id, rejectComment);
|
||||||
|
setRejectComment("");
|
||||||
|
}}
|
||||||
|
okText="确认拒绝"
|
||||||
|
cancelText="取消"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
size="small"
|
||||||
|
danger
|
||||||
|
icon={<X className="w-3 h-3" />}
|
||||||
|
>
|
||||||
|
拒绝
|
||||||
|
</Button>
|
||||||
|
</Popconfirm>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</List.Item>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function tryParsePayload(
|
||||||
|
payload: string
|
||||||
|
): Record<string, unknown> | null {
|
||||||
|
try {
|
||||||
|
return JSON.parse(payload);
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
102
frontend/src/pages/KnowledgeGraph/components/SearchPanel.tsx
Normal file
102
frontend/src/pages/KnowledgeGraph/components/SearchPanel.tsx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import { useState, useCallback } from "react";
|
||||||
|
import { Input, List, Tag, Select, Empty } from "antd";
|
||||||
|
import { Search } from "lucide-react";
|
||||||
|
import type { SearchHitVO } from "../knowledge-graph.model";
|
||||||
|
import {
|
||||||
|
ENTITY_TYPES,
|
||||||
|
ENTITY_TYPE_LABELS,
|
||||||
|
ENTITY_TYPE_COLORS,
|
||||||
|
DEFAULT_ENTITY_COLOR,
|
||||||
|
} from "../knowledge-graph.const";
|
||||||
|
|
||||||
|
interface SearchPanelProps {
|
||||||
|
graphId: string;
|
||||||
|
results: SearchHitVO[];
|
||||||
|
loading: boolean;
|
||||||
|
onSearch: (graphId: string, query: string) => void;
|
||||||
|
onResultClick: (entityId: string) => void;
|
||||||
|
onClear: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function SearchPanel({
|
||||||
|
graphId,
|
||||||
|
results,
|
||||||
|
loading,
|
||||||
|
onSearch,
|
||||||
|
onResultClick,
|
||||||
|
onClear,
|
||||||
|
}: SearchPanelProps) {
|
||||||
|
const [query, setQuery] = useState("");
|
||||||
|
const [typeFilter, setTypeFilter] = useState<string | undefined>(undefined);
|
||||||
|
|
||||||
|
const handleSearch = useCallback(
|
||||||
|
(value: string) => {
|
||||||
|
setQuery(value);
|
||||||
|
if (!value.trim()) {
|
||||||
|
onClear();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onSearch(graphId, value);
|
||||||
|
},
|
||||||
|
[graphId, onSearch, onClear]
|
||||||
|
);
|
||||||
|
|
||||||
|
const filteredResults = typeFilter
|
||||||
|
? results.filter((r) => r.type === typeFilter)
|
||||||
|
: results;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-3">
|
||||||
|
<Input.Search
|
||||||
|
placeholder="搜索实体名称..."
|
||||||
|
value={query}
|
||||||
|
onChange={(e) => setQuery(e.target.value)}
|
||||||
|
onSearch={handleSearch}
|
||||||
|
allowClear
|
||||||
|
onClear={() => {
|
||||||
|
setQuery("");
|
||||||
|
onClear();
|
||||||
|
}}
|
||||||
|
prefix={<Search className="w-4 h-4 text-gray-400" />}
|
||||||
|
loading={loading}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Select
|
||||||
|
allowClear
|
||||||
|
placeholder="按类型筛选"
|
||||||
|
value={typeFilter}
|
||||||
|
onChange={setTypeFilter}
|
||||||
|
className="w-full"
|
||||||
|
options={ENTITY_TYPES.map((t) => ({
|
||||||
|
label: ENTITY_TYPE_LABELS[t] ?? t,
|
||||||
|
value: t,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{filteredResults.length > 0 ? (
|
||||||
|
<List
|
||||||
|
size="small"
|
||||||
|
dataSource={filteredResults}
|
||||||
|
renderItem={(item) => (
|
||||||
|
<List.Item
|
||||||
|
className="cursor-pointer hover:bg-gray-50 !px-2"
|
||||||
|
onClick={() => onResultClick(item.id)}
|
||||||
|
>
|
||||||
|
<div className="flex items-center gap-2 w-full min-w-0">
|
||||||
|
<Tag color={ENTITY_TYPE_COLORS[item.type] ?? DEFAULT_ENTITY_COLOR}>
|
||||||
|
{ENTITY_TYPE_LABELS[item.type] ?? item.type}
|
||||||
|
</Tag>
|
||||||
|
<span className="truncate font-medium text-sm">{item.name}</span>
|
||||||
|
<span className="ml-auto text-xs text-gray-400 shrink-0">
|
||||||
|
{item.score.toFixed(2)}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</List.Item>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
) : query && !loading ? (
|
||||||
|
<Empty description="未找到匹配实体" image={Empty.PRESENTED_IMAGE_SIMPLE} />
|
||||||
|
) : null}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
106
frontend/src/pages/KnowledgeGraph/graphConfig.ts
Normal file
106
frontend/src/pages/KnowledgeGraph/graphConfig.ts
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import { ENTITY_TYPE_COLORS, DEFAULT_ENTITY_COLOR } from "./knowledge-graph.const";
|
||||||
|
|
||||||
|
/** Node count threshold above which performance optimizations kick in. */
|
||||||
|
export const LARGE_GRAPH_THRESHOLD = 200;
|
||||||
|
|
||||||
|
/** Create the G6 v5 graph options. */
|
||||||
|
export function createGraphOptions(container: HTMLElement, multiSelect = false) {
|
||||||
|
return {
|
||||||
|
container,
|
||||||
|
autoFit: "view" as const,
|
||||||
|
padding: 40,
|
||||||
|
animation: true,
|
||||||
|
layout: {
|
||||||
|
type: "d3-force" as const,
|
||||||
|
preventOverlap: true,
|
||||||
|
link: {
|
||||||
|
distance: 180,
|
||||||
|
},
|
||||||
|
charge: {
|
||||||
|
strength: -400,
|
||||||
|
},
|
||||||
|
collide: {
|
||||||
|
radius: 50,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
node: {
|
||||||
|
type: "circle" as const,
|
||||||
|
style: {
|
||||||
|
size: (d: { data?: { type?: string } }) => {
|
||||||
|
return d?.data?.type === "Dataset" ? 40 : 32;
|
||||||
|
},
|
||||||
|
fill: (d: { data?: { type?: string } }) => {
|
||||||
|
const type = d?.data?.type ?? "";
|
||||||
|
return ENTITY_TYPE_COLORS[type] ?? DEFAULT_ENTITY_COLOR;
|
||||||
|
},
|
||||||
|
stroke: "#fff",
|
||||||
|
lineWidth: 2,
|
||||||
|
labelText: (d: { data?: { label?: string } }) => d?.data?.label ?? "",
|
||||||
|
labelFontSize: 11,
|
||||||
|
labelFill: "#333",
|
||||||
|
labelPlacement: "bottom" as const,
|
||||||
|
labelOffsetY: 4,
|
||||||
|
labelMaxWidth: 100,
|
||||||
|
labelWordWrap: true,
|
||||||
|
labelWordWrapWidth: 100,
|
||||||
|
cursor: "pointer",
|
||||||
|
},
|
||||||
|
state: {
|
||||||
|
selected: {
|
||||||
|
stroke: "#1677ff",
|
||||||
|
lineWidth: 3,
|
||||||
|
shadowColor: "rgba(22, 119, 255, 0.4)",
|
||||||
|
shadowBlur: 10,
|
||||||
|
labelVisibility: "visible" as const,
|
||||||
|
},
|
||||||
|
highlighted: {
|
||||||
|
stroke: "#faad14",
|
||||||
|
lineWidth: 3,
|
||||||
|
labelVisibility: "visible" as const,
|
||||||
|
},
|
||||||
|
dimmed: {
|
||||||
|
opacity: 0.3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edge: {
|
||||||
|
type: "line" as const,
|
||||||
|
style: {
|
||||||
|
stroke: "#C2C8D5",
|
||||||
|
lineWidth: 1,
|
||||||
|
endArrow: true,
|
||||||
|
endArrowSize: 6,
|
||||||
|
labelText: (d: { data?: { label?: string } }) => d?.data?.label ?? "",
|
||||||
|
labelFontSize: 10,
|
||||||
|
labelFill: "#999",
|
||||||
|
labelBackground: true,
|
||||||
|
labelBackgroundFill: "#fff",
|
||||||
|
labelBackgroundOpacity: 0.85,
|
||||||
|
labelPadding: [2, 4],
|
||||||
|
cursor: "pointer",
|
||||||
|
},
|
||||||
|
state: {
|
||||||
|
selected: {
|
||||||
|
stroke: "#1677ff",
|
||||||
|
lineWidth: 2,
|
||||||
|
},
|
||||||
|
highlighted: {
|
||||||
|
stroke: "#faad14",
|
||||||
|
lineWidth: 2,
|
||||||
|
},
|
||||||
|
dimmed: {
|
||||||
|
opacity: 0.15,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
behaviors: [
|
||||||
|
"drag-canvas",
|
||||||
|
"zoom-canvas",
|
||||||
|
"drag-element",
|
||||||
|
{
|
||||||
|
type: "click-select" as const,
|
||||||
|
multiple: multiSelect,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
}
|
||||||
77
frontend/src/pages/KnowledgeGraph/graphTransform.ts
Normal file
77
frontend/src/pages/KnowledgeGraph/graphTransform.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import type { EntitySummaryVO, EdgeSummaryVO, SubgraphVO } from "./knowledge-graph.model";
|
||||||
|
import { ENTITY_TYPE_COLORS, DEFAULT_ENTITY_COLOR, RELATION_TYPE_LABELS } from "./knowledge-graph.const";
|
||||||
|
|
||||||
|
export interface G6NodeData {
|
||||||
|
id: string;
|
||||||
|
data: {
|
||||||
|
label: string;
|
||||||
|
type: string;
|
||||||
|
description?: string;
|
||||||
|
};
|
||||||
|
style?: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface G6EdgeData {
|
||||||
|
id: string;
|
||||||
|
source: string;
|
||||||
|
target: string;
|
||||||
|
data: {
|
||||||
|
label: string;
|
||||||
|
relationType: string;
|
||||||
|
weight?: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface G6GraphData {
|
||||||
|
nodes: G6NodeData[];
|
||||||
|
edges: G6EdgeData[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function entityToG6Node(entity: EntitySummaryVO): G6NodeData {
|
||||||
|
return {
|
||||||
|
id: entity.id,
|
||||||
|
data: {
|
||||||
|
label: entity.name,
|
||||||
|
type: entity.type,
|
||||||
|
description: entity.description,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function edgeToG6Edge(edge: EdgeSummaryVO): G6EdgeData {
|
||||||
|
return {
|
||||||
|
id: edge.id,
|
||||||
|
source: edge.sourceEntityId,
|
||||||
|
target: edge.targetEntityId,
|
||||||
|
data: {
|
||||||
|
label: RELATION_TYPE_LABELS[edge.relationType] ?? edge.relationType,
|
||||||
|
relationType: edge.relationType,
|
||||||
|
weight: edge.weight,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function subgraphToG6Data(subgraph: SubgraphVO): G6GraphData {
|
||||||
|
return {
|
||||||
|
nodes: subgraph.nodes.map(entityToG6Node),
|
||||||
|
edges: subgraph.edges.map(edgeToG6Edge),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Merge new subgraph data into existing graph data, avoiding duplicates. */
|
||||||
|
export function mergeG6Data(existing: G6GraphData, incoming: G6GraphData): G6GraphData {
|
||||||
|
const nodeIds = new Set(existing.nodes.map((n) => n.id));
|
||||||
|
const edgeIds = new Set(existing.edges.map((e) => e.id));
|
||||||
|
|
||||||
|
const newNodes = incoming.nodes.filter((n) => !nodeIds.has(n.id));
|
||||||
|
const newEdges = incoming.edges.filter((e) => !edgeIds.has(e.id));
|
||||||
|
|
||||||
|
return {
|
||||||
|
nodes: [...existing.nodes, ...newNodes],
|
||||||
|
edges: [...existing.edges, ...newEdges],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getEntityColor(type: string): string {
|
||||||
|
return ENTITY_TYPE_COLORS[type] ?? DEFAULT_ENTITY_COLOR;
|
||||||
|
}
|
||||||
141
frontend/src/pages/KnowledgeGraph/hooks/useGraphData.ts
Normal file
141
frontend/src/pages/KnowledgeGraph/hooks/useGraphData.ts
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import { useState, useCallback, useRef } from "react";
|
||||||
|
import { message } from "antd";
|
||||||
|
import type { SubgraphVO, SearchHitVO, EntitySummaryVO, EdgeSummaryVO } from "../knowledge-graph.model";
|
||||||
|
import type { G6GraphData } from "../graphTransform";
|
||||||
|
import { subgraphToG6Data, mergeG6Data } from "../graphTransform";
|
||||||
|
import * as api from "../knowledge-graph.api";
|
||||||
|
|
||||||
|
export interface UseGraphDataReturn {
|
||||||
|
graphData: G6GraphData;
|
||||||
|
loading: boolean;
|
||||||
|
searchResults: SearchHitVO[];
|
||||||
|
searchLoading: boolean;
|
||||||
|
highlightedNodeIds: Set<string>;
|
||||||
|
loadSubgraph: (graphId: string, entityIds: string[], depth?: number) => Promise<void>;
|
||||||
|
expandNode: (graphId: string, entityId: string, depth?: number) => Promise<void>;
|
||||||
|
searchEntities: (graphId: string, query: string) => Promise<void>;
|
||||||
|
loadInitialData: (graphId: string) => Promise<void>;
|
||||||
|
mergePathData: (nodes: EntitySummaryVO[], edges: EdgeSummaryVO[]) => void;
|
||||||
|
clearGraph: () => void;
|
||||||
|
clearSearch: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function useGraphData(): UseGraphDataReturn {
|
||||||
|
const [graphData, setGraphData] = useState<G6GraphData>({ nodes: [], edges: [] });
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [searchResults, setSearchResults] = useState<SearchHitVO[]>([]);
|
||||||
|
const [searchLoading, setSearchLoading] = useState(false);
|
||||||
|
const [highlightedNodeIds, setHighlightedNodeIds] = useState<Set<string>>(new Set());
|
||||||
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
|
const loadInitialData = useCallback(async (graphId: string) => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const entities = await api.listEntitiesPaged(graphId, { page: 0, size: 100 });
|
||||||
|
const entityIds = entities.content.map((e) => e.id);
|
||||||
|
if (entityIds.length === 0) {
|
||||||
|
setGraphData({ nodes: [], edges: [] });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const subgraph: SubgraphVO = await api.getSubgraph(graphId, { entityIds }, { depth: 1 });
|
||||||
|
setGraphData(subgraphToG6Data(subgraph));
|
||||||
|
} catch {
|
||||||
|
message.error("加载图谱数据失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const loadSubgraph = useCallback(async (graphId: string, entityIds: string[], depth = 1) => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const subgraph = await api.getSubgraph(graphId, { entityIds }, { depth });
|
||||||
|
setGraphData(subgraphToG6Data(subgraph));
|
||||||
|
} catch {
|
||||||
|
message.error("加载子图失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const expandNode = useCallback(
|
||||||
|
async (graphId: string, entityId: string, depth = 1) => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const subgraph = await api.getNeighborSubgraph(graphId, entityId, { depth, limit: 50 });
|
||||||
|
const incoming = subgraphToG6Data(subgraph);
|
||||||
|
setGraphData((prev) => mergeG6Data(prev, incoming));
|
||||||
|
} catch {
|
||||||
|
message.error("展开节点失败");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const searchEntitiesFn = useCallback(async (graphId: string, query: string) => {
|
||||||
|
if (!query.trim()) {
|
||||||
|
setSearchResults([]);
|
||||||
|
setHighlightedNodeIds(new Set());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
abortRef.current?.abort();
|
||||||
|
const controller = new AbortController();
|
||||||
|
abortRef.current = controller;
|
||||||
|
setSearchLoading(true);
|
||||||
|
try {
|
||||||
|
const result = await api.searchEntities(graphId, { q: query, size: 20 }, { signal: controller.signal });
|
||||||
|
setSearchResults(result.content);
|
||||||
|
setHighlightedNodeIds(new Set(result.content.map((h) => h.id)));
|
||||||
|
} catch {
|
||||||
|
// ignore abort errors
|
||||||
|
} finally {
|
||||||
|
setSearchLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const clearGraph = useCallback(() => {
|
||||||
|
setGraphData({ nodes: [], edges: [] });
|
||||||
|
setSearchResults([]);
|
||||||
|
setHighlightedNodeIds(new Set());
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const clearSearch = useCallback(() => {
|
||||||
|
setSearchResults([]);
|
||||||
|
setHighlightedNodeIds(new Set());
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const mergePathData = useCallback(
|
||||||
|
(nodes: EntitySummaryVO[], edges: EdgeSummaryVO[]) => {
|
||||||
|
if (nodes.length === 0) {
|
||||||
|
setHighlightedNodeIds(new Set());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const pathData = subgraphToG6Data({
|
||||||
|
nodes,
|
||||||
|
edges,
|
||||||
|
nodeCount: nodes.length,
|
||||||
|
edgeCount: edges.length,
|
||||||
|
});
|
||||||
|
setGraphData((prev) => mergeG6Data(prev, pathData));
|
||||||
|
setHighlightedNodeIds(new Set(nodes.map((n) => n.id)));
|
||||||
|
},
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
graphData,
|
||||||
|
loading,
|
||||||
|
searchResults,
|
||||||
|
searchLoading,
|
||||||
|
highlightedNodeIds,
|
||||||
|
loadSubgraph,
|
||||||
|
expandNode,
|
||||||
|
searchEntities: searchEntitiesFn,
|
||||||
|
loadInitialData,
|
||||||
|
mergePathData,
|
||||||
|
clearGraph,
|
||||||
|
clearSearch,
|
||||||
|
};
|
||||||
|
}
|
||||||
61
frontend/src/pages/KnowledgeGraph/hooks/useGraphLayout.ts
Normal file
61
frontend/src/pages/KnowledgeGraph/hooks/useGraphLayout.ts
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import { useState, useCallback } from "react";
|
||||||
|
|
||||||
|
export type LayoutType = "d3-force" | "circular" | "grid" | "radial" | "concentric";
|
||||||
|
|
||||||
|
interface LayoutConfig {
|
||||||
|
type: LayoutType;
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
const LAYOUT_CONFIGS: Record<LayoutType, LayoutConfig> = {
|
||||||
|
"d3-force": {
|
||||||
|
type: "d3-force",
|
||||||
|
preventOverlap: true,
|
||||||
|
link: { distance: 180 },
|
||||||
|
charge: { strength: -400 },
|
||||||
|
collide: { radius: 50 },
|
||||||
|
},
|
||||||
|
circular: {
|
||||||
|
type: "circular",
|
||||||
|
radius: 250,
|
||||||
|
},
|
||||||
|
grid: {
|
||||||
|
type: "grid",
|
||||||
|
rows: undefined,
|
||||||
|
cols: undefined,
|
||||||
|
sortBy: "type",
|
||||||
|
},
|
||||||
|
radial: {
|
||||||
|
type: "radial",
|
||||||
|
unitRadius: 120,
|
||||||
|
preventOverlap: true,
|
||||||
|
nodeSpacing: 30,
|
||||||
|
},
|
||||||
|
concentric: {
|
||||||
|
type: "concentric",
|
||||||
|
preventOverlap: true,
|
||||||
|
nodeSpacing: 30,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
export const LAYOUT_OPTIONS: { label: string; value: LayoutType }[] = [
|
||||||
|
{ label: "力导向", value: "d3-force" },
|
||||||
|
{ label: "环形", value: "circular" },
|
||||||
|
{ label: "网格", value: "grid" },
|
||||||
|
{ label: "径向", value: "radial" },
|
||||||
|
{ label: "同心圆", value: "concentric" },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function useGraphLayout() {
|
||||||
|
const [layoutType, setLayoutType] = useState<LayoutType>("d3-force");
|
||||||
|
|
||||||
|
const getLayoutConfig = useCallback((): LayoutConfig => {
|
||||||
|
return LAYOUT_CONFIGS[layoutType] ?? LAYOUT_CONFIGS["d3-force"];
|
||||||
|
}, [layoutType]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
layoutType,
|
||||||
|
setLayoutType,
|
||||||
|
getLayoutConfig,
|
||||||
|
};
|
||||||
|
}
|
||||||
193
frontend/src/pages/KnowledgeGraph/knowledge-graph.api.ts
Normal file
193
frontend/src/pages/KnowledgeGraph/knowledge-graph.api.ts
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import { get, post, del, put } from "@/utils/request";
|
||||||
|
import type {
|
||||||
|
GraphEntity,
|
||||||
|
SubgraphVO,
|
||||||
|
RelationVO,
|
||||||
|
SearchHitVO,
|
||||||
|
PagedResponse,
|
||||||
|
PathVO,
|
||||||
|
AllPathsVO,
|
||||||
|
EditReviewVO,
|
||||||
|
} from "./knowledge-graph.model";
|
||||||
|
|
||||||
|
const BASE = "/api/knowledge-graph";
|
||||||
|
|
||||||
|
// ---- Entity ----
|
||||||
|
|
||||||
|
export function getEntity(graphId: string, entityId: string): Promise<GraphEntity> {
|
||||||
|
return get(`${BASE}/${graphId}/entities/${entityId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listEntities(
|
||||||
|
graphId: string,
|
||||||
|
params?: { type?: string; keyword?: string }
|
||||||
|
): Promise<GraphEntity[]> {
|
||||||
|
return get(`${BASE}/${graphId}/entities`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listEntitiesPaged(
|
||||||
|
graphId: string,
|
||||||
|
params: { type?: string; keyword?: string; page?: number; size?: number }
|
||||||
|
): Promise<PagedResponse<GraphEntity>> {
|
||||||
|
return get(`${BASE}/${graphId}/entities`, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createEntity(
|
||||||
|
graphId: string,
|
||||||
|
data: { name: string; type: string; description?: string; aliases?: string[]; properties?: Record<string, unknown>; confidence?: number }
|
||||||
|
): Promise<GraphEntity> {
|
||||||
|
return post(`${BASE}/${graphId}/entities`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function updateEntity(
|
||||||
|
graphId: string,
|
||||||
|
entityId: string,
|
||||||
|
data: { name?: string; description?: string; aliases?: string[]; properties?: Record<string, unknown>; confidence?: number }
|
||||||
|
): Promise<GraphEntity> {
|
||||||
|
return put(`${BASE}/${graphId}/entities/${entityId}`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function deleteEntity(graphId: string, entityId: string): Promise<void> {
|
||||||
|
return del(`${BASE}/${graphId}/entities/${entityId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Relation ----
|
||||||
|
|
||||||
|
export function getRelation(graphId: string, relationId: string): Promise<RelationVO> {
|
||||||
|
return get(`${BASE}/${graphId}/relations/${relationId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listRelations(
|
||||||
|
graphId: string,
|
||||||
|
params?: { type?: string; page?: number; size?: number }
|
||||||
|
): Promise<PagedResponse<RelationVO>> {
|
||||||
|
return get(`${BASE}/${graphId}/relations`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createRelation(
|
||||||
|
graphId: string,
|
||||||
|
data: {
|
||||||
|
sourceEntityId: string;
|
||||||
|
targetEntityId: string;
|
||||||
|
relationType: string;
|
||||||
|
properties?: Record<string, unknown>;
|
||||||
|
weight?: number;
|
||||||
|
confidence?: number;
|
||||||
|
}
|
||||||
|
): Promise<RelationVO> {
|
||||||
|
return post(`${BASE}/${graphId}/relations`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function updateRelation(
|
||||||
|
graphId: string,
|
||||||
|
relationId: string,
|
||||||
|
data: { relationType?: string; properties?: Record<string, unknown>; weight?: number; confidence?: number }
|
||||||
|
): Promise<RelationVO> {
|
||||||
|
return put(`${BASE}/${graphId}/relations/${relationId}`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function deleteRelation(graphId: string, relationId: string): Promise<void> {
|
||||||
|
return del(`${BASE}/${graphId}/relations/${relationId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listEntityRelations(
|
||||||
|
graphId: string,
|
||||||
|
entityId: string,
|
||||||
|
params?: { direction?: string; type?: string; page?: number; size?: number }
|
||||||
|
): Promise<PagedResponse<RelationVO>> {
|
||||||
|
return get(`${BASE}/${graphId}/entities/${entityId}/relations`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Query ----
|
||||||
|
|
||||||
|
export function getNeighborSubgraph(
|
||||||
|
graphId: string,
|
||||||
|
entityId: string,
|
||||||
|
params?: { depth?: number; limit?: number }
|
||||||
|
): Promise<SubgraphVO> {
|
||||||
|
return get(`${BASE}/${graphId}/query/neighbors/${entityId}`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getSubgraph(
|
||||||
|
graphId: string,
|
||||||
|
data: { entityIds: string[] },
|
||||||
|
params?: { depth?: number }
|
||||||
|
): Promise<SubgraphVO> {
|
||||||
|
return post(`${BASE}/${graphId}/query/subgraph/export?depth=${params?.depth ?? 1}`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getShortestPath(
|
||||||
|
graphId: string,
|
||||||
|
params: { sourceId: string; targetId: string; maxDepth?: number }
|
||||||
|
): Promise<PathVO> {
|
||||||
|
return get(`${BASE}/${graphId}/query/shortest-path`, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getAllPaths(
|
||||||
|
graphId: string,
|
||||||
|
params: { sourceId: string; targetId: string; maxDepth?: number; maxPaths?: number }
|
||||||
|
): Promise<AllPathsVO> {
|
||||||
|
return get(`${BASE}/${graphId}/query/all-paths`, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function searchEntities(
|
||||||
|
graphId: string,
|
||||||
|
params: { q: string; page?: number; size?: number },
|
||||||
|
options?: { signal?: AbortSignal }
|
||||||
|
): Promise<PagedResponse<SearchHitVO>> {
|
||||||
|
return get(`${BASE}/${graphId}/query/search`, params, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Neighbors (entity controller) ----
|
||||||
|
|
||||||
|
export function getEntityNeighbors(
|
||||||
|
graphId: string,
|
||||||
|
entityId: string,
|
||||||
|
params?: { depth?: number; limit?: number }
|
||||||
|
): Promise<GraphEntity[]> {
|
||||||
|
return get(`${BASE}/${graphId}/entities/${entityId}/neighbors`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Review ----
|
||||||
|
|
||||||
|
export function submitReview(
|
||||||
|
graphId: string,
|
||||||
|
data: {
|
||||||
|
operationType: string;
|
||||||
|
entityId?: string;
|
||||||
|
relationId?: string;
|
||||||
|
payload?: string;
|
||||||
|
}
|
||||||
|
): Promise<EditReviewVO> {
|
||||||
|
return post(`${BASE}/${graphId}/review/submit`, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function approveReview(
|
||||||
|
graphId: string,
|
||||||
|
reviewId: string,
|
||||||
|
data?: { comment?: string }
|
||||||
|
): Promise<EditReviewVO> {
|
||||||
|
return post(`${BASE}/${graphId}/review/${reviewId}/approve`, data ?? {});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function rejectReview(
|
||||||
|
graphId: string,
|
||||||
|
reviewId: string,
|
||||||
|
data?: { comment?: string }
|
||||||
|
): Promise<EditReviewVO> {
|
||||||
|
return post(`${BASE}/${graphId}/review/${reviewId}/reject`, data ?? {});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listPendingReviews(
|
||||||
|
graphId: string,
|
||||||
|
params?: { page?: number; size?: number }
|
||||||
|
): Promise<PagedResponse<EditReviewVO>> {
|
||||||
|
return get(`${BASE}/${graphId}/review/pending`, params ?? null);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listReviews(
|
||||||
|
graphId: string,
|
||||||
|
params?: { status?: string; page?: number; size?: number }
|
||||||
|
): Promise<PagedResponse<EditReviewVO>> {
|
||||||
|
return get(`${BASE}/${graphId}/review`, params ?? null);
|
||||||
|
}
|
||||||
46
frontend/src/pages/KnowledgeGraph/knowledge-graph.const.ts
Normal file
46
frontend/src/pages/KnowledgeGraph/knowledge-graph.const.ts
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
/** Entity type -> display color mapping */
|
||||||
|
export const ENTITY_TYPE_COLORS: Record<string, string> = {
|
||||||
|
Dataset: "#5B8FF9",
|
||||||
|
Field: "#5AD8A6",
|
||||||
|
User: "#F6BD16",
|
||||||
|
Org: "#E86452",
|
||||||
|
Workflow: "#6DC8EC",
|
||||||
|
Job: "#945FB9",
|
||||||
|
LabelTask: "#FF9845",
|
||||||
|
KnowledgeSet: "#1E9493",
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Default color for unknown entity types */
|
||||||
|
export const DEFAULT_ENTITY_COLOR = "#9CA3AF";
|
||||||
|
|
||||||
|
/** Relation type -> Chinese label mapping */
|
||||||
|
export const RELATION_TYPE_LABELS: Record<string, string> = {
|
||||||
|
HAS_FIELD: "包含字段",
|
||||||
|
DERIVED_FROM: "来源于",
|
||||||
|
USES_DATASET: "使用数据集",
|
||||||
|
PRODUCES: "产出",
|
||||||
|
ASSIGNED_TO: "分配给",
|
||||||
|
BELONGS_TO: "属于",
|
||||||
|
TRIGGERS: "触发",
|
||||||
|
DEPENDS_ON: "依赖",
|
||||||
|
IMPACTS: "影响",
|
||||||
|
SOURCED_FROM: "知识来源",
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Entity type -> Chinese label mapping */
|
||||||
|
export const ENTITY_TYPE_LABELS: Record<string, string> = {
|
||||||
|
Dataset: "数据集",
|
||||||
|
Field: "字段",
|
||||||
|
User: "用户",
|
||||||
|
Org: "组织",
|
||||||
|
Workflow: "工作流",
|
||||||
|
Job: "作业",
|
||||||
|
LabelTask: "标注任务",
|
||||||
|
KnowledgeSet: "知识集",
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Available entity types for filtering */
|
||||||
|
export const ENTITY_TYPES = Object.keys(ENTITY_TYPE_LABELS);
|
||||||
|
|
||||||
|
/** Available relation types for filtering */
|
||||||
|
export const RELATION_TYPES = Object.keys(RELATION_TYPE_LABELS);
|
||||||
108
frontend/src/pages/KnowledgeGraph/knowledge-graph.model.ts
Normal file
108
frontend/src/pages/KnowledgeGraph/knowledge-graph.model.ts
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
export interface GraphEntity {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
type: string;
|
||||||
|
description?: string;
|
||||||
|
labels?: string[];
|
||||||
|
aliases?: string[];
|
||||||
|
properties?: Record<string, unknown>;
|
||||||
|
sourceId?: string;
|
||||||
|
sourceType?: string;
|
||||||
|
graphId: string;
|
||||||
|
confidence?: number;
|
||||||
|
createdAt?: string;
|
||||||
|
updatedAt?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EntitySummaryVO {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
type: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EdgeSummaryVO {
|
||||||
|
id: string;
|
||||||
|
sourceEntityId: string;
|
||||||
|
targetEntityId: string;
|
||||||
|
relationType: string;
|
||||||
|
weight?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SubgraphVO {
|
||||||
|
nodes: EntitySummaryVO[];
|
||||||
|
edges: EdgeSummaryVO[];
|
||||||
|
nodeCount: number;
|
||||||
|
edgeCount: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RelationVO {
|
||||||
|
id: string;
|
||||||
|
sourceEntityId: string;
|
||||||
|
sourceEntityName: string;
|
||||||
|
sourceEntityType: string;
|
||||||
|
targetEntityId: string;
|
||||||
|
targetEntityName: string;
|
||||||
|
targetEntityType: string;
|
||||||
|
relationType: string;
|
||||||
|
properties?: Record<string, unknown>;
|
||||||
|
weight?: number;
|
||||||
|
confidence?: number;
|
||||||
|
sourceId?: string;
|
||||||
|
graphId: string;
|
||||||
|
createdAt?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SearchHitVO {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
type: string;
|
||||||
|
description?: string;
|
||||||
|
score: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PagedResponse<T> {
|
||||||
|
page: number;
|
||||||
|
size: number;
|
||||||
|
totalElements: number;
|
||||||
|
totalPages: number;
|
||||||
|
content: T[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PathVO {
|
||||||
|
nodes: EntitySummaryVO[];
|
||||||
|
edges: EdgeSummaryVO[];
|
||||||
|
pathLength: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AllPathsVO {
|
||||||
|
paths: PathVO[];
|
||||||
|
pathCount: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Edit Review ----
|
||||||
|
|
||||||
|
export type ReviewOperationType =
|
||||||
|
| "CREATE_ENTITY"
|
||||||
|
| "UPDATE_ENTITY"
|
||||||
|
| "DELETE_ENTITY"
|
||||||
|
| "CREATE_RELATION"
|
||||||
|
| "UPDATE_RELATION"
|
||||||
|
| "DELETE_RELATION";
|
||||||
|
|
||||||
|
export type ReviewStatus = "PENDING" | "APPROVED" | "REJECTED";
|
||||||
|
|
||||||
|
export interface EditReviewVO {
|
||||||
|
id: string;
|
||||||
|
graphId: string;
|
||||||
|
operationType: ReviewOperationType;
|
||||||
|
entityId?: string;
|
||||||
|
relationId?: string;
|
||||||
|
payload?: string;
|
||||||
|
status: ReviewStatus;
|
||||||
|
submittedBy?: string;
|
||||||
|
reviewedBy?: string;
|
||||||
|
reviewComment?: string;
|
||||||
|
createdAt?: string;
|
||||||
|
reviewedAt?: string;
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import {
|
|||||||
Shield,
|
Shield,
|
||||||
Sparkles,
|
Sparkles,
|
||||||
ListChecks,
|
ListChecks,
|
||||||
|
Network,
|
||||||
// Database,
|
// Database,
|
||||||
// Store,
|
// Store,
|
||||||
// Merge,
|
// Merge,
|
||||||
@@ -56,6 +57,14 @@ export const menuItems = [
|
|||||||
description: "管理知识集与知识条目",
|
description: "管理知识集与知识条目",
|
||||||
color: "bg-indigo-500",
|
color: "bg-indigo-500",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: "knowledge-graph",
|
||||||
|
title: "知识图谱",
|
||||||
|
icon: Network,
|
||||||
|
permissionCode: PermissionCodes.knowledgeGraphRead,
|
||||||
|
description: "知识图谱浏览与探索",
|
||||||
|
color: "bg-teal-500",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: "task-coordination",
|
id: "task-coordination",
|
||||||
title: "任务协调",
|
title: "任务协调",
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ import ContentGenerationPage from "@/pages/ContentGeneration/ContentGenerationPa
|
|||||||
import LoginPage from "@/pages/Login/LoginPage";
|
import LoginPage from "@/pages/Login/LoginPage";
|
||||||
import ProtectedRoute from "@/components/ProtectedRoute";
|
import ProtectedRoute from "@/components/ProtectedRoute";
|
||||||
import ForbiddenPage from "@/pages/Forbidden/ForbiddenPage";
|
import ForbiddenPage from "@/pages/Forbidden/ForbiddenPage";
|
||||||
|
import KnowledgeGraphPage from "@/pages/KnowledgeGraph/Home/KnowledgeGraphPage";
|
||||||
|
|
||||||
const router = createBrowserRouter([
|
const router = createBrowserRouter([
|
||||||
{
|
{
|
||||||
@@ -287,6 +288,10 @@ const router = createBrowserRouter([
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
path: "knowledge-graph",
|
||||||
|
Component: withErrorBoundary(KnowledgeGraphPage),
|
||||||
|
},
|
||||||
{
|
{
|
||||||
path: "task-coordination",
|
path: "task-coordination",
|
||||||
children: [
|
children: [
|
||||||
|
|||||||
@@ -82,6 +82,42 @@ class Settings(BaseSettings):
|
|||||||
kg_llm_timeout_seconds: int = 60
|
kg_llm_timeout_seconds: int = 60
|
||||||
kg_llm_max_retries: int = 2
|
kg_llm_max_retries: int = 2
|
||||||
|
|
||||||
|
# Knowledge Graph - 实体对齐配置
|
||||||
|
kg_alignment_enabled: bool = False
|
||||||
|
kg_alignment_embedding_model: str = "text-embedding-3-small"
|
||||||
|
kg_alignment_vector_threshold: float = 0.92
|
||||||
|
kg_alignment_llm_threshold: float = 0.78
|
||||||
|
|
||||||
|
# GraphRAG 融合查询配置
|
||||||
|
graphrag_enabled: bool = False
|
||||||
|
graphrag_milvus_uri: str = "http://milvus-standalone:19530"
|
||||||
|
graphrag_kg_service_url: str = "http://datamate-backend:8080"
|
||||||
|
graphrag_kg_internal_token: str = ""
|
||||||
|
|
||||||
|
# GraphRAG - 检索策略默认值
|
||||||
|
graphrag_vector_top_k: int = 5
|
||||||
|
graphrag_graph_depth: int = 2
|
||||||
|
graphrag_graph_max_entities: int = 20
|
||||||
|
graphrag_vector_weight: float = 0.6
|
||||||
|
graphrag_graph_weight: float = 0.4
|
||||||
|
|
||||||
|
# GraphRAG - LLM(空则复用 kg_llm_* 配置)
|
||||||
|
graphrag_llm_model: str = ""
|
||||||
|
graphrag_llm_base_url: Optional[str] = None
|
||||||
|
graphrag_llm_api_key: SecretStr = SecretStr("EMPTY")
|
||||||
|
graphrag_llm_temperature: float = 0.1
|
||||||
|
graphrag_llm_timeout_seconds: int = 60
|
||||||
|
|
||||||
|
# GraphRAG - Embedding(空则复用 kg_alignment_embedding_* 配置)
|
||||||
|
graphrag_embedding_model: str = ""
|
||||||
|
|
||||||
|
# GraphRAG - 缓存配置
|
||||||
|
graphrag_cache_enabled: bool = True
|
||||||
|
graphrag_cache_kg_maxsize: int = 256
|
||||||
|
graphrag_cache_kg_ttl: int = 300
|
||||||
|
graphrag_cache_embedding_maxsize: int = 512
|
||||||
|
graphrag_cache_embedding_ttl: int = 600
|
||||||
|
|
||||||
# 标注编辑器(Label Studio Editor)相关
|
# 标注编辑器(Label Studio Editor)相关
|
||||||
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .evaluation.interface import router as evaluation_router
|
|||||||
from .collection.interface import router as collection_route
|
from .collection.interface import router as collection_route
|
||||||
from .dataset.interface import router as dataset_router
|
from .dataset.interface import router as dataset_router
|
||||||
from .kg_extraction.interface import router as kg_extraction_router
|
from .kg_extraction.interface import router as kg_extraction_router
|
||||||
|
from .kg_graphrag.interface import router as kg_graphrag_router
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/api"
|
prefix="/api"
|
||||||
@@ -21,5 +22,6 @@ router.include_router(evaluation_router)
|
|||||||
router.include_router(collection_route)
|
router.include_router(collection_route)
|
||||||
router.include_router(dataset_router)
|
router.include_router(dataset_router)
|
||||||
router.include_router(kg_extraction_router)
|
router.include_router(kg_extraction_router)
|
||||||
|
router.include_router(kg_graphrag_router)
|
||||||
|
|
||||||
__all__ = ["router"]
|
__all__ = ["router"]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from app.module.kg_extraction.aligner import EntityAligner
|
||||||
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
||||||
from app.module.kg_extraction.models import (
|
from app.module.kg_extraction.models import (
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
@@ -9,6 +10,7 @@ from app.module.kg_extraction.models import (
|
|||||||
from app.module.kg_extraction.interface import router
|
from app.module.kg_extraction.interface import router
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EntityAligner",
|
||||||
"KnowledgeGraphExtractor",
|
"KnowledgeGraphExtractor",
|
||||||
"ExtractionRequest",
|
"ExtractionRequest",
|
||||||
"ExtractionResult",
|
"ExtractionResult",
|
||||||
|
|||||||
478
runtime/datamate-python/app/module/kg_extraction/aligner.py
Normal file
478
runtime/datamate-python/app/module/kg_extraction/aligner.py
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
"""实体对齐器:对抽取结果中的实体进行去重和合并。
|
||||||
|
|
||||||
|
三层对齐策略:
|
||||||
|
1. 规则层:名称规范化 + 别名匹配 + 类型硬过滤
|
||||||
|
2. 向量相似度层:基于 embedding 的 cosine 相似度
|
||||||
|
3. LLM 仲裁层:仅对边界样本调用,严格 JSON schema 校验
|
||||||
|
|
||||||
|
失败策略:fail-open —— 对齐失败不阻断抽取请求。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_extraction.models import (
|
||||||
|
ExtractionResult,
|
||||||
|
GraphEdge,
|
||||||
|
GraphNode,
|
||||||
|
Triple,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rule Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_name(name: str) -> str:
|
||||||
|
"""名称规范化:Unicode NFKC -> 小写 -> 去标点 -> 合并空白。"""
|
||||||
|
name = unicodedata.normalize("NFKC", name)
|
||||||
|
name = name.lower()
|
||||||
|
name = re.sub(r"[^\w\s]", "", name)
|
||||||
|
name = re.sub(r"\s+", " ", name).strip()
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def rule_score(a: GraphNode, b: GraphNode) -> float:
|
||||||
|
"""规则层匹配分数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
1.0 规范化名称完全一致且类型兼容
|
||||||
|
0.5 一方名称是另一方子串且类型兼容(别名/缩写)
|
||||||
|
0.0 类型不兼容或名称无关联
|
||||||
|
"""
|
||||||
|
# 类型硬过滤
|
||||||
|
if a.type.lower() != b.type.lower():
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
norm_a = normalize_name(a.name)
|
||||||
|
norm_b = normalize_name(b.name)
|
||||||
|
|
||||||
|
# 完全匹配
|
||||||
|
if norm_a == norm_b:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# 子串匹配(别名/缩写),要求双方规范化名称至少 2 字符
|
||||||
|
if len(norm_a) >= 2 and len(norm_b) >= 2:
|
||||||
|
if norm_a in norm_b or norm_b in norm_a:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vector Similarity Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
|
"""计算两个向量的余弦相似度。"""
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = sum(x * x for x in a) ** 0.5
|
||||||
|
norm_b = sum(x * x for x in b) ** 0.5
|
||||||
|
if norm_a == 0.0 or norm_b == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
|
def _entity_text(node: GraphNode) -> str:
|
||||||
|
"""构造用于 embedding 的实体文本表示。"""
|
||||||
|
return f"{node.type}: {node.name}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLM Arbitration Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_LLM_PROMPT = (
|
||||||
|
"判断以下两个实体是否指向同一个现实世界的实体或概念。\n\n"
|
||||||
|
"实体 A:\n- 名称: {name_a}\n- 类型: {type_a}\n\n"
|
||||||
|
"实体 B:\n- 名称: {name_b}\n- 类型: {type_b}\n\n"
|
||||||
|
'请严格按以下 JSON 格式返回,不要包含任何其他内容:\n'
|
||||||
|
'{{"is_same": true, "confidence": 0.95, "reason": "简要理由"}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMArbitrationResult(BaseModel):
|
||||||
|
"""LLM 仲裁返回结构。"""
|
||||||
|
|
||||||
|
is_same: bool
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Union-Find
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_union_find(n: int):
|
||||||
|
"""创建 Union-Find 数据结构,返回 (parent, find, union)。"""
|
||||||
|
parent = list(range(n))
|
||||||
|
|
||||||
|
def find(x: int) -> int:
|
||||||
|
while parent[x] != x:
|
||||||
|
parent[x] = parent[parent[x]]
|
||||||
|
x = parent[x]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def union(x: int, y: int) -> None:
|
||||||
|
px, py = find(x), find(y)
|
||||||
|
if px != py:
|
||||||
|
parent[px] = py
|
||||||
|
|
||||||
|
return parent, find, union
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Merge Result Builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_merged_result(
|
||||||
|
original: ExtractionResult,
|
||||||
|
parent: list[int],
|
||||||
|
find,
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""根据 Union-Find 结果构建合并后的 ExtractionResult。"""
|
||||||
|
nodes = original.nodes
|
||||||
|
|
||||||
|
# Group by root
|
||||||
|
groups: dict[int, list[int]] = {}
|
||||||
|
for i in range(len(nodes)):
|
||||||
|
root = find(i)
|
||||||
|
groups.setdefault(root, []).append(i)
|
||||||
|
|
||||||
|
# 无合并发生时直接返回原结果
|
||||||
|
if len(groups) == len(nodes):
|
||||||
|
return original
|
||||||
|
|
||||||
|
# Canonical: 选择每组中名称最长的节点
|
||||||
|
# 使用 (name, type) 作为 key 避免同名跨类型节点误映射
|
||||||
|
node_map: dict[tuple[str, str], str] = {}
|
||||||
|
merged_nodes: list[GraphNode] = []
|
||||||
|
for members in groups.values():
|
||||||
|
best_idx = max(members, key=lambda idx: len(nodes[idx].name))
|
||||||
|
canon = nodes[best_idx]
|
||||||
|
merged_nodes.append(canon)
|
||||||
|
for idx in members:
|
||||||
|
node_map[(nodes[idx].name, nodes[idx].type)] = canon.name
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Alignment merged %d nodes -> %d nodes",
|
||||||
|
len(nodes),
|
||||||
|
len(merged_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为 edges 构建仅名称的映射(仅当同名节点映射结果无歧义时才包含)
|
||||||
|
_edge_remap: dict[str, set[str]] = {}
|
||||||
|
for (name, _type), canon_name in node_map.items():
|
||||||
|
_edge_remap.setdefault(name, set()).add(canon_name)
|
||||||
|
edge_name_map: dict[str, str] = {
|
||||||
|
name: next(iter(canon_names))
|
||||||
|
for name, canon_names in _edge_remap.items()
|
||||||
|
if len(canon_names) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新 edges(重命名 + 去重)
|
||||||
|
seen_edges: set[str] = set()
|
||||||
|
merged_edges: list[GraphEdge] = []
|
||||||
|
for edge in original.edges:
|
||||||
|
src = edge_name_map.get(edge.source, edge.source)
|
||||||
|
tgt = edge_name_map.get(edge.target, edge.target)
|
||||||
|
key = f"{src}|{edge.relation_type}|{tgt}"
|
||||||
|
if key not in seen_edges:
|
||||||
|
seen_edges.add(key)
|
||||||
|
merged_edges.append(
|
||||||
|
GraphEdge(
|
||||||
|
source=src,
|
||||||
|
target=tgt,
|
||||||
|
relation_type=edge.relation_type,
|
||||||
|
properties=edge.properties,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新 triples(使用 (name, type) 精确查找,避免跨类型误映射)
|
||||||
|
seen_triples: set[str] = set()
|
||||||
|
merged_triples: list[Triple] = []
|
||||||
|
for triple in original.triples:
|
||||||
|
sub_key = (triple.subject.name, triple.subject.type)
|
||||||
|
obj_key = (triple.object.name, triple.object.type)
|
||||||
|
sub_name = node_map.get(sub_key, triple.subject.name)
|
||||||
|
obj_name = node_map.get(obj_key, triple.object.name)
|
||||||
|
key = f"{sub_name}|{triple.predicate}|{obj_name}"
|
||||||
|
if key not in seen_triples:
|
||||||
|
seen_triples.add(key)
|
||||||
|
merged_triples.append(
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name=sub_name, type=triple.subject.type),
|
||||||
|
predicate=triple.predicate,
|
||||||
|
object=GraphNode(name=obj_name, type=triple.object.type),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractionResult(
|
||||||
|
nodes=merged_nodes,
|
||||||
|
edges=merged_edges,
|
||||||
|
triples=merged_triples,
|
||||||
|
raw_text=original.raw_text,
|
||||||
|
source_id=original.source_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EntityAligner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class EntityAligner:
|
||||||
|
"""实体对齐器。
|
||||||
|
|
||||||
|
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
||||||
|
也可直接构造以覆盖默认参数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
enabled: bool = False,
|
||||||
|
embedding_model: str = "text-embedding-3-small",
|
||||||
|
embedding_base_url: str | None = None,
|
||||||
|
embedding_api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
llm_model: str = "gpt-4o-mini",
|
||||||
|
llm_base_url: str | None = None,
|
||||||
|
llm_api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
llm_timeout: int = 30,
|
||||||
|
vector_auto_merge_threshold: float = 0.92,
|
||||||
|
vector_llm_threshold: float = 0.78,
|
||||||
|
llm_arbitration_enabled: bool = True,
|
||||||
|
max_llm_arbitrations: int = 10,
|
||||||
|
) -> None:
|
||||||
|
self._enabled = enabled
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
self._embedding_base_url = embedding_base_url
|
||||||
|
self._embedding_api_key = embedding_api_key
|
||||||
|
self._llm_model = llm_model
|
||||||
|
self._llm_base_url = llm_base_url
|
||||||
|
self._llm_api_key = llm_api_key
|
||||||
|
self._llm_timeout = llm_timeout
|
||||||
|
self._vector_auto_threshold = vector_auto_merge_threshold
|
||||||
|
self._vector_llm_threshold = vector_llm_threshold
|
||||||
|
self._llm_arbitration_enabled = llm_arbitration_enabled
|
||||||
|
self._max_llm_arbitrations = max_llm_arbitrations
|
||||||
|
# Lazy init
|
||||||
|
self._embeddings: OpenAIEmbeddings | None = None
|
||||||
|
self._llm: ChatOpenAI | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> EntityAligner:
|
||||||
|
"""从全局 Settings 创建对齐器实例。"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
enabled=settings.kg_alignment_enabled,
|
||||||
|
embedding_model=settings.kg_alignment_embedding_model,
|
||||||
|
embedding_base_url=settings.kg_llm_base_url,
|
||||||
|
embedding_api_key=settings.kg_llm_api_key,
|
||||||
|
llm_model=settings.kg_llm_model,
|
||||||
|
llm_base_url=settings.kg_llm_base_url,
|
||||||
|
llm_api_key=settings.kg_llm_api_key,
|
||||||
|
llm_timeout=settings.kg_llm_timeout_seconds,
|
||||||
|
vector_auto_merge_threshold=settings.kg_alignment_vector_threshold,
|
||||||
|
vector_llm_threshold=settings.kg_alignment_llm_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_embeddings(self) -> OpenAIEmbeddings:
|
||||||
|
if self._embeddings is None:
|
||||||
|
self._embeddings = OpenAIEmbeddings(
|
||||||
|
model=self._embedding_model,
|
||||||
|
base_url=self._embedding_base_url,
|
||||||
|
api_key=self._embedding_api_key,
|
||||||
|
)
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
def _get_llm(self) -> ChatOpenAI:
|
||||||
|
if self._llm is None:
|
||||||
|
self._llm = ChatOpenAI(
|
||||||
|
model=self._llm_model,
|
||||||
|
base_url=self._llm_base_url,
|
||||||
|
api_key=self._llm_api_key,
|
||||||
|
temperature=0.0,
|
||||||
|
timeout=self._llm_timeout,
|
||||||
|
)
|
||||||
|
return self._llm
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def align(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""对抽取结果中的实体进行对齐去重(异步,三层策略)。
|
||||||
|
|
||||||
|
Fail-open:对齐失败时返回原始结果,不阻断请求。
|
||||||
|
|
||||||
|
注意:当前仅支持批内对齐(单次抽取结果内部的 pairwise 合并)。
|
||||||
|
库内对齐(对现有图谱实体召回/匹配)需要 KG 服务 API 支持,待后续实现。
|
||||||
|
"""
|
||||||
|
if not self._enabled or len(result.nodes) <= 1:
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._align_impl(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Entity alignment failed, returning original result (fail-open)"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def align_rules_only(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""仅使用规则层对齐(同步,用于 extract_sync 路径)。
|
||||||
|
|
||||||
|
Fail-open:对齐失败时返回原始结果。
|
||||||
|
"""
|
||||||
|
if not self._enabled or len(result.nodes) <= 1:
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
nodes = result.nodes
|
||||||
|
parent, find, union = _make_union_find(len(nodes))
|
||||||
|
|
||||||
|
for i in range(len(nodes)):
|
||||||
|
for j in range(i + 1, len(nodes)):
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
if rule_score(nodes[i], nodes[j]) >= 1.0:
|
||||||
|
union(i, j)
|
||||||
|
|
||||||
|
return _build_merged_result(result, parent, find)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Rule-only alignment failed, returning original result (fail-open)"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _align_impl(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""三层对齐的核心实现。
|
||||||
|
|
||||||
|
当前仅在单次抽取结果的节点列表内做 pairwise 对齐。
|
||||||
|
若需与已有图谱实体匹配(库内对齐),需扩展入参以支持
|
||||||
|
graph_id + 候选实体检索上下文,依赖 KG 服务 API。
|
||||||
|
"""
|
||||||
|
nodes = result.nodes
|
||||||
|
n = len(nodes)
|
||||||
|
parent, find, union = _make_union_find(n)
|
||||||
|
|
||||||
|
# Phase 1: Rule layer
|
||||||
|
vector_candidates: list[tuple[int, int]] = []
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i + 1, n):
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
score = rule_score(nodes[i], nodes[j])
|
||||||
|
if score >= 1.0:
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"Rule merge: '%s' <-> '%s'", nodes[i].name, nodes[j].name
|
||||||
|
)
|
||||||
|
elif score > 0:
|
||||||
|
vector_candidates.append((i, j))
|
||||||
|
|
||||||
|
# Phase 2: Vector similarity
|
||||||
|
llm_candidates: list[tuple[int, int, float]] = []
|
||||||
|
if vector_candidates:
|
||||||
|
try:
|
||||||
|
emb_map = await self._embed_candidates(nodes, vector_candidates)
|
||||||
|
for i, j in vector_candidates:
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
sim = cosine_similarity(emb_map[i], emb_map[j])
|
||||||
|
if sim >= self._vector_auto_threshold:
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"Vector merge: '%s' <-> '%s' (sim=%.3f)",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
sim,
|
||||||
|
)
|
||||||
|
elif sim >= self._vector_llm_threshold:
|
||||||
|
llm_candidates.append((i, j, sim))
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Vector similarity failed, skipping vector layer", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 3: LLM arbitration (boundary cases only)
|
||||||
|
if llm_candidates and self._llm_arbitration_enabled:
|
||||||
|
llm_count = 0
|
||||||
|
for i, j, sim in llm_candidates:
|
||||||
|
if llm_count >= self._max_llm_arbitrations or find(i) == find(j):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if await self._llm_arbitrate(nodes[i], nodes[j]):
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"LLM merge: '%s' <-> '%s' (sim=%.3f)",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
sim,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"LLM arbitration failed for '%s' <-> '%s'",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
llm_count += 1
|
||||||
|
|
||||||
|
return _build_merged_result(result, parent, find)
|
||||||
|
|
||||||
|
async def _embed_candidates(
|
||||||
|
self, nodes: list[GraphNode], candidates: list[tuple[int, int]]
|
||||||
|
) -> dict[int, list[float]]:
|
||||||
|
"""对候选实体计算 embedding,返回 {index: embedding}。"""
|
||||||
|
unique_indices: set[int] = set()
|
||||||
|
for i, j in candidates:
|
||||||
|
unique_indices.add(i)
|
||||||
|
unique_indices.add(j)
|
||||||
|
|
||||||
|
idx_list = sorted(unique_indices)
|
||||||
|
texts = [_entity_text(nodes[i]) for i in idx_list]
|
||||||
|
embeddings = await self._get_embeddings().aembed_documents(texts)
|
||||||
|
return dict(zip(idx_list, embeddings))
|
||||||
|
|
||||||
|
async def _llm_arbitrate(self, a: GraphNode, b: GraphNode) -> bool:
|
||||||
|
"""LLM 仲裁两个实体是否相同,严格 JSON schema 校验。"""
|
||||||
|
prompt = _LLM_PROMPT.format(
|
||||||
|
name_a=a.name,
|
||||||
|
type_a=a.type,
|
||||||
|
name_b=b.name,
|
||||||
|
type_b=b.type,
|
||||||
|
)
|
||||||
|
response = await self._get_llm().ainvoke(prompt)
|
||||||
|
content = response.content.strip()
|
||||||
|
|
||||||
|
parsed = json.loads(content)
|
||||||
|
result = LLMArbitrationResult.model_validate(parsed)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"LLM arbitration: '%s' <-> '%s' -> is_same=%s, confidence=%.2f",
|
||||||
|
a.name,
|
||||||
|
b.name,
|
||||||
|
result.is_same,
|
||||||
|
result.confidence,
|
||||||
|
)
|
||||||
|
return result.is_same and result.confidence >= 0.7
|
||||||
@@ -15,6 +15,7 @@ from langchain_experimental.graph_transformers import LLMGraphTransformer
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_extraction.aligner import EntityAligner
|
||||||
from app.module.kg_extraction.models import (
|
from app.module.kg_extraction.models import (
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
ExtractionResult,
|
ExtractionResult,
|
||||||
@@ -47,6 +48,7 @@ class KnowledgeGraphExtractor:
|
|||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
max_retries: int = 2,
|
max_retries: int = 2,
|
||||||
|
aligner: EntityAligner | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
||||||
@@ -63,6 +65,7 @@ class KnowledgeGraphExtractor:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
|
self._aligner = aligner or EntityAligner()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_settings(cls) -> KnowledgeGraphExtractor:
|
def from_settings(cls) -> KnowledgeGraphExtractor:
|
||||||
@@ -76,6 +79,7 @@ class KnowledgeGraphExtractor:
|
|||||||
temperature=settings.kg_llm_temperature,
|
temperature=settings.kg_llm_temperature,
|
||||||
timeout=settings.kg_llm_timeout_seconds,
|
timeout=settings.kg_llm_timeout_seconds,
|
||||||
max_retries=settings.kg_llm_max_retries,
|
max_retries=settings.kg_llm_max_retries,
|
||||||
|
aligner=EntityAligner.from_settings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_transformer(
|
def _build_transformer(
|
||||||
@@ -119,6 +123,7 @@ class KnowledgeGraphExtractor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
result = self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
result = await self._aligner.align(result)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
||||||
request.graph_id,
|
request.graph_id,
|
||||||
@@ -154,6 +159,7 @@ class KnowledgeGraphExtractor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
result = self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
result = self._aligner.align_rules_only(result)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
||||||
request.graph_id,
|
request.graph_id,
|
||||||
|
|||||||
477
runtime/datamate-python/app/module/kg_extraction/test_aligner.py
Normal file
477
runtime/datamate-python/app/module/kg_extraction/test_aligner.py
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
"""实体对齐器测试。
|
||||||
|
|
||||||
|
Run with: pytest app/module/kg_extraction/test_aligner.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_extraction.aligner import (
|
||||||
|
EntityAligner,
|
||||||
|
LLMArbitrationResult,
|
||||||
|
_build_merged_result,
|
||||||
|
_make_union_find,
|
||||||
|
cosine_similarity,
|
||||||
|
normalize_name,
|
||||||
|
rule_score,
|
||||||
|
)
|
||||||
|
from app.module.kg_extraction.models import (
|
||||||
|
ExtractionResult,
|
||||||
|
GraphEdge,
|
||||||
|
GraphNode,
|
||||||
|
Triple,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# normalize_name
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeName:
|
||||||
|
def test_basic_lowercase(self):
|
||||||
|
assert normalize_name("Hello World") == "hello world"
|
||||||
|
|
||||||
|
def test_unicode_nfkc(self):
|
||||||
|
assert normalize_name("\uff28ello") == "hello"
|
||||||
|
|
||||||
|
def test_punctuation_removed(self):
|
||||||
|
assert normalize_name("U.S.A.") == "usa"
|
||||||
|
|
||||||
|
def test_whitespace_collapsed(self):
|
||||||
|
assert normalize_name(" hello world ") == "hello world"
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert normalize_name("") == ""
|
||||||
|
|
||||||
|
def test_chinese_preserved(self):
|
||||||
|
assert normalize_name("\u5f20\u4e09") == "\u5f20\u4e09"
|
||||||
|
|
||||||
|
def test_mixed_chinese_english(self):
|
||||||
|
assert normalize_name("\u5f20\u4e09 (Zhang San)") == "\u5f20\u4e09 zhang san"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rule_score
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleScore:
|
||||||
|
def test_exact_match(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_normalized_match(self):
|
||||||
|
a = GraphNode(name="Hello World", type="Organization")
|
||||||
|
b = GraphNode(name="hello world", type="Organization")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_type_mismatch(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u5f20\u4e09", type="Organization")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
def test_substring_match(self):
|
||||||
|
a = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
||||||
|
b = GraphNode(name="\u5317\u4eac\u5927\u5b66\u8ba1\u7b97\u673a\u5b66\u9662", type="Organization")
|
||||||
|
assert rule_score(a, b) == 0.5
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u674e\u56db", type="Person")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
def test_type_case_insensitive(self):
|
||||||
|
a = GraphNode(name="test", type="PERSON")
|
||||||
|
b = GraphNode(name="test", type="person")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_short_substring_ignored(self):
|
||||||
|
"""Single-character substring should not trigger match."""
|
||||||
|
a = GraphNode(name="A", type="Person")
|
||||||
|
b = GraphNode(name="AB", type="Person")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cosine_similarity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCosineSimilarity:
|
||||||
|
def test_identical(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_orthogonal(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_opposite(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
||||||
|
|
||||||
|
def test_zero_vector(self):
|
||||||
|
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Union-Find
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnionFind:
|
||||||
|
def test_basic(self):
|
||||||
|
parent, find, union = _make_union_find(4)
|
||||||
|
union(0, 1)
|
||||||
|
union(2, 3)
|
||||||
|
assert find(0) == find(1)
|
||||||
|
assert find(2) == find(3)
|
||||||
|
assert find(0) != find(2)
|
||||||
|
|
||||||
|
def test_transitive(self):
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
union(1, 2)
|
||||||
|
assert find(0) == find(2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_merged_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_result(nodes, edges=None, triples=None):
|
||||||
|
return ExtractionResult(
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges or [],
|
||||||
|
triples=triples or [],
|
||||||
|
raw_text="test text",
|
||||||
|
source_id="src-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildMergedResult:
|
||||||
|
def test_no_merge_returns_original(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="B", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, _ = _make_union_find(2)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert merged is result
|
||||||
|
|
||||||
|
def test_canonical_picks_longest_name(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="AI", type="Tech"),
|
||||||
|
GraphNode(name="Artificial Intelligence", type="Tech"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, union = _make_union_find(2)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.nodes) == 1
|
||||||
|
assert merged.nodes[0].name == "Artificial Intelligence"
|
||||||
|
|
||||||
|
def test_edge_remap_and_dedup(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Alice", type="Person"),
|
||||||
|
GraphNode(name="alice", type="Person"),
|
||||||
|
GraphNode(name="Bob", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="Alice", target="Bob", relation_type="knows"),
|
||||||
|
GraphEdge(source="alice", target="Bob", relation_type="knows"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.edges) == 1
|
||||||
|
assert merged.edges[0].source == "Alice"
|
||||||
|
|
||||||
|
def test_triple_remap_and_dedup(self):
|
||||||
|
n1 = GraphNode(name="Alice", type="Person")
|
||||||
|
n2 = GraphNode(name="alice", type="Person")
|
||||||
|
n3 = GraphNode(name="MIT", type="Organization")
|
||||||
|
triples = [
|
||||||
|
Triple(subject=n1, predicate="works_at", object=n3),
|
||||||
|
Triple(subject=n2, predicate="works_at", object=n3),
|
||||||
|
]
|
||||||
|
result = _make_result([n1, n2, n3], triples=triples)
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.triples) == 1
|
||||||
|
assert merged.triples[0].subject.name == "Alice"
|
||||||
|
|
||||||
|
def test_preserves_metadata(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, union = _make_union_find(2)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert merged.raw_text == "test text"
|
||||||
|
assert merged.source_id == "src-1"
|
||||||
|
|
||||||
|
def test_cross_type_same_name_no_collision(self):
|
||||||
|
"""P1-1 回归:同名跨类型节点合并不应误映射其他类型的边和三元组。
|
||||||
|
|
||||||
|
场景:Person "张三" 和 "张三先生" 合并为 "张三先生",
|
||||||
|
但 Organization "张三" 不应被重写。
|
||||||
|
"""
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="张三", type="Person"), # idx 0
|
||||||
|
GraphNode(name="张三先生", type="Person"), # idx 1
|
||||||
|
GraphNode(name="张三", type="Organization"), # idx 2 - 同名不同类型
|
||||||
|
GraphNode(name="北京", type="Location"), # idx 3
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="张三", target="北京", relation_type="lives_in"),
|
||||||
|
GraphEdge(source="张三", target="北京", relation_type="located_in"),
|
||||||
|
]
|
||||||
|
triples = [
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name="张三", type="Person"),
|
||||||
|
predicate="lives_in",
|
||||||
|
object=GraphNode(name="北京", type="Location"),
|
||||||
|
),
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name="张三", type="Organization"),
|
||||||
|
predicate="located_in",
|
||||||
|
object=GraphNode(name="北京", type="Location"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges, triples)
|
||||||
|
parent, find, union = _make_union_find(4)
|
||||||
|
union(0, 1) # 合并 Person "张三" 和 "张三先生"
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
|
||||||
|
# 应有 3 个节点:张三先生(Person), 张三(Org), 北京(Location)
|
||||||
|
assert len(merged.nodes) == 3
|
||||||
|
merged_names = {(n.name, n.type) for n in merged.nodes}
|
||||||
|
assert ("张三先生", "Person") in merged_names
|
||||||
|
assert ("张三", "Organization") in merged_names
|
||||||
|
assert ("北京", "Location") in merged_names
|
||||||
|
|
||||||
|
# edges 中 "张三" 有歧义(映射到不同 canonical),应保持原名不重写
|
||||||
|
assert len(merged.edges) == 2
|
||||||
|
|
||||||
|
# triples 有类型信息,可精确区分
|
||||||
|
assert len(merged.triples) == 2
|
||||||
|
person_triple = [t for t in merged.triples if t.subject.type == "Person"][0]
|
||||||
|
org_triple = [t for t in merged.triples if t.subject.type == "Organization"][0]
|
||||||
|
assert person_triple.subject.name == "张三先生" # Person 被重写
|
||||||
|
assert org_triple.subject.name == "张三" # Organization 保持原名
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EntityAligner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityAligner:
|
||||||
|
def _run(self, coro):
|
||||||
|
"""Helper to run async coroutine in sync test."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_disabled_returns_original(self):
|
||||||
|
aligner = EntityAligner(enabled=False)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_single_node_returns_original(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_rule_merge_exact_names(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u674e\u56db", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="\u5f20\u4e09", target="\u674e\u56db", relation_type="knows"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
names = {n.name for n in aligned.nodes}
|
||||||
|
assert "\u5f20\u4e09" in names
|
||||||
|
assert "\u674e\u56db" in names
|
||||||
|
|
||||||
|
def test_rule_merge_case_insensitive(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Hello World", type="Org"),
|
||||||
|
GraphNode(name="hello world", type="Org"),
|
||||||
|
GraphNode(name="Test", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_rule_merge_deduplicates_edges(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Hello World", type="Org"),
|
||||||
|
GraphNode(name="hello world", type="Org"),
|
||||||
|
GraphNode(name="Test", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="Hello World", target="Test", relation_type="employs"),
|
||||||
|
GraphEdge(source="hello world", target="Test", relation_type="employs"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.edges) == 1
|
||||||
|
|
||||||
|
def test_rule_merge_deduplicates_triples(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
n1 = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
n2 = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
n3 = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
||||||
|
triples = [
|
||||||
|
Triple(subject=n1, predicate="works_at", object=n3),
|
||||||
|
Triple(subject=n2, predicate="works_at", object=n3),
|
||||||
|
]
|
||||||
|
result = _make_result([n1, n2, n3], triples=triples)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.triples) == 1
|
||||||
|
|
||||||
|
def test_type_mismatch_no_merge(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Organization"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_fail_open_on_error(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
with patch.object(aligner, "_align_impl", side_effect=RuntimeError("boom")):
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_align_rules_only_sync(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u674e\u56db", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_align_rules_only_disabled(self):
|
||||||
|
aligner = EntityAligner(enabled=False)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_align_rules_only_fail_open(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="B", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
with patch(
|
||||||
|
"app.module.kg_extraction.aligner.rule_score", side_effect=RuntimeError("boom")
|
||||||
|
):
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_llm_count_incremented_on_failure(self):
|
||||||
|
"""P1-2 回归:LLM 仲裁失败也应计入 max_llm_arbitrations 预算。"""
|
||||||
|
max_arb = 2
|
||||||
|
aligner = EntityAligner(
|
||||||
|
enabled=True,
|
||||||
|
max_llm_arbitrations=max_arb,
|
||||||
|
llm_arbitration_enabled=True,
|
||||||
|
)
|
||||||
|
# 构建 4 个同类型节点,规则层子串匹配产生多个 vector 候选
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="北京大学", type="Organization"),
|
||||||
|
GraphNode(name="北京大学计算机学院", type="Organization"),
|
||||||
|
GraphNode(name="北京大学数学学院", type="Organization"),
|
||||||
|
GraphNode(name="北京大学物理学院", type="Organization"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
|
||||||
|
# Mock embedding 使所有候选都落入 LLM 仲裁区间
|
||||||
|
fake_embedding = [1.0, 0.0, 0.0]
|
||||||
|
# 微调使 cosine 在 llm_threshold 和 auto_threshold 之间
|
||||||
|
import math
|
||||||
|
|
||||||
|
# cos(θ) = 0.85 → 在默认 [0.78, 0.92) 区间
|
||||||
|
angle = math.acos(0.85)
|
||||||
|
emb_a = [1.0, 0.0]
|
||||||
|
emb_b = [math.cos(angle), math.sin(angle)]
|
||||||
|
|
||||||
|
async def fake_embed(texts):
|
||||||
|
# 偶数索引返回 emb_a,奇数返回 emb_b
|
||||||
|
return [emb_a if i % 2 == 0 else emb_b for i in range(len(texts))]
|
||||||
|
|
||||||
|
mock_llm_arbitrate = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
with patch.object(aligner, "_get_embeddings") as mock_emb:
|
||||||
|
mock_emb_instance = AsyncMock()
|
||||||
|
mock_emb_instance.aembed_documents = fake_embed
|
||||||
|
mock_emb.return_value = mock_emb_instance
|
||||||
|
with patch.object(aligner, "_llm_arbitrate", mock_llm_arbitrate):
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
|
||||||
|
# LLM 应恰好被调用 max_arb 次(不会因异常不计数而超出预算)
|
||||||
|
assert mock_llm_arbitrate.call_count <= max_arb
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLMArbitrationResult
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMArbitrationResult:
|
||||||
|
def test_valid_parse(self):
|
||||||
|
data = {"is_same": True, "confidence": 0.95, "reason": "Same entity"}
|
||||||
|
result = LLMArbitrationResult.model_validate(data)
|
||||||
|
assert result.is_same is True
|
||||||
|
assert result.confidence == 0.95
|
||||||
|
|
||||||
|
def test_confidence_bounds(self):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
LLMArbitrationResult.model_validate(
|
||||||
|
{"is_same": True, "confidence": 1.5, "reason": ""}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_reason_defaults(self):
|
||||||
|
result = LLMArbitrationResult.model_validate(
|
||||||
|
{"is_same": False, "confidence": 0.1}
|
||||||
|
)
|
||||||
|
assert result.reason == ""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""GraphRAG 融合查询模块。"""
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.interface import router
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
207
runtime/datamate-python/app/module/kg_graphrag/cache.py
Normal file
207
runtime/datamate-python/app/module/kg_graphrag/cache.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""GraphRAG 检索缓存。
|
||||||
|
|
||||||
|
使用 cachetools 的 TTLCache 为 KG 服务响应和 embedding 向量
|
||||||
|
提供内存级 LRU + TTL 缓存,减少重复网络调用。
|
||||||
|
|
||||||
|
缓存策略:
|
||||||
|
- KG 全文搜索结果:TTL 5 分钟,最多 256 条
|
||||||
|
- KG 子图导出结果:TTL 5 分钟,最多 256 条
|
||||||
|
- Embedding 向量:TTL 10 分钟,最多 512 条(embedding 计算成本高)
|
||||||
|
|
||||||
|
写操作由 Java 侧负责,Python 只读,因此不需要写后失效机制。
|
||||||
|
TTL 到期后自然过期,保证最终一致性。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats:
|
||||||
|
"""缓存命中统计。"""
|
||||||
|
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
total = self.hits + self.misses
|
||||||
|
return self.hits / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"hits": self.hits,
|
||||||
|
"misses": self.misses,
|
||||||
|
"evictions": self.evictions,
|
||||||
|
"hit_rate": round(self.hit_rate, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _DisabledCache:
|
||||||
|
"""缓存禁用时的 no-op 缓存实现。"""
|
||||||
|
|
||||||
|
maxsize = 0
|
||||||
|
|
||||||
|
def get(self, key: str) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGCache:
|
||||||
|
"""GraphRAG 检索结果缓存。
|
||||||
|
|
||||||
|
线程安全:内部使用 threading.Lock 保护 TTLCache。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
kg_maxsize: int = 256,
|
||||||
|
kg_ttl: int = 300,
|
||||||
|
embedding_maxsize: int = 512,
|
||||||
|
embedding_ttl: int = 600,
|
||||||
|
) -> None:
|
||||||
|
self._kg_cache: TTLCache | _DisabledCache = self._create_cache(kg_maxsize, kg_ttl)
|
||||||
|
self._embedding_cache: TTLCache | _DisabledCache = self._create_cache(
|
||||||
|
embedding_maxsize, embedding_ttl
|
||||||
|
)
|
||||||
|
self._kg_lock = threading.Lock()
|
||||||
|
self._embedding_lock = threading.Lock()
|
||||||
|
self._kg_stats = CacheStats()
|
||||||
|
self._embedding_stats = CacheStats()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_cache(maxsize: int, ttl: int) -> TTLCache | _DisabledCache:
|
||||||
|
if maxsize <= 0:
|
||||||
|
return _DisabledCache()
|
||||||
|
return TTLCache(maxsize=maxsize, ttl=max(1, ttl))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> GraphRAGCache:
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if not settings.graphrag_cache_enabled:
|
||||||
|
# 返回禁用缓存实例:不缓存数据,避免 maxsize=0 初始化异常
|
||||||
|
return cls(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
kg_maxsize=settings.graphrag_cache_kg_maxsize,
|
||||||
|
kg_ttl=settings.graphrag_cache_kg_ttl,
|
||||||
|
embedding_maxsize=settings.graphrag_cache_embedding_maxsize,
|
||||||
|
embedding_ttl=settings.graphrag_cache_embedding_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# KG 缓存(全文搜索 + 子图导出)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_kg(self, key: str) -> Any | None:
|
||||||
|
"""查找 KG 缓存。返回 None 表示 miss。"""
|
||||||
|
with self._kg_lock:
|
||||||
|
val = self._kg_cache.get(key)
|
||||||
|
if val is not None:
|
||||||
|
self._kg_stats.hits += 1
|
||||||
|
return val
|
||||||
|
self._kg_stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_kg(self, key: str, value: Any) -> None:
|
||||||
|
"""写入 KG 缓存。"""
|
||||||
|
if self._kg_cache.maxsize <= 0:
|
||||||
|
return
|
||||||
|
with self._kg_lock:
|
||||||
|
self._kg_cache[key] = value
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Embedding 缓存
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_embedding(self, key: str) -> list[float] | None:
|
||||||
|
"""查找 embedding 缓存。返回 None 表示 miss。"""
|
||||||
|
with self._embedding_lock:
|
||||||
|
val = self._embedding_cache.get(key)
|
||||||
|
if val is not None:
|
||||||
|
self._embedding_stats.hits += 1
|
||||||
|
return val
|
||||||
|
self._embedding_stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_embedding(self, key: str, value: list[float]) -> None:
|
||||||
|
"""写入 embedding 缓存。"""
|
||||||
|
if self._embedding_cache.maxsize <= 0:
|
||||||
|
return
|
||||||
|
with self._embedding_lock:
|
||||||
|
self._embedding_cache[key] = value
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 统计 & 管理
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, Any]:
|
||||||
|
"""返回所有缓存区域的统计信息。"""
|
||||||
|
with self._kg_lock:
|
||||||
|
kg_size = len(self._kg_cache)
|
||||||
|
with self._embedding_lock:
|
||||||
|
emb_size = len(self._embedding_cache)
|
||||||
|
return {
|
||||||
|
"kg": {
|
||||||
|
**self._kg_stats.to_dict(),
|
||||||
|
"size": kg_size,
|
||||||
|
"maxsize": self._kg_cache.maxsize,
|
||||||
|
},
|
||||||
|
"embedding": {
|
||||||
|
**self._embedding_stats.to_dict(),
|
||||||
|
"size": emb_size,
|
||||||
|
"maxsize": self._embedding_cache.maxsize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空所有缓存。"""
|
||||||
|
with self._kg_lock:
|
||||||
|
self._kg_cache.clear()
|
||||||
|
with self._embedding_lock:
|
||||||
|
self._embedding_cache.clear()
|
||||||
|
logger.info("GraphRAG cache cleared")
|
||||||
|
|
||||||
|
|
||||||
|
def make_cache_key(*args: Any) -> str:
|
||||||
|
"""从任意参数生成稳定的缓存 key。
|
||||||
|
|
||||||
|
对参数进行 JSON 序列化后取 SHA-256 摘要,
|
||||||
|
确保 key 长度固定且不含特殊字符。
|
||||||
|
"""
|
||||||
|
raw = json.dumps(args, sort_keys=True, ensure_ascii=False, default=str)
|
||||||
|
return hashlib.sha256(raw.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例(延迟初始化)
|
||||||
|
_cache: GraphRAGCache | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache() -> GraphRAGCache:
|
||||||
|
"""获取全局缓存单例。"""
|
||||||
|
global _cache
|
||||||
|
if _cache is None:
|
||||||
|
_cache = GraphRAGCache.from_settings()
|
||||||
|
return _cache
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
"""三元组文本化 + 上下文构建。
|
||||||
|
|
||||||
|
将图谱子图(实体 + 关系)转为自然语言描述,
|
||||||
|
并与向量检索片段合并为 LLM 可消费的上下文文本。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
EntitySummary,
|
||||||
|
RelationSummary,
|
||||||
|
VectorChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系类型 -> 中文模板映射
|
||||||
|
RELATION_TEMPLATES: dict[str, str] = {
|
||||||
|
"HAS_FIELD": "{source}包含字段{target}",
|
||||||
|
"DERIVED_FROM": "{source}来源于{target}",
|
||||||
|
"USES_DATASET": "{source}使用了数据集{target}",
|
||||||
|
"PRODUCES": "{source}产出了{target}",
|
||||||
|
"ASSIGNED_TO": "{source}分配给了{target}",
|
||||||
|
"BELONGS_TO": "{source}属于{target}",
|
||||||
|
"TRIGGERS": "{source}触发了{target}",
|
||||||
|
"DEPENDS_ON": "{source}依赖于{target}",
|
||||||
|
"IMPACTS": "{source}影响了{target}",
|
||||||
|
"SOURCED_FROM": "{source}的知识来源于{target}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 通用模板(未在映射中的关系类型)
|
||||||
|
_DEFAULT_TEMPLATE = "{source}与{target}存在{relation}关系"
|
||||||
|
|
||||||
|
|
||||||
|
def textualize_subgraph(
|
||||||
|
entities: list[EntitySummary],
|
||||||
|
relations: list[RelationSummary],
|
||||||
|
) -> str:
|
||||||
|
"""将图谱子图转为自然语言描述。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: 子图中的实体列表。
|
||||||
|
relations: 子图中的关系列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本化后的图谱描述,每条关系/实体一行。
|
||||||
|
"""
|
||||||
|
lines: list[str] = []
|
||||||
|
|
||||||
|
# 记录有关系的实体名称
|
||||||
|
mentioned_entities: set[str] = set()
|
||||||
|
|
||||||
|
# 1. 对每条关系生成一句话
|
||||||
|
for rel in relations:
|
||||||
|
source_label = f"{rel.source_type}'{rel.source_name}'"
|
||||||
|
target_label = f"{rel.target_type}'{rel.target_name}'"
|
||||||
|
template = RELATION_TEMPLATES.get(rel.relation_type, _DEFAULT_TEMPLATE)
|
||||||
|
line = template.format(
|
||||||
|
source=source_label,
|
||||||
|
target=target_label,
|
||||||
|
relation=rel.relation_type,
|
||||||
|
)
|
||||||
|
lines.append(line)
|
||||||
|
mentioned_entities.add(rel.source_name)
|
||||||
|
mentioned_entities.add(rel.target_name)
|
||||||
|
|
||||||
|
# 2. 对独立实体(无关系)生成描述句
|
||||||
|
for entity in entities:
|
||||||
|
if entity.name not in mentioned_entities:
|
||||||
|
desc = entity.description or ""
|
||||||
|
if desc:
|
||||||
|
lines.append(f"{entity.type}'{entity.name}': {desc}")
|
||||||
|
else:
|
||||||
|
lines.append(f"存在{entity.type}'{entity.name}'")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def build_context(
|
||||||
|
vector_chunks: list[VectorChunk],
|
||||||
|
graph_text: str,
|
||||||
|
vector_weight: float = 0.6,
|
||||||
|
graph_weight: float = 0.4,
|
||||||
|
) -> str:
|
||||||
|
"""合并向量检索片段和图谱文本化内容为 LLM 上下文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_chunks: 向量检索到的文档片段列表。
|
||||||
|
graph_text: 文本化后的图谱描述。
|
||||||
|
vector_weight: 向量分数权重(当前用于日志/调试,不影响上下文排序)。
|
||||||
|
graph_weight: 图谱相关性权重。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的上下文文本,分为「相关文档」和「知识图谱上下文」两个部分。
|
||||||
|
"""
|
||||||
|
sections: list[str] = []
|
||||||
|
|
||||||
|
# 向量检索片段
|
||||||
|
if vector_chunks:
|
||||||
|
doc_lines = ["## 相关文档"]
|
||||||
|
for i, chunk in enumerate(vector_chunks, 1):
|
||||||
|
doc_lines.append(f"[{i}] {chunk.text}")
|
||||||
|
sections.append("\n".join(doc_lines))
|
||||||
|
|
||||||
|
# 图谱文本化内容
|
||||||
|
if graph_text:
|
||||||
|
sections.append(f"## 知识图谱上下文\n{graph_text}")
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
return "(未检索到相关上下文信息)"
|
||||||
|
|
||||||
|
return "\n\n".join(sections)
|
||||||
101
runtime/datamate-python/app/module/kg_graphrag/generator.py
Normal file
101
runtime/datamate-python/app/module/kg_graphrag/generator.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""LLM 生成器。
|
||||||
|
|
||||||
|
基于增强上下文(向量 + 图谱)调用 LLM 生成回答,
|
||||||
|
支持同步和流式两种模式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"你是 DataMate 数据管理平台的智能助手。请根据以下上下文信息回答用户的问题。\n"
|
||||||
|
"如果上下文中没有相关信息,请明确说明。不要编造信息。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGGenerator:
|
||||||
|
"""GraphRAG LLM 生成器。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
base_url: str | None = None,
|
||||||
|
api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
temperature: float = 0.1,
|
||||||
|
timeout: int = 60,
|
||||||
|
) -> None:
|
||||||
|
self._model = model
|
||||||
|
self._base_url = base_url
|
||||||
|
self._api_key = api_key
|
||||||
|
self._temperature = temperature
|
||||||
|
self._timeout = timeout
|
||||||
|
self._llm = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> GraphRAGGenerator:
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
model = settings.graphrag_llm_model or settings.kg_llm_model
|
||||||
|
base_url = settings.graphrag_llm_base_url or settings.kg_llm_base_url
|
||||||
|
api_key = (
|
||||||
|
settings.graphrag_llm_api_key
|
||||||
|
if settings.graphrag_llm_api_key.get_secret_value() != "EMPTY"
|
||||||
|
else settings.kg_llm_api_key
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
model=model,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=settings.graphrag_llm_temperature,
|
||||||
|
timeout=settings.graphrag_llm_timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_llm(self):
|
||||||
|
if self._llm is None:
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
self._llm = ChatOpenAI(
|
||||||
|
model=self._model,
|
||||||
|
base_url=self._base_url,
|
||||||
|
api_key=self._api_key,
|
||||||
|
temperature=self._temperature,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
return self._llm
|
||||||
|
|
||||||
|
def _build_messages(self, query: str, context: str) -> list[dict[str, str]]:
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{context}\n\n用户问题: {query}\n\n请基于上下文中的信息回答。",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
async def generate(self, query: str, context: str) -> str:
|
||||||
|
"""基于增强上下文生成回答。"""
|
||||||
|
messages = self._build_messages(query, context)
|
||||||
|
llm = self._get_llm()
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
return str(response.content)
|
||||||
|
|
||||||
|
async def generate_stream(self, query: str, context: str) -> AsyncIterator[str]:
|
||||||
|
"""基于增强上下文流式生成回答,逐 token 返回。"""
|
||||||
|
messages = self._build_messages(query, context)
|
||||||
|
llm = self._get_llm()
|
||||||
|
async for chunk in llm.astream(messages):
|
||||||
|
content = chunk.content
|
||||||
|
if content:
|
||||||
|
yield str(content)
|
||||||
281
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
281
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""GraphRAG 融合查询 API 端点。
|
||||||
|
|
||||||
|
提供向量检索 + 知识图谱的融合查询能力:
|
||||||
|
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
|
||||||
|
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
|
||||||
|
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
GraphRAGQueryRequest,
|
||||||
|
GraphRAGQueryResponse,
|
||||||
|
RetrievalContext,
|
||||||
|
)
|
||||||
|
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||||
|
from app.module.kg_graphrag.generator import GraphRAGGenerator
|
||||||
|
from app.module.shared.schema import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 延迟初始化
|
||||||
|
_retriever: GraphRAGRetriever | None = None
|
||||||
|
_generator: GraphRAGGenerator | None = None
|
||||||
|
_kb_validator: KnowledgeBaseAccessValidator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_retriever() -> GraphRAGRetriever:
|
||||||
|
global _retriever
|
||||||
|
if _retriever is None:
|
||||||
|
_retriever = GraphRAGRetriever.from_settings()
|
||||||
|
return _retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _get_generator() -> GraphRAGGenerator:
|
||||||
|
global _generator
|
||||||
|
if _generator is None:
|
||||||
|
_generator = GraphRAGGenerator.from_settings()
|
||||||
|
return _generator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
|
||||||
|
global _kb_validator
|
||||||
|
if _kb_validator is None:
|
||||||
|
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
|
||||||
|
return _kb_validator
|
||||||
|
|
||||||
|
|
||||||
|
def _require_caller_id(
|
||||||
|
x_user_id: Annotated[
|
||||||
|
str,
|
||||||
|
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
|
||||||
|
],
|
||||||
|
) -> str:
|
||||||
|
caller = x_user_id.strip()
|
||||||
|
if not caller:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
||||||
|
return caller
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P0: 完整 GraphRAG 查询
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/query",
|
||||||
|
response_model=StandardResponse[GraphRAGQueryResponse],
|
||||||
|
summary="GraphRAG 查询",
|
||||||
|
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
|
||||||
|
)
|
||||||
|
async def query(
|
||||||
|
req: GraphRAGQueryRequest,
|
||||||
|
caller: Annotated[str, Depends(_require_caller_id)],
|
||||||
|
):
|
||||||
|
trace_id = uuid.uuid4().hex[:16]
|
||||||
|
logger.info(
|
||||||
|
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.graph_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = _get_retriever()
|
||||||
|
generator = _get_generator()
|
||||||
|
|
||||||
|
# 权限校验:验证用户是否有权访问该知识库
|
||||||
|
kb_validator = _get_kb_validator()
|
||||||
|
if not await kb_validator.check_access(
|
||||||
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = await retriever.retrieve(
|
||||||
|
query=req.query,
|
||||||
|
collection_name=req.collection_name,
|
||||||
|
graph_id=req.graph_id,
|
||||||
|
strategy=req.strategy,
|
||||||
|
user_id=caller,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Retrieval failed", trace_id)
|
||||||
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = await generator.generate(query=req.query, context=context.merged_text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Generation failed", trace_id)
|
||||||
|
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
result = GraphRAGQueryResponse(
|
||||||
|
answer=answer,
|
||||||
|
context=context,
|
||||||
|
model=generator.model_name,
|
||||||
|
)
|
||||||
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P1-1: 仅检索
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/retrieve",
|
||||||
|
response_model=StandardResponse[RetrievalContext],
|
||||||
|
summary="GraphRAG 仅检索",
|
||||||
|
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
|
||||||
|
)
|
||||||
|
async def retrieve(
|
||||||
|
req: GraphRAGQueryRequest,
|
||||||
|
caller: Annotated[str, Depends(_require_caller_id)],
|
||||||
|
):
|
||||||
|
trace_id = uuid.uuid4().hex[:16]
|
||||||
|
logger.info(
|
||||||
|
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.graph_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = _get_retriever()
|
||||||
|
|
||||||
|
# 权限校验:验证用户是否有权访问该知识库
|
||||||
|
kb_validator = _get_kb_validator()
|
||||||
|
if not await kb_validator.check_access(
|
||||||
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = await retriever.retrieve(
|
||||||
|
query=req.query,
|
||||||
|
collection_name=req.collection_name,
|
||||||
|
graph_id=req.graph_id,
|
||||||
|
strategy=req.strategy,
|
||||||
|
user_id=caller,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Retrieval failed", trace_id)
|
||||||
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=context)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# P1-4: 流式查询 (SSE)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/query/stream",
|
||||||
|
summary="GraphRAG 流式查询",
|
||||||
|
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
|
||||||
|
)
|
||||||
|
async def query_stream(
|
||||||
|
req: GraphRAGQueryRequest,
|
||||||
|
caller: Annotated[str, Depends(_require_caller_id)],
|
||||||
|
):
|
||||||
|
trace_id = uuid.uuid4().hex[:16]
|
||||||
|
logger.info(
|
||||||
|
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.graph_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = _get_retriever()
|
||||||
|
generator = _get_generator()
|
||||||
|
|
||||||
|
# 权限校验:验证用户是否有权访问该知识库
|
||||||
|
kb_validator = _get_kb_validator()
|
||||||
|
if not await kb_validator.check_access(
|
||||||
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||||
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context = await retriever.retrieve(
|
||||||
|
query=req.query,
|
||||||
|
collection_name=req.collection_name,
|
||||||
|
graph_id=req.graph_id,
|
||||||
|
strategy=req.strategy,
|
||||||
|
user_id=caller,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Retrieval failed", trace_id)
|
||||||
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def event_stream():
|
||||||
|
try:
|
||||||
|
async for token in generator.generate_stream(
|
||||||
|
query=req.query, context=context.merged_text
|
||||||
|
):
|
||||||
|
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
|
||||||
|
# 结束事件:附带检索上下文
|
||||||
|
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Stream generation failed", trace_id)
|
||||||
|
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 缓存管理
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/cache/stats",
|
||||||
|
response_model=StandardResponse[dict],
|
||||||
|
summary="缓存统计",
|
||||||
|
description="返回 GraphRAG 检索缓存的命中率和容量统计。",
|
||||||
|
)
|
||||||
|
async def cache_stats(caller: Annotated[str, Depends(_require_caller_id)]):
|
||||||
|
from app.module.kg_graphrag.cache import get_cache
|
||||||
|
|
||||||
|
logger.info("GraphRAG cache stats requested by caller=%s", caller)
|
||||||
|
return StandardResponse(code=200, message="success", data=get_cache().stats())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/cache/clear",
|
||||||
|
response_model=StandardResponse[dict],
|
||||||
|
summary="清空缓存",
|
||||||
|
description="清空所有 GraphRAG 检索缓存。",
|
||||||
|
)
|
||||||
|
async def cache_clear(caller: Annotated[str, Depends(_require_caller_id)]):
|
||||||
|
from app.module.kg_graphrag.cache import get_cache
|
||||||
|
|
||||||
|
logger.info("GraphRAG cache clear requested by caller=%s", caller)
|
||||||
|
get_cache().clear()
|
||||||
|
return StandardResponse(code=200, message="success", data={"cleared": True})
|
||||||
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""知识库访问权限校验。
|
||||||
|
|
||||||
|
在执行 GraphRAG 检索前,调用 Java rag-indexer-service 的
|
||||||
|
GET /knowledge-base/{id} 端点验证当前用户是否有权访问该知识库。
|
||||||
|
|
||||||
|
Java 侧实现参考:KnowledgeBaseService.getKnowledgeBaseWithAccessCheck()
|
||||||
|
- 查找 KB 是否存在
|
||||||
|
- 校验 createdBy == currentUserId(管理员跳过)
|
||||||
|
- 不满足则抛出 sys.0005 (INSUFFICIENT_PERMISSIONS)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseAccessValidator:
|
||||||
|
"""通过 Java 后端校验用户是否有权访问指定知识库。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str = "http://datamate-backend:8080/api",
|
||||||
|
timeout: float = 10.0,
|
||||||
|
) -> None:
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> KnowledgeBaseAccessValidator:
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return cls(base_url=settings.datamate_backend_base_url)
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self._base_url,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def check_access(
|
||||||
|
self,
|
||||||
|
knowledge_base_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
collection_name: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""校验用户是否有权访问指定知识库。
|
||||||
|
|
||||||
|
调用 Java 后端 GET /knowledge-base/{id},该端点内部执行
|
||||||
|
owner 校验(createdBy == currentUserId,管理员跳过)。
|
||||||
|
|
||||||
|
当 *collection_name* 不为 None 时,还会校验请求中的
|
||||||
|
collection_name 与该知识库实际的 name 是否一致,防止
|
||||||
|
用户提交合法 KB ID 但篡改 collection_name 来访问
|
||||||
|
其他知识库的 Milvus 数据。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True — 用户有权访问且 collection_name 匹配
|
||||||
|
False — 无权访问、collection_name 不匹配或校验失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/knowledge-base/{knowledge_base_id}",
|
||||||
|
headers={"X-User-Id": user_id},
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
body = resp.json()
|
||||||
|
# Java 全局包装: {"code": 200, "data": {...}}
|
||||||
|
# code != 200 说明业务层拒绝(如权限不足)
|
||||||
|
code = body.get("code", resp.status_code)
|
||||||
|
if code != 200:
|
||||||
|
logger.warning(
|
||||||
|
"KB access denied: kb_id=%s, user=%s, biz_code=%s, msg=%s",
|
||||||
|
knowledge_base_id, user_id, code, body.get("message", ""),
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 校验 collection_name 与 KB 实际名称的绑定关系
|
||||||
|
if collection_name is not None:
|
||||||
|
data = body.get("data") or {}
|
||||||
|
actual_name = data.get("name") if isinstance(data, dict) else None
|
||||||
|
if actual_name != collection_name:
|
||||||
|
logger.warning(
|
||||||
|
"KB collection_name mismatch: kb_id=%s, "
|
||||||
|
"expected=%s, actual=%s, user=%s",
|
||||||
|
knowledge_base_id, collection_name,
|
||||||
|
actual_name, user_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
# HTTP 4xx/5xx
|
||||||
|
logger.warning(
|
||||||
|
"KB access check returned HTTP %d: kb_id=%s, user=%s",
|
||||||
|
resp.status_code, knowledge_base_id, user_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
# 网络异常时 fail-close:拒绝访问,防止绕过权限
|
||||||
|
logger.exception(
|
||||||
|
"KB access check failed (fail-close): kb_id=%s, user=%s",
|
||||||
|
knowledge_base_id, user_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
214
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
214
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""KG 服务 REST 客户端。
|
||||||
|
|
||||||
|
通过 httpx 调用 Java 侧 knowledge-graph-service 的查询 API,
|
||||||
|
包括全文检索和子图导出。
|
||||||
|
|
||||||
|
失败策略:fail-open —— KG 服务不可用时返回空结果 + 日志告警。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_graphrag.cache import get_cache, make_cache_key
|
||||||
|
from app.module.kg_graphrag.models import EntitySummary, RelationSummary
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KGServiceClient:
|
||||||
|
"""Java KG 服务 REST 客户端。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str = "http://datamate-backend:8080",
|
||||||
|
internal_token: str = "",
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._internal_token = internal_token
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> KGServiceClient:
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
base_url=settings.graphrag_kg_service_url,
|
||||||
|
internal_token=settings.graphrag_kg_internal_token,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self._base_url,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _headers(self, user_id: str = "") -> dict[str, str]:
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if self._internal_token:
|
||||||
|
headers["X-Internal-Token"] = self._internal_token
|
||||||
|
if user_id:
|
||||||
|
headers["X-User-Id"] = user_id
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def fulltext_search(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
query: str,
|
||||||
|
size: int = 10,
|
||||||
|
user_id: str = "",
|
||||||
|
) -> list[EntitySummary]:
|
||||||
|
"""调用 KG 服务全文检索,返回匹配的实体列表。
|
||||||
|
|
||||||
|
Fail-open: KG 服务不可用时返回空列表。
|
||||||
|
结果会被缓存(TTL 由 graphrag_cache_kg_ttl 控制)。
|
||||||
|
"""
|
||||||
|
cache = get_cache()
|
||||||
|
cache_key = make_cache_key("fulltext", graph_id, query, size, user_id)
|
||||||
|
cached = cache.get_kg(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
try:
|
||||||
|
result = await self._fulltext_search_impl(graph_id, query, size, user_id)
|
||||||
|
cache.set_kg(cache_key, result)
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"KG fulltext search failed for graph_id=%s (fail-open, returning empty)",
|
||||||
|
graph_id,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _fulltext_search_impl(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
query: str,
|
||||||
|
size: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> list[EntitySummary]:
|
||||||
|
client = self._get_client()
|
||||||
|
resp = await client.get(
|
||||||
|
f"/api/knowledge-graph/{graph_id}/query/search",
|
||||||
|
params={"q": query, "size": size},
|
||||||
|
headers=self._headers(user_id),
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = resp.json()
|
||||||
|
|
||||||
|
# Java 返回 PagedResponse<SearchHitVO>:
|
||||||
|
# 可能被全局包装为 {"code": 200, "data": PagedResponse}
|
||||||
|
# 也可能直接返回 PagedResponse {"page": 0, "content": [...]}
|
||||||
|
data = body.get("data", body)
|
||||||
|
# PagedResponse 将实体列表放在 content 字段中
|
||||||
|
items: list[dict] = (
|
||||||
|
data.get("content", []) if isinstance(data, dict) else data if isinstance(data, list) else []
|
||||||
|
)
|
||||||
|
entities: list[EntitySummary] = []
|
||||||
|
for item in items:
|
||||||
|
entities.append(
|
||||||
|
EntitySummary(
|
||||||
|
id=str(item.get("id", "")),
|
||||||
|
name=item.get("name", ""),
|
||||||
|
type=item.get("type", ""),
|
||||||
|
description=item.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return entities
|
||||||
|
|
||||||
|
async def get_subgraph(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
entity_ids: list[str],
|
||||||
|
depth: int = 1,
|
||||||
|
user_id: str = "",
|
||||||
|
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||||
|
"""获取种子实体的 N-hop 子图。
|
||||||
|
|
||||||
|
Fail-open: KG 服务不可用时返回空子图。
|
||||||
|
结果会被缓存(TTL 由 graphrag_cache_kg_ttl 控制)。
|
||||||
|
"""
|
||||||
|
cache = get_cache()
|
||||||
|
cache_key = make_cache_key("subgraph", graph_id, sorted(entity_ids), depth, user_id)
|
||||||
|
cached = cache.get_kg(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
try:
|
||||||
|
result = await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
|
||||||
|
cache.set_kg(cache_key, result)
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"KG subgraph export failed for graph_id=%s (fail-open, returning empty)",
|
||||||
|
graph_id,
|
||||||
|
)
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
async def _get_subgraph_impl(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
entity_ids: list[str],
|
||||||
|
depth: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||||
|
client = self._get_client()
|
||||||
|
resp = await client.post(
|
||||||
|
f"/api/knowledge-graph/{graph_id}/query/subgraph/export",
|
||||||
|
params={"depth": depth},
|
||||||
|
json={"entityIds": entity_ids},
|
||||||
|
headers=self._headers(user_id),
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = resp.json()
|
||||||
|
|
||||||
|
# Java 返回 SubgraphExportVO:
|
||||||
|
# 可能被全局包装为 {"code": 200, "data": SubgraphExportVO}
|
||||||
|
# 也可能直接返回 SubgraphExportVO {"nodes": [...], "edges": [...]}
|
||||||
|
data = body.get("data", body) if isinstance(body.get("data"), dict) else body
|
||||||
|
nodes_raw = data.get("nodes", [])
|
||||||
|
edges_raw = data.get("edges", [])
|
||||||
|
|
||||||
|
# ExportNodeVO: id, name, type, description, properties (Map)
|
||||||
|
entities: list[EntitySummary] = []
|
||||||
|
for node in nodes_raw:
|
||||||
|
entities.append(
|
||||||
|
EntitySummary(
|
||||||
|
id=str(node.get("id", "")),
|
||||||
|
name=node.get("name", ""),
|
||||||
|
type=node.get("type", ""),
|
||||||
|
description=node.get("description", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
relations: list[RelationSummary] = []
|
||||||
|
# 构建 id -> entity 的映射用于查找 source/target 名称和类型
|
||||||
|
entity_map = {e.id: e for e in entities}
|
||||||
|
# ExportEdgeVO: sourceEntityId, targetEntityId, relationType
|
||||||
|
# 注意:sourceId 是数据来源 ID,不是源实体 ID
|
||||||
|
for edge in edges_raw:
|
||||||
|
source_id = str(edge.get("sourceEntityId", ""))
|
||||||
|
target_id = str(edge.get("targetEntityId", ""))
|
||||||
|
source_entity = entity_map.get(source_id)
|
||||||
|
target_entity = entity_map.get(target_id)
|
||||||
|
relations.append(
|
||||||
|
RelationSummary(
|
||||||
|
source_name=source_entity.name if source_entity else source_id,
|
||||||
|
source_type=source_entity.type if source_entity else "",
|
||||||
|
target_name=target_entity.name if target_entity else target_id,
|
||||||
|
target_type=target_entity.type if target_entity else "",
|
||||||
|
relation_type=edge.get("relationType", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return entities, relations
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Milvus 向量检索客户端。
|
||||||
|
|
||||||
|
通过 pymilvus 连接 Milvus,对查询文本进行 embedding 后执行混合搜索,
|
||||||
|
返回 top-K 文档片段。
|
||||||
|
|
||||||
|
失败策略:fail-open —— Milvus 不可用时返回空列表 + 日志告警。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_graphrag.models import VectorChunk
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusVectorRetriever:
|
||||||
|
"""Milvus 向量检索器。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
uri: str = "http://milvus-standalone:19530",
|
||||||
|
embedding_model: str = "text-embedding-3-small",
|
||||||
|
embedding_base_url: str | None = None,
|
||||||
|
embedding_api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
) -> None:
|
||||||
|
self._uri = uri
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
self._embedding_base_url = embedding_base_url
|
||||||
|
self._embedding_api_key = embedding_api_key
|
||||||
|
# Lazy init
|
||||||
|
self._milvus_client = None
|
||||||
|
self._embeddings = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> MilvusVectorRetriever:
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
embedding_model = (
|
||||||
|
settings.graphrag_embedding_model
|
||||||
|
or settings.kg_alignment_embedding_model
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
uri=settings.graphrag_milvus_uri,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_base_url=settings.kg_llm_base_url,
|
||||||
|
embedding_api_key=settings.kg_llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_embeddings(self):
|
||||||
|
if self._embeddings is None:
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
self._embeddings = OpenAIEmbeddings(
|
||||||
|
model=self._embedding_model,
|
||||||
|
base_url=self._embedding_base_url,
|
||||||
|
api_key=self._embedding_api_key,
|
||||||
|
)
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
def _get_milvus_client(self):
|
||||||
|
if self._milvus_client is None:
|
||||||
|
from pymilvus import MilvusClient
|
||||||
|
|
||||||
|
self._milvus_client = MilvusClient(uri=self._uri)
|
||||||
|
logger.info("Connected to Milvus at %s", self._uri)
|
||||||
|
return self._milvus_client
|
||||||
|
|
||||||
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
|
"""检查 Milvus 中是否存在指定 collection(防止越权访问不存在的库)。"""
|
||||||
|
try:
|
||||||
|
client = self._get_milvus_client()
|
||||||
|
return await asyncio.to_thread(client.has_collection, collection_name)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Milvus has_collection check failed for %s", collection_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> list[VectorChunk]:
|
||||||
|
"""向量搜索:embed query -> Milvus search -> 返回 top-K 文档片段。
|
||||||
|
|
||||||
|
Fail-open: Milvus 不可用时返回空列表。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self._search_impl(collection_name, query, top_k)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Milvus search failed for collection=%s (fail-open, returning empty)",
|
||||||
|
collection_name,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_impl(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorChunk]:
|
||||||
|
# 1. Embed query
|
||||||
|
query_vector = await self._get_embeddings().aembed_query(query)
|
||||||
|
|
||||||
|
# 2. Milvus search(同步 I/O,通过 to_thread 避免阻塞事件循环)
|
||||||
|
client = self._get_milvus_client()
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
client.search,
|
||||||
|
collection_name=collection_name,
|
||||||
|
data=[query_vector],
|
||||||
|
limit=top_k,
|
||||||
|
output_fields=["text", "metadata"],
|
||||||
|
search_params={"metric_type": "COSINE", "params": {"nprobe": 16}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 转换为 VectorChunk
|
||||||
|
chunks: list[VectorChunk] = []
|
||||||
|
if results and len(results) > 0:
|
||||||
|
for hit in results[0]:
|
||||||
|
entity = hit.get("entity", {})
|
||||||
|
chunks.append(
|
||||||
|
VectorChunk(
|
||||||
|
id=str(hit.get("id", "")),
|
||||||
|
text=entity.get("text", ""),
|
||||||
|
score=float(hit.get("distance", 0.0)),
|
||||||
|
metadata=entity.get("metadata", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""GraphRAG 融合查询的请求/响应数据模型。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalStrategy(BaseModel):
|
||||||
|
"""检索策略配置。"""
|
||||||
|
|
||||||
|
vector_top_k: int = Field(default=5, ge=1, le=50, description="向量检索返回数")
|
||||||
|
graph_depth: int = Field(default=2, ge=1, le=5, description="图谱扩展深度")
|
||||||
|
graph_max_entities: int = Field(default=20, ge=1, le=100, description="图谱最大实体数")
|
||||||
|
vector_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="向量分数权重")
|
||||||
|
graph_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="图谱相关性权重")
|
||||||
|
enable_graph: bool = Field(default=True, description="是否启用图谱检索")
|
||||||
|
enable_vector: bool = Field(default=True, description="是否启用向量检索")
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGQueryRequest(BaseModel):
|
||||||
|
"""GraphRAG 查询请求。"""
|
||||||
|
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=2000,
|
||||||
|
description="用户查询",
|
||||||
|
)
|
||||||
|
knowledge_base_id: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=64,
|
||||||
|
description="知识库 ID,用于权限校验(由上游 Java 后端传入)",
|
||||||
|
)
|
||||||
|
collection_name: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=256,
|
||||||
|
pattern=r"^[a-zA-Z0-9_\-\u4e00-\u9fff]+$",
|
||||||
|
description="Milvus collection 名称(= 知识库名),仅允许字母、数字、下划线、连字符和中文",
|
||||||
|
)
|
||||||
|
graph_id: str = Field(
|
||||||
|
...,
|
||||||
|
pattern=r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
|
||||||
|
description="Neo4j 图谱 ID(UUID 格式)",
|
||||||
|
)
|
||||||
|
strategy: RetrievalStrategy = Field(
|
||||||
|
default_factory=RetrievalStrategy,
|
||||||
|
description="可选策略覆盖",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorChunk(BaseModel):
|
||||||
|
"""向量检索到的文档片段。"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
score: float
|
||||||
|
metadata: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class EntitySummary(BaseModel):
|
||||||
|
"""实体摘要。"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RelationSummary(BaseModel):
|
||||||
|
"""关系摘要。"""
|
||||||
|
|
||||||
|
source_name: str
|
||||||
|
source_type: str
|
||||||
|
target_name: str
|
||||||
|
target_type: str
|
||||||
|
relation_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class GraphContext(BaseModel):
|
||||||
|
"""图谱上下文。"""
|
||||||
|
|
||||||
|
entities: list[EntitySummary] = Field(default_factory=list)
|
||||||
|
relations: list[RelationSummary] = Field(default_factory=list)
|
||||||
|
textualized: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalContext(BaseModel):
|
||||||
|
"""检索上下文(检索结果的结构化表示)。"""
|
||||||
|
|
||||||
|
vector_chunks: list[VectorChunk] = Field(default_factory=list)
|
||||||
|
graph_context: GraphContext = Field(default_factory=GraphContext)
|
||||||
|
merged_text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGQueryResponse(BaseModel):
|
||||||
|
"""GraphRAG 查询响应。"""
|
||||||
|
|
||||||
|
answer: str = Field(..., description="LLM 生成的回答")
|
||||||
|
context: RetrievalContext = Field(..., description="检索上下文")
|
||||||
|
model: str = Field(..., description="使用的 LLM 模型名")
|
||||||
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""GraphRAG 检索编排器。
|
||||||
|
|
||||||
|
并行执行向量检索和图谱检索,融合排序后构建统一上下文。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph
|
||||||
|
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||||
|
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
EntitySummary,
|
||||||
|
GraphContext,
|
||||||
|
RelationSummary,
|
||||||
|
RetrievalContext,
|
||||||
|
RetrievalStrategy,
|
||||||
|
VectorChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGRetriever:
|
||||||
|
"""GraphRAG 检索编排器。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
milvus_client: MilvusVectorRetriever,
|
||||||
|
kg_client: KGServiceClient,
|
||||||
|
) -> None:
|
||||||
|
self._milvus = milvus_client
|
||||||
|
self._kg = kg_client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> GraphRAGRetriever:
|
||||||
|
return cls(
|
||||||
|
milvus_client=MilvusVectorRetriever.from_settings(),
|
||||||
|
kg_client=KGServiceClient.from_settings(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
collection_name: str,
|
||||||
|
graph_id: str,
|
||||||
|
strategy: RetrievalStrategy,
|
||||||
|
user_id: str = "",
|
||||||
|
) -> RetrievalContext:
|
||||||
|
"""并行执行向量检索 + 图谱检索,融合结果。"""
|
||||||
|
# 构建并行任务
|
||||||
|
tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
if strategy.enable_vector:
|
||||||
|
# 先校验 collection 存在性,防止越权访问
|
||||||
|
if not await self._milvus.has_collection(collection_name):
|
||||||
|
logger.warning(
|
||||||
|
"Collection %s not found, skipping vector retrieval",
|
||||||
|
collection_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tasks["vector"] = asyncio.create_task(
|
||||||
|
self._milvus.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query,
|
||||||
|
top_k=strategy.vector_top_k,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if strategy.enable_graph:
|
||||||
|
tasks["graph"] = asyncio.create_task(
|
||||||
|
self._retrieve_graph(
|
||||||
|
query=query,
|
||||||
|
graph_id=graph_id,
|
||||||
|
strategy=strategy,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
vector_chunks: list[VectorChunk] = []
|
||||||
|
if "vector" in tasks:
|
||||||
|
try:
|
||||||
|
vector_chunks = tasks["vector"].result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Vector retrieval task failed")
|
||||||
|
|
||||||
|
entities: list[EntitySummary] = []
|
||||||
|
relations: list[RelationSummary] = []
|
||||||
|
if "graph" in tasks:
|
||||||
|
try:
|
||||||
|
entities, relations = tasks["graph"].result()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Graph retrieval task failed")
|
||||||
|
|
||||||
|
# 融合排序
|
||||||
|
vector_chunks = self._rank_results(
|
||||||
|
vector_chunks, entities, relations, strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
# 三元组文本化
|
||||||
|
graph_text = textualize_subgraph(entities, relations)
|
||||||
|
|
||||||
|
# 构建上下文
|
||||||
|
merged_text = build_context(
|
||||||
|
vector_chunks,
|
||||||
|
graph_text,
|
||||||
|
vector_weight=strategy.vector_weight,
|
||||||
|
graph_weight=strategy.graph_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalContext(
|
||||||
|
vector_chunks=vector_chunks,
|
||||||
|
graph_context=GraphContext(
|
||||||
|
entities=entities,
|
||||||
|
relations=relations,
|
||||||
|
textualized=graph_text,
|
||||||
|
),
|
||||||
|
merged_text=merged_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _retrieve_graph(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
graph_id: str,
|
||||||
|
strategy: RetrievalStrategy,
|
||||||
|
user_id: str,
|
||||||
|
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||||
|
"""图谱检索:全文搜索 -> 种子实体 -> 子图扩展。"""
|
||||||
|
# 1. 全文检索获取种子实体
|
||||||
|
seed_entities = await self._kg.fulltext_search(
|
||||||
|
graph_id=graph_id,
|
||||||
|
query=query,
|
||||||
|
size=strategy.graph_max_entities,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not seed_entities:
|
||||||
|
logger.debug("No seed entities found for query: %s", query)
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# 2. 获取种子实体的 N-hop 子图
|
||||||
|
seed_ids = [e.id for e in seed_entities]
|
||||||
|
entities, relations = await self._kg.get_subgraph(
|
||||||
|
graph_id=graph_id,
|
||||||
|
entity_ids=seed_ids,
|
||||||
|
depth=strategy.graph_depth,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Graph retrieval: %d seed entities -> %d entities, %d relations",
|
||||||
|
len(seed_entities), len(entities), len(relations),
|
||||||
|
)
|
||||||
|
return entities, relations
|
||||||
|
|
||||||
|
def _rank_results(
|
||||||
|
self,
|
||||||
|
vector_chunks: list[VectorChunk],
|
||||||
|
entities: list[EntitySummary],
|
||||||
|
relations: list[RelationSummary],
|
||||||
|
strategy: RetrievalStrategy,
|
||||||
|
) -> list[VectorChunk]:
|
||||||
|
"""对向量检索结果进行融合排序。
|
||||||
|
|
||||||
|
基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。
|
||||||
|
"""
|
||||||
|
if not vector_chunks:
|
||||||
|
return vector_chunks
|
||||||
|
|
||||||
|
# 向量分数归一化 (min-max scaling)
|
||||||
|
scores = [c.score for c in vector_chunks]
|
||||||
|
min_score = min(scores)
|
||||||
|
max_score = max(scores)
|
||||||
|
score_range = max_score - min_score
|
||||||
|
|
||||||
|
# 构建图谱实体名称集合,用于关联度加分
|
||||||
|
graph_entity_names = {e.name.lower() for e in entities}
|
||||||
|
|
||||||
|
ranked: list[tuple[float, VectorChunk]] = []
|
||||||
|
for chunk in vector_chunks:
|
||||||
|
# 归一化向量分数
|
||||||
|
norm_score = (
|
||||||
|
(chunk.score - min_score) / score_range
|
||||||
|
if score_range > 0
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 图谱关联度加分:文档片段中提及图谱实体名称
|
||||||
|
graph_boost = 0.0
|
||||||
|
if graph_entity_names:
|
||||||
|
chunk_text_lower = chunk.text.lower()
|
||||||
|
mentioned = sum(
|
||||||
|
1 for name in graph_entity_names if name in chunk_text_lower
|
||||||
|
)
|
||||||
|
graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0)
|
||||||
|
|
||||||
|
# 加权融合分数
|
||||||
|
final_score = (
|
||||||
|
strategy.vector_weight * norm_score
|
||||||
|
+ strategy.graph_weight * graph_boost
|
||||||
|
)
|
||||||
|
ranked.append((final_score, chunk))
|
||||||
|
|
||||||
|
# 按融合分数降序排序
|
||||||
|
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
return [chunk for _, chunk in ranked]
|
||||||
183
runtime/datamate-python/app/module/kg_graphrag/test_cache.py
Normal file
183
runtime/datamate-python/app/module/kg_graphrag/test_cache.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""GraphRAG 缓存的单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.cache import CacheStats, GraphRAGCache, make_cache_key
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CacheStats
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheStats:
|
||||||
|
"""CacheStats 统计逻辑测试。"""
|
||||||
|
|
||||||
|
def test_hit_rate_no_access(self):
|
||||||
|
stats = CacheStats()
|
||||||
|
assert stats.hit_rate == 0.0
|
||||||
|
|
||||||
|
def test_hit_rate_all_hits(self):
|
||||||
|
stats = CacheStats(hits=10, misses=0)
|
||||||
|
assert stats.hit_rate == 1.0
|
||||||
|
|
||||||
|
def test_hit_rate_mixed(self):
|
||||||
|
stats = CacheStats(hits=3, misses=7)
|
||||||
|
assert abs(stats.hit_rate - 0.3) < 1e-9
|
||||||
|
|
||||||
|
def test_to_dict_contains_all_fields(self):
|
||||||
|
stats = CacheStats(hits=5, misses=3, evictions=1)
|
||||||
|
d = stats.to_dict()
|
||||||
|
assert d["hits"] == 5
|
||||||
|
assert d["misses"] == 3
|
||||||
|
assert d["evictions"] == 1
|
||||||
|
assert "hit_rate" in d
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GraphRAGCache — KG 缓存
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKGCache:
|
||||||
|
"""KG 缓存(全文搜索 + 子图导出)测试。"""
|
||||||
|
|
||||||
|
def test_get_miss_returns_none(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||||
|
assert cache.get_kg("nonexistent") is None
|
||||||
|
|
||||||
|
def test_set_then_get_hit(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||||
|
cache.set_kg("key1", {"entities": [1, 2, 3]})
|
||||||
|
result = cache.get_kg("key1")
|
||||||
|
assert result == {"entities": [1, 2, 3]}
|
||||||
|
|
||||||
|
def test_stats_count_hits_and_misses(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||||
|
cache.set_kg("a", "value-a")
|
||||||
|
|
||||||
|
cache.get_kg("a") # hit
|
||||||
|
cache.get_kg("a") # hit
|
||||||
|
cache.get_kg("b") # miss
|
||||||
|
|
||||||
|
stats = cache.stats()
|
||||||
|
assert stats["kg"]["hits"] == 2
|
||||||
|
assert stats["kg"]["misses"] == 1
|
||||||
|
|
||||||
|
def test_maxsize_evicts_oldest(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=2, kg_ttl=60)
|
||||||
|
cache.set_kg("a", 1)
|
||||||
|
cache.set_kg("b", 2)
|
||||||
|
cache.set_kg("c", 3) # should evict "a"
|
||||||
|
|
||||||
|
assert cache.get_kg("a") is None
|
||||||
|
assert cache.get_kg("c") == 3
|
||||||
|
|
||||||
|
def test_ttl_expiry(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=1)
|
||||||
|
cache.set_kg("ephemeral", "data")
|
||||||
|
assert cache.get_kg("ephemeral") == "data"
|
||||||
|
|
||||||
|
time.sleep(1.1)
|
||||||
|
assert cache.get_kg("ephemeral") is None
|
||||||
|
|
||||||
|
def test_clear_removes_all(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=10, kg_ttl=60)
|
||||||
|
cache.set_kg("x", 1)
|
||||||
|
cache.set_kg("y", 2)
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
assert cache.get_kg("x") is None
|
||||||
|
assert cache.get_kg("y") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GraphRAGCache — Embedding 缓存
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingCache:
|
||||||
|
"""Embedding 向量缓存测试。"""
|
||||||
|
|
||||||
|
def test_get_miss_returns_none(self):
|
||||||
|
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||||
|
assert cache.get_embedding("query-1") is None
|
||||||
|
|
||||||
|
def test_set_then_get_hit(self):
|
||||||
|
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||||
|
vec = [0.1, 0.2, 0.3, 0.4]
|
||||||
|
cache.set_embedding("query-1", vec)
|
||||||
|
assert cache.get_embedding("query-1") == vec
|
||||||
|
|
||||||
|
def test_stats_count_hits_and_misses(self):
|
||||||
|
cache = GraphRAGCache(embedding_maxsize=10, embedding_ttl=60)
|
||||||
|
cache.set_embedding("q1", [1.0])
|
||||||
|
cache.get_embedding("q1") # hit
|
||||||
|
cache.get_embedding("q2") # miss
|
||||||
|
|
||||||
|
stats = cache.stats()
|
||||||
|
assert stats["embedding"]["hits"] == 1
|
||||||
|
assert stats["embedding"]["misses"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GraphRAGCache — 整体统计
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheOverallStats:
|
||||||
|
"""缓存整体统计测试。"""
|
||||||
|
|
||||||
|
def test_stats_structure(self):
|
||||||
|
cache = GraphRAGCache(kg_maxsize=5, kg_ttl=60, embedding_maxsize=10, embedding_ttl=60)
|
||||||
|
stats = cache.stats()
|
||||||
|
|
||||||
|
assert "kg" in stats
|
||||||
|
assert "embedding" in stats
|
||||||
|
assert "size" in stats["kg"]
|
||||||
|
assert "maxsize" in stats["kg"]
|
||||||
|
assert "hits" in stats["kg"]
|
||||||
|
assert "misses" in stats["kg"]
|
||||||
|
|
||||||
|
def test_zero_maxsize_disables_caching(self):
|
||||||
|
"""maxsize=0 时,所有 set 都是 no-op。"""
|
||||||
|
cache = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||||
|
cache.set_kg("key", "value")
|
||||||
|
assert cache.get_kg("key") is None
|
||||||
|
|
||||||
|
cache.set_embedding("key", [1.0])
|
||||||
|
assert cache.get_embedding("key") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# make_cache_key
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeCacheKey:
|
||||||
|
"""缓存 key 生成测试。"""
|
||||||
|
|
||||||
|
def test_deterministic(self):
|
||||||
|
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||||
|
key2 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_different_args_different_keys(self):
|
||||||
|
key1 = make_cache_key("fulltext", "graph-1", "hello", 10)
|
||||||
|
key2 = make_cache_key("fulltext", "graph-1", "world", 10)
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
def test_order_matters(self):
|
||||||
|
key1 = make_cache_key("a", "b")
|
||||||
|
key2 = make_cache_key("b", "a")
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
def test_handles_unicode(self):
|
||||||
|
key = make_cache_key("用户行为数据", "图谱")
|
||||||
|
assert len(key) == 64 # SHA-256 hex digest
|
||||||
|
|
||||||
|
def test_handles_list_args(self):
|
||||||
|
key = make_cache_key("subgraph", ["id-1", "id-2"], 2)
|
||||||
|
assert len(key) == 64
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
"""三元组文本化 + 上下文构建的单元测试。"""
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.context_builder import (
|
||||||
|
RELATION_TEMPLATES,
|
||||||
|
build_context,
|
||||||
|
textualize_subgraph,
|
||||||
|
)
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
EntitySummary,
|
||||||
|
RelationSummary,
|
||||||
|
VectorChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# textualize_subgraph 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextualizeSubgraph:
|
||||||
|
"""textualize_subgraph 函数的测试。"""
|
||||||
|
|
||||||
|
def test_single_relation(self):
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
|
||||||
|
EntitySummary(id="2", name="user_id", type="Field"),
|
||||||
|
]
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="用户行为数据",
|
||||||
|
source_type="Dataset",
|
||||||
|
target_name="user_id",
|
||||||
|
target_type="Field",
|
||||||
|
relation_type="HAS_FIELD",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph(entities, relations)
|
||||||
|
assert "Dataset'用户行为数据'包含字段Field'user_id'" in result
|
||||||
|
|
||||||
|
def test_multiple_relations(self):
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
|
||||||
|
EntitySummary(id="2", name="清洗管道", type="Workflow"),
|
||||||
|
]
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="清洗管道",
|
||||||
|
source_type="Workflow",
|
||||||
|
target_name="用户行为数据",
|
||||||
|
target_type="Dataset",
|
||||||
|
relation_type="USES_DATASET",
|
||||||
|
),
|
||||||
|
RelationSummary(
|
||||||
|
source_name="用户行为数据",
|
||||||
|
source_type="Dataset",
|
||||||
|
target_name="外部系统",
|
||||||
|
target_type="DataSource",
|
||||||
|
relation_type="SOURCED_FROM",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph(entities, relations)
|
||||||
|
assert "Workflow'清洗管道'使用了数据集Dataset'用户行为数据'" in result
|
||||||
|
assert "Dataset'用户行为数据'的知识来源于DataSource'外部系统'" in result
|
||||||
|
|
||||||
|
def test_all_relation_templates(self):
|
||||||
|
"""验证所有 10 种关系模板都能正确生成。"""
|
||||||
|
for rel_type, template in RELATION_TEMPLATES.items():
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="A",
|
||||||
|
source_type="TypeA",
|
||||||
|
target_name="B",
|
||||||
|
target_type="TypeB",
|
||||||
|
relation_type=rel_type,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph([], relations)
|
||||||
|
assert "TypeA'A'" in result
|
||||||
|
assert "TypeB'B'" in result
|
||||||
|
assert result # 非空
|
||||||
|
|
||||||
|
def test_unknown_relation_type(self):
|
||||||
|
"""未知关系类型使用通用模板。"""
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="X",
|
||||||
|
source_type="T1",
|
||||||
|
target_name="Y",
|
||||||
|
target_type="T2",
|
||||||
|
relation_type="CUSTOM_REL",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph([], relations)
|
||||||
|
assert "T1'X'与T2'Y'存在CUSTOM_REL关系" in result
|
||||||
|
|
||||||
|
def test_orphan_entity_with_description(self):
|
||||||
|
"""无关系的独立实体(有描述)。"""
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="1", name="孤立实体", type="Dataset", description="这是一个测试实体"),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph(entities, [])
|
||||||
|
assert "Dataset'孤立实体': 这是一个测试实体" in result
|
||||||
|
|
||||||
|
def test_orphan_entity_without_description(self):
|
||||||
|
"""无关系的独立实体(无描述)。"""
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="1", name="孤立实体", type="Dataset"),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph(entities, [])
|
||||||
|
assert "存在Dataset'孤立实体'" in result
|
||||||
|
|
||||||
|
def test_empty_inputs(self):
|
||||||
|
result = textualize_subgraph([], [])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_entity_with_relation_not_orphan(self):
|
||||||
|
"""有关系的实体不应出现在独立实体部分。"""
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="1", name="A", type="Dataset"),
|
||||||
|
EntitySummary(id="2", name="B", type="Field"),
|
||||||
|
EntitySummary(id="3", name="C", type="Workflow"),
|
||||||
|
]
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="A",
|
||||||
|
source_type="Dataset",
|
||||||
|
target_name="B",
|
||||||
|
target_type="Field",
|
||||||
|
relation_type="HAS_FIELD",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = textualize_subgraph(entities, relations)
|
||||||
|
# A 和 B 有关系,不应作为独立实体出现
|
||||||
|
# C 无关系,应出现
|
||||||
|
assert "存在Workflow'C'" in result
|
||||||
|
lines = result.strip().split("\n")
|
||||||
|
assert len(lines) == 2 # 一条关系 + 一个独立实体
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_context 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildContext:
|
||||||
|
"""build_context 函数的测试。"""
|
||||||
|
|
||||||
|
def test_both_vector_and_graph(self):
|
||||||
|
chunks = [
|
||||||
|
VectorChunk(id="1", text="文档片段一", score=0.9),
|
||||||
|
VectorChunk(id="2", text="文档片段二", score=0.8),
|
||||||
|
]
|
||||||
|
graph_text = "Dataset'用户数据'包含字段Field'user_id'"
|
||||||
|
result = build_context(chunks, graph_text)
|
||||||
|
assert "## 相关文档" in result
|
||||||
|
assert "[1] 文档片段一" in result
|
||||||
|
assert "[2] 文档片段二" in result
|
||||||
|
assert "## 知识图谱上下文" in result
|
||||||
|
assert graph_text in result
|
||||||
|
|
||||||
|
def test_vector_only(self):
|
||||||
|
chunks = [VectorChunk(id="1", text="文档片段", score=0.9)]
|
||||||
|
result = build_context(chunks, "")
|
||||||
|
assert "## 相关文档" in result
|
||||||
|
assert "## 知识图谱上下文" not in result
|
||||||
|
|
||||||
|
def test_graph_only(self):
|
||||||
|
result = build_context([], "图谱内容")
|
||||||
|
assert "## 知识图谱上下文" in result
|
||||||
|
assert "## 相关文档" not in result
|
||||||
|
|
||||||
|
def test_empty_both(self):
|
||||||
|
result = build_context([], "")
|
||||||
|
assert "未检索到相关上下文信息" in result
|
||||||
|
|
||||||
|
def test_context_section_order(self):
|
||||||
|
"""验证文档在图谱之前。"""
|
||||||
|
chunks = [VectorChunk(id="1", text="doc", score=0.9)]
|
||||||
|
result = build_context(chunks, "graph")
|
||||||
|
doc_pos = result.index("## 相关文档")
|
||||||
|
graph_pos = result.index("## 知识图谱上下文")
|
||||||
|
assert doc_pos < graph_pos
|
||||||
300
runtime/datamate-python/app/module/kg_graphrag/test_interface.py
Normal file
300
runtime/datamate-python/app/module/kg_graphrag/test_interface.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""GraphRAG API 端点回归测试。
|
||||||
|
|
||||||
|
验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点
|
||||||
|
的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
|
|
||||||
|
from app.exception import (
|
||||||
|
fastapi_http_exception_handler,
|
||||||
|
starlette_http_exception_handler,
|
||||||
|
validation_exception_handler,
|
||||||
|
)
|
||||||
|
from app.module.kg_graphrag.interface import router
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
GraphContext,
|
||||||
|
RetrievalContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_app = FastAPI()
|
||||||
|
_app.include_router(router, prefix="/api")
|
||||||
|
_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler)
|
||||||
|
_app.add_exception_handler(HTTPException, fastapi_http_exception_handler)
|
||||||
|
_app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc"
|
||||||
|
|
||||||
|
_VALID_BODY = {
|
||||||
|
"query": "测试查询",
|
||||||
|
"knowledge_base_id": "kb-1",
|
||||||
|
"collection_name": "test-collection",
|
||||||
|
"graph_id": _VALID_GRAPH_ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
_HEADERS = {"X-User-Id": "user-1"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_retrieval_context() -> RetrievalContext:
|
||||||
|
return RetrievalContext(
|
||||||
|
vector_chunks=[],
|
||||||
|
graph_context=GraphContext(),
|
||||||
|
merged_text="test context",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_retriever_mock() -> AsyncMock:
|
||||||
|
m = AsyncMock()
|
||||||
|
m.retrieve = AsyncMock(return_value=_fake_retrieval_context())
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _make_generator_mock() -> AsyncMock:
|
||||||
|
m = AsyncMock()
|
||||||
|
m.generate = AsyncMock(return_value="test answer")
|
||||||
|
m.model_name = "test-model"
|
||||||
|
|
||||||
|
async def _stream(*, query: str, context: str): # noqa: ARG001
|
||||||
|
for token in ["hello", " ", "world"]:
|
||||||
|
yield token
|
||||||
|
|
||||||
|
m.generate_stream = _stream
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock:
|
||||||
|
m = AsyncMock()
|
||||||
|
m.check_access = AsyncMock(return_value=access_granted)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_all(
|
||||||
|
*,
|
||||||
|
access_granted: bool = True,
|
||||||
|
retriever: AsyncMock | None = None,
|
||||||
|
generator: AsyncMock | None = None,
|
||||||
|
validator: AsyncMock | None = None,
|
||||||
|
):
|
||||||
|
"""返回 context manager,统一 patch 三个懒加载工厂函数。"""
|
||||||
|
retriever = retriever or _make_retriever_mock()
|
||||||
|
generator = generator or _make_generator_mock()
|
||||||
|
validator = validator or _make_kb_validator_mock(access_granted=access_granted)
|
||||||
|
|
||||||
|
class _Ctx:
|
||||||
|
def __init__(self):
|
||||||
|
self.retriever = retriever
|
||||||
|
self.generator = generator
|
||||||
|
self.validator = validator
|
||||||
|
self._patches = [
|
||||||
|
patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever),
|
||||||
|
patch("app.module.kg_graphrag.interface._get_generator", return_value=generator),
|
||||||
|
patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
for p in self._patches:
|
||||||
|
p.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
for p in reversed(self._patches):
|
||||||
|
p.__exit__(*args)
|
||||||
|
|
||||||
|
return _Ctx()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(_app)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/graphrag/query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryEndpoint:
|
||||||
|
"""POST /api/graphrag/query 端点测试。"""
|
||||||
|
|
||||||
|
def test_success(self, client: TestClient):
|
||||||
|
"""权限校验通过 + 检索 + 生成 → 200。"""
|
||||||
|
with _patch_all(access_granted=True) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["code"] == 200
|
||||||
|
assert body["data"]["answer"] == "test answer"
|
||||||
|
assert body["data"]["model"] == "test-model"
|
||||||
|
ctx.retriever.retrieve.assert_awaited_once()
|
||||||
|
ctx.generator.generate.assert_awaited_once()
|
||||||
|
|
||||||
|
def test_access_denied_returns_403(self, client: TestClient):
|
||||||
|
"""check_access 返回 False → 403 + 标准错误格式。"""
|
||||||
|
with _patch_all(access_granted=False):
|
||||||
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
body = resp.json()
|
||||||
|
assert body["code"] == 403
|
||||||
|
assert "kb-1" in body["data"]["detail"]
|
||||||
|
|
||||||
|
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
||||||
|
"""权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。"""
|
||||||
|
with _patch_all(access_granted=False) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
ctx.retriever.retrieve.assert_not_called()
|
||||||
|
ctx.generator.generate.assert_not_called()
|
||||||
|
|
||||||
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||||
|
"""验证 check_access 被调用时携带正确的 collection_name 参数。"""
|
||||||
|
with _patch_all(access_granted=True) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
ctx.validator.check_access.assert_awaited_once_with(
|
||||||
|
"kb-1", "user-1", collection_name="test-collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||||
|
"""缺少 X-User-Id 请求头 → 422 验证错误。"""
|
||||||
|
with _patch_all(access_granted=True):
|
||||||
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/graphrag/retrieve
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveEndpoint:
|
||||||
|
"""POST /api/graphrag/retrieve 端点测试。"""
|
||||||
|
|
||||||
|
def test_success(self, client: TestClient):
|
||||||
|
"""权限通过 → 检索 → 返回 RetrievalContext。"""
|
||||||
|
with _patch_all(access_granted=True) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["code"] == 200
|
||||||
|
assert body["data"]["merged_text"] == "test context"
|
||||||
|
ctx.retriever.retrieve.assert_awaited_once()
|
||||||
|
|
||||||
|
def test_access_denied_returns_403(self, client: TestClient):
|
||||||
|
"""权限拒绝 → 403。"""
|
||||||
|
with _patch_all(access_granted=False):
|
||||||
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
body = resp.json()
|
||||||
|
assert body["code"] == 403
|
||||||
|
|
||||||
|
def test_access_denied_skips_retrieval(self, client: TestClient):
|
||||||
|
"""权限拒绝时不调用 retriever.retrieve。"""
|
||||||
|
with _patch_all(access_granted=False) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
ctx.retriever.retrieve.assert_not_called()
|
||||||
|
|
||||||
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||||
|
"""验证 check_access 收到 collection_name 参数。"""
|
||||||
|
with _patch_all(access_granted=True) as ctx:
|
||||||
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
ctx.validator.check_access.assert_awaited_once_with(
|
||||||
|
"kb-1", "user-1", collection_name="test-collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||||
|
"""缺少 X-User-Id → 422。"""
|
||||||
|
with _patch_all(access_granted=True):
|
||||||
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/graphrag/query/stream
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryStreamEndpoint:
|
||||||
|
"""POST /api/graphrag/query/stream 端点测试。"""
|
||||||
|
|
||||||
|
def test_success_returns_sse(self, client: TestClient):
|
||||||
|
"""权限通过 → SSE 流式响应,包含 token 和 done 事件。"""
|
||||||
|
with _patch_all(access_granted=True):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||||
|
text = resp.text
|
||||||
|
assert '"token"' in text
|
||||||
|
assert '"done": true' in text or '"done":true' in text
|
||||||
|
|
||||||
|
def test_access_denied_returns_403(self, client: TestClient):
|
||||||
|
"""权限拒绝 → 403。"""
|
||||||
|
with _patch_all(access_granted=False):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
body = resp.json()
|
||||||
|
assert body["code"] == 403
|
||||||
|
|
||||||
|
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
||||||
|
"""权限拒绝时不调用检索和生成。"""
|
||||||
|
with _patch_all(access_granted=False) as ctx:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 403
|
||||||
|
ctx.retriever.retrieve.assert_not_called()
|
||||||
|
|
||||||
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||||
|
"""验证 check_access 收到 collection_name 参数。"""
|
||||||
|
with _patch_all(access_granted=True) as ctx:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
ctx.validator.check_access.assert_awaited_once_with(
|
||||||
|
"kb-1", "user-1", collection_name="test-collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||||
|
"""缺少 X-User-Id → 422。"""
|
||||||
|
with _patch_all(access_granted=True):
|
||||||
|
resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY)
|
||||||
|
|
||||||
|
assert resp.status_code == 422
|
||||||
330
runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py
Normal file
330
runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
"""知识库访问权限校验的单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def validator() -> KnowledgeBaseAccessValidator:
|
||||||
|
return KnowledgeBaseAccessValidator(
|
||||||
|
base_url="http://test-backend:8080/api",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
||||||
|
|
||||||
|
|
||||||
|
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
||||||
|
"""创建带 request 的 httpx.Response。"""
|
||||||
|
if json is not None:
|
||||||
|
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
||||||
|
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_access 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckAccess:
|
||||||
|
"""check_access 方法的测试。"""
|
||||||
|
|
||||||
|
def test_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 返回 200 + code=200: 用户有权访问。"""
|
||||||
|
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""权限通过且 collection_name 与 KB name 一致:允许访问。"""
|
||||||
|
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-1", "user-1", collection_name="my-collection",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。"""
|
||||||
|
mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "other-user"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 返回 HTTP 403。"""
|
||||||
|
mock_resp = _resp(403, text="Forbidden")
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""知识库不存在,Java 返回 404。"""
|
||||||
|
mock_resp = _resp(404, text="Not Found")
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("nonexistent-kb", "user-1"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 后端返回 500。"""
|
||||||
|
mock_resp = _resp(500, text="Internal Server Error")
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""网络异常时 fail-close(拒绝访问),防止绕过权限校验。"""
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""超时时 fail-close(拒绝访问)。"""
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout"))
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_request_headers(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""验证请求中携带正确的 X-User-Id header。"""
|
||||||
|
mock_resp = _resp(200, json={"code": 200, "data": {}})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
_run(validator.check_access("kb-123", "user-456"))
|
||||||
|
|
||||||
|
call_kwargs = mock_http.get.call_args
|
||||||
|
assert "/api/knowledge-base/kb-123" in call_kwargs.args[0]
|
||||||
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456"
|
||||||
|
|
||||||
|
def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。
|
||||||
|
|
||||||
|
模拟 Java 后端返回权限不足的业务错误。
|
||||||
|
"""
|
||||||
|
# 用户 A 创建的 KB,用户 B 请求访问
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": "sys.0005",
|
||||||
|
"message": "权限不足",
|
||||||
|
"data": None,
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-user-a", "user-b"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
# 确认请求携带的是用户 B 的 ID
|
||||||
|
call_kwargs = mock_http.get.call_args
|
||||||
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b"
|
||||||
|
|
||||||
|
def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access("kb-user-a", "admin-user"))
|
||||||
|
|
||||||
|
# Java 侧管理员校验通过,返回 200 + code=200
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# collection_name 绑定校验测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectionNameBinding:
|
||||||
|
"""collection_name 与 knowledge_base_id 的绑定校验测试。
|
||||||
|
|
||||||
|
防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他
|
||||||
|
知识库的 Milvus 数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-a", "name": "collection-a"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-a", "user-1", collection_name="collection-b",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""collection_name=None 时不做绑定校验(向后兼容)。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-1", "name": "some-name"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
# 不传 collection_name → 仅校验权限,不校验绑定
|
||||||
|
result = _run(validator.check_access("kb-1", "user-1"))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 响应 data 中没有 name 字段:fail-close 拒绝。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-1"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-1", "user-1", collection_name="any-collection",
|
||||||
|
))
|
||||||
|
|
||||||
|
# data.name is None, doesn't match "any-collection" → denied
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 响应 data 为 null:fail-close 拒绝。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": None,
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-1", "user-1", collection_name="any-collection",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""Java 响应 data 为空 dict {}:fail-close 拒绝。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-1", "user-1", collection_name="any-collection",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的
|
||||||
|
collection_name='kb-b-data':应被拒绝。
|
||||||
|
|
||||||
|
这是核心越权场景的完整模拟。
|
||||||
|
"""
|
||||||
|
# 用户有权访问 KB-A
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
# 但 collection_name 指向 KB-B 的数据
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-a", "user-1", collection_name="kb-b-data",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator):
|
||||||
|
"""中文 collection_name 精确匹配。"""
|
||||||
|
mock_resp = _resp(200, json={
|
||||||
|
"code": 200,
|
||||||
|
"data": {"id": "kb-1", "name": "用户行为数据"},
|
||||||
|
})
|
||||||
|
with patch.object(validator, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
result = _run(validator.check_access(
|
||||||
|
"kb-1", "user-1", collection_name="用户行为数据",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
306
runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py
Normal file
306
runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
"""KG 服务 REST 客户端的单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.cache import GraphRAGCache
|
||||||
|
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client() -> KGServiceClient:
|
||||||
|
return KGServiceClient(
|
||||||
|
base_url="http://test-kg:8080",
|
||||||
|
internal_token="test-token",
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _disable_cache():
|
||||||
|
"""为每个测试禁用缓存,防止跨测试缓存命中干扰 mock 验证。"""
|
||||||
|
disabled = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
||||||
|
with patch("app.module.kg_graphrag.kg_client.get_cache", return_value=disabled):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
||||||
|
|
||||||
|
|
||||||
|
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
||||||
|
"""创建带 request 的 httpx.Response(raise_for_status 需要)。"""
|
||||||
|
if json is not None:
|
||||||
|
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
||||||
|
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# fulltext_search 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFulltextSearch:
|
||||||
|
"""fulltext_search 方法的测试。"""
|
||||||
|
|
||||||
|
def test_wrapped_paged_response(self, client: KGServiceClient):
|
||||||
|
"""Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}"""
|
||||||
|
mock_body = {
|
||||||
|
"code": 200,
|
||||||
|
"data": {
|
||||||
|
"page": 0,
|
||||||
|
"size": 20,
|
||||||
|
"totalElements": 2,
|
||||||
|
"totalPages": 1,
|
||||||
|
"content": [
|
||||||
|
{"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5},
|
||||||
|
{"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1"))
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert entities[0].id == "e1"
|
||||||
|
assert entities[0].name == "用户数据"
|
||||||
|
assert entities[0].type == "Dataset"
|
||||||
|
assert entities[1].name == "清洗管道"
|
||||||
|
|
||||||
|
def test_unwrapped_paged_response(self, client: KGServiceClient):
|
||||||
|
"""Java 直接返回 PagedResponse(无全局包装)。"""
|
||||||
|
mock_body = {
|
||||||
|
"page": 0,
|
||||||
|
"size": 10,
|
||||||
|
"totalElements": 1,
|
||||||
|
"totalPages": 1,
|
||||||
|
"content": [
|
||||||
|
{"id": "e1", "name": "A", "type": "Dataset", "description": "desc"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities = _run(client.fulltext_search("graph-1", "A"))
|
||||||
|
|
||||||
|
# body has no "data" key → fallback to body itself → read "content"
|
||||||
|
assert len(entities) == 1
|
||||||
|
assert entities[0].name == "A"
|
||||||
|
|
||||||
|
def test_empty_content(self, client: KGServiceClient):
|
||||||
|
mock_body = {"code": 200, "data": {"page": 0, "content": []}}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities = _run(client.fulltext_search("graph-1", "nothing"))
|
||||||
|
|
||||||
|
assert entities == []
|
||||||
|
|
||||||
|
def test_fail_open_on_http_error(self, client: KGServiceClient):
|
||||||
|
"""HTTP 错误时 fail-open 返回空列表。"""
|
||||||
|
mock_resp = _resp(500, text="Internal Server Error")
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities = _run(client.fulltext_search("graph-1", "test"))
|
||||||
|
|
||||||
|
assert entities == []
|
||||||
|
|
||||||
|
def test_fail_open_on_connection_error(self, client: KGServiceClient):
|
||||||
|
"""连接错误时 fail-open 返回空列表。"""
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities = _run(client.fulltext_search("graph-1", "test"))
|
||||||
|
|
||||||
|
assert entities == []
|
||||||
|
|
||||||
|
def test_request_headers(self, client: KGServiceClient):
|
||||||
|
"""验证请求中携带正确的 headers。"""
|
||||||
|
mock_resp = _resp(200, json={"data": {"content": []}})
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
_run(client.fulltext_search("gid", "q", size=5, user_id="user-123"))
|
||||||
|
|
||||||
|
call_kwargs = mock_http.get.call_args
|
||||||
|
assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token"
|
||||||
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123"
|
||||||
|
assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_subgraph 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSubgraph:
|
||||||
|
"""get_subgraph 方法的测试。"""
|
||||||
|
|
||||||
|
def test_wrapped_subgraph_response(self, client: KGServiceClient):
|
||||||
|
"""Java 返回被全局包装的 SubgraphExportVO。"""
|
||||||
|
mock_body = {
|
||||||
|
"code": 200,
|
||||||
|
"data": {
|
||||||
|
"nodes": [
|
||||||
|
{"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}},
|
||||||
|
{"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}},
|
||||||
|
],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "edge1",
|
||||||
|
"sourceEntityId": "n1",
|
||||||
|
"targetEntityId": "n2",
|
||||||
|
"relationType": "HAS_FIELD",
|
||||||
|
"weight": 1.0,
|
||||||
|
"confidence": 0.9,
|
||||||
|
"sourceId": "kb-1",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodeCount": 2,
|
||||||
|
"edgeCount": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1"))
|
||||||
|
|
||||||
|
assert len(entities) == 2
|
||||||
|
assert entities[0].name == "用户数据"
|
||||||
|
assert entities[1].name == "user_id"
|
||||||
|
|
||||||
|
assert len(relations) == 1
|
||||||
|
assert relations[0].source_name == "用户数据"
|
||||||
|
assert relations[0].target_name == "user_id"
|
||||||
|
assert relations[0].relation_type == "HAS_FIELD"
|
||||||
|
assert relations[0].source_type == "Dataset"
|
||||||
|
assert relations[0].target_type == "Field"
|
||||||
|
|
||||||
|
def test_unwrapped_subgraph_response(self, client: KGServiceClient):
|
||||||
|
"""Java 直接返回 SubgraphExportVO(无全局包装)。"""
|
||||||
|
mock_body = {
|
||||||
|
"nodes": [
|
||||||
|
{"id": "n1", "name": "A", "type": "T1", "description": ""},
|
||||||
|
],
|
||||||
|
"edges": [],
|
||||||
|
"nodeCount": 1,
|
||||||
|
"edgeCount": 0,
|
||||||
|
}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||||
|
|
||||||
|
assert len(entities) == 1
|
||||||
|
assert entities[0].name == "A"
|
||||||
|
assert relations == []
|
||||||
|
|
||||||
|
def test_edge_with_unknown_entity(self, client: KGServiceClient):
|
||||||
|
"""边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。"""
|
||||||
|
mock_body = {
|
||||||
|
"code": 200,
|
||||||
|
"data": {
|
||||||
|
"nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}],
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"sourceEntityId": "n1",
|
||||||
|
"targetEntityId": "n999",
|
||||||
|
"relationType": "DEPENDS_ON",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_resp = _resp(200, json=mock_body)
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||||
|
|
||||||
|
assert len(relations) == 1
|
||||||
|
assert relations[0].source_name == "A"
|
||||||
|
assert relations[0].target_name == "n999" # fallback to ID
|
||||||
|
assert relations[0].target_type == ""
|
||||||
|
|
||||||
|
def test_fail_open_on_error(self, client: KGServiceClient):
|
||||||
|
mock_resp = _resp(500, text="error")
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||||
|
|
||||||
|
assert entities == []
|
||||||
|
assert relations == []
|
||||||
|
|
||||||
|
def test_request_params(self, client: KGServiceClient):
|
||||||
|
"""验证子图请求参数正确传递。"""
|
||||||
|
mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}})
|
||||||
|
with patch.object(client, "_get_client") as mock_get:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||||
|
mock_get.return_value = mock_http
|
||||||
|
|
||||||
|
_run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1"))
|
||||||
|
|
||||||
|
call_kwargs = mock_http.post.call_args
|
||||||
|
assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0]
|
||||||
|
assert call_kwargs.kwargs["params"] == {"depth": 3}
|
||||||
|
assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# headers 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeaders:
|
||||||
|
def test_headers_with_token_and_user(self, client: KGServiceClient):
|
||||||
|
headers = client._headers(user_id="user-1")
|
||||||
|
assert headers["X-Internal-Token"] == "test-token"
|
||||||
|
assert headers["X-User-Id"] == "user-1"
|
||||||
|
|
||||||
|
def test_headers_without_user(self, client: KGServiceClient):
|
||||||
|
headers = client._headers()
|
||||||
|
assert "X-Internal-Token" in headers
|
||||||
|
assert "X-User-Id" not in headers
|
||||||
|
|
||||||
|
def test_headers_without_token(self):
|
||||||
|
c = KGServiceClient(base_url="http://test:8080", internal_token="")
|
||||||
|
headers = c._headers(user_id="u1")
|
||||||
|
assert "X-Internal-Token" not in headers
|
||||||
|
assert headers["X-User-Id"] == "u1"
|
||||||
@@ -0,0 +1,145 @@
|
|||||||
|
"""Milvus 向量检索客户端的单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def retriever() -> MilvusVectorRetriever:
|
||||||
|
return MilvusVectorRetriever(
|
||||||
|
uri="http://test-milvus:19530",
|
||||||
|
embedding_model="text-embedding-test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# has_collection 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHasCollection:
|
||||||
|
def test_collection_exists(self, retriever: MilvusVectorRetriever):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.has_collection = MagicMock(return_value=True)
|
||||||
|
retriever._milvus_client = mock_client
|
||||||
|
|
||||||
|
result = _run(retriever.has_collection("my_collection"))
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_collection_not_exists(self, retriever: MilvusVectorRetriever):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.has_collection = MagicMock(return_value=False)
|
||||||
|
retriever._milvus_client = mock_client
|
||||||
|
|
||||||
|
result = _run(retriever.has_collection("nonexistent"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_fail_open_on_error(self, retriever: MilvusVectorRetriever):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.has_collection = MagicMock(side_effect=Exception("connection error"))
|
||||||
|
retriever._milvus_client = mock_client
|
||||||
|
|
||||||
|
result = _run(retriever.has_collection("test"))
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# search 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearch:
|
||||||
|
def test_successful_search(self, retriever: MilvusVectorRetriever):
|
||||||
|
"""正常搜索返回 VectorChunk 列表。"""
|
||||||
|
mock_embeddings = AsyncMock()
|
||||||
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||||
|
retriever._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
mock_milvus = MagicMock()
|
||||||
|
mock_milvus.search = MagicMock(return_value=[
|
||||||
|
[
|
||||||
|
{"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}},
|
||||||
|
{"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}},
|
||||||
|
]
|
||||||
|
])
|
||||||
|
retriever._milvus_client = mock_milvus
|
||||||
|
|
||||||
|
chunks = _run(retriever.search("my_collection", "用户数据", top_k=5))
|
||||||
|
|
||||||
|
assert len(chunks) == 2
|
||||||
|
assert chunks[0].id == "doc1"
|
||||||
|
assert chunks[0].text == "文档片段一"
|
||||||
|
assert chunks[0].score == 0.95
|
||||||
|
assert chunks[0].metadata == {"source": "kb1"}
|
||||||
|
assert chunks[1].id == "doc2"
|
||||||
|
assert chunks[1].score == 0.82
|
||||||
|
|
||||||
|
def test_empty_results(self, retriever: MilvusVectorRetriever):
|
||||||
|
mock_embeddings = AsyncMock()
|
||||||
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||||
|
retriever._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
mock_milvus = MagicMock()
|
||||||
|
mock_milvus.search = MagicMock(return_value=[[]])
|
||||||
|
retriever._milvus_client = mock_milvus
|
||||||
|
|
||||||
|
chunks = _run(retriever.search("col", "query"))
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever):
|
||||||
|
"""Embedding 失败时 fail-open 返回空列表。"""
|
||||||
|
mock_embeddings = AsyncMock()
|
||||||
|
mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error"))
|
||||||
|
retriever._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
chunks = _run(retriever.search("col", "query"))
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever):
|
||||||
|
"""Milvus 搜索失败时 fail-open 返回空列表。"""
|
||||||
|
mock_embeddings = AsyncMock()
|
||||||
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||||
|
retriever._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
mock_milvus = MagicMock()
|
||||||
|
mock_milvus.search = MagicMock(side_effect=Exception("Milvus down"))
|
||||||
|
retriever._milvus_client = mock_milvus
|
||||||
|
|
||||||
|
chunks = _run(retriever.search("col", "query"))
|
||||||
|
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever):
|
||||||
|
"""验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。"""
|
||||||
|
mock_embeddings = AsyncMock()
|
||||||
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||||
|
retriever._embeddings = mock_embeddings
|
||||||
|
|
||||||
|
mock_milvus = MagicMock()
|
||||||
|
mock_milvus.search = MagicMock(return_value=[[]])
|
||||||
|
retriever._milvus_client = mock_milvus
|
||||||
|
|
||||||
|
with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
|
||||||
|
mock_to_thread.return_value = [[]]
|
||||||
|
|
||||||
|
chunks = _run(retriever.search("col", "query"))
|
||||||
|
|
||||||
|
# asyncio.to_thread 应该被调用来包装同步 Milvus 调用
|
||||||
|
mock_to_thread.assert_called_once()
|
||||||
|
call_args = mock_to_thread.call_args
|
||||||
|
assert call_args.args[0] == mock_milvus.search
|
||||||
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""GraphRAG 检索编排器的单元测试。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_graphrag.models import (
|
||||||
|
EntitySummary,
|
||||||
|
RelationSummary,
|
||||||
|
RetrievalStrategy,
|
||||||
|
VectorChunk,
|
||||||
|
)
|
||||||
|
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def _run(coro):
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_retriever(
|
||||||
|
*,
|
||||||
|
milvus_search_result: list[VectorChunk] | None = None,
|
||||||
|
milvus_has_collection: bool = True,
|
||||||
|
kg_fulltext_result: list[EntitySummary] | None = None,
|
||||||
|
kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None,
|
||||||
|
) -> GraphRAGRetriever:
|
||||||
|
"""创建带 mock 依赖的 retriever。"""
|
||||||
|
mock_milvus = AsyncMock()
|
||||||
|
mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection)
|
||||||
|
mock_milvus.search = AsyncMock(return_value=milvus_search_result or [])
|
||||||
|
|
||||||
|
mock_kg = AsyncMock()
|
||||||
|
mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or [])
|
||||||
|
mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], []))
|
||||||
|
|
||||||
|
return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# retrieve 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieve:
|
||||||
|
"""retrieve 方法的测试。"""
|
||||||
|
|
||||||
|
def test_both_vector_and_graph(self):
|
||||||
|
"""同时启用向量和图谱检索。"""
|
||||||
|
chunks = [
|
||||||
|
VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9),
|
||||||
|
VectorChunk(id="c2", text="其他内容", score=0.7),
|
||||||
|
]
|
||||||
|
seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||||
|
entities = [
|
||||||
|
EntitySummary(id="e1", name="用户数据", type="Dataset"),
|
||||||
|
EntitySummary(id="e2", name="user_id", type="Field"),
|
||||||
|
]
|
||||||
|
relations = [
|
||||||
|
RelationSummary(
|
||||||
|
source_name="用户数据", source_type="Dataset",
|
||||||
|
target_name="user_id", target_type="Field",
|
||||||
|
relation_type="HAS_FIELD",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
retriever = _make_retriever(
|
||||||
|
milvus_search_result=chunks,
|
||||||
|
kg_fulltext_result=seed,
|
||||||
|
kg_subgraph_result=(entities, relations),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="用户数据有哪些字段",
|
||||||
|
collection_name="kb1",
|
||||||
|
graph_id="graph-1",
|
||||||
|
strategy=RetrievalStrategy(),
|
||||||
|
user_id="u1",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert len(ctx.vector_chunks) == 2
|
||||||
|
assert len(ctx.graph_context.entities) == 2
|
||||||
|
assert len(ctx.graph_context.relations) == 1
|
||||||
|
assert "用户数据" in ctx.graph_context.textualized
|
||||||
|
assert "## 相关文档" in ctx.merged_text
|
||||||
|
assert "## 知识图谱上下文" in ctx.merged_text
|
||||||
|
|
||||||
|
def test_vector_only(self):
|
||||||
|
"""仅启用向量检索。"""
|
||||||
|
chunks = [VectorChunk(id="c1", text="doc", score=0.9)]
|
||||||
|
retriever = _make_retriever(milvus_search_result=chunks)
|
||||||
|
strategy = RetrievalStrategy(enable_graph=False)
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="test", collection_name="kb", graph_id="g",
|
||||||
|
strategy=strategy, user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert len(ctx.vector_chunks) == 1
|
||||||
|
assert ctx.graph_context.entities == []
|
||||||
|
# KG client should not be called
|
||||||
|
retriever._kg.fulltext_search.assert_not_called()
|
||||||
|
|
||||||
|
def test_graph_only(self):
|
||||||
|
"""仅启用图谱检索。"""
|
||||||
|
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||||
|
entities = [EntitySummary(id="e1", name="A", type="T")]
|
||||||
|
retriever = _make_retriever(
|
||||||
|
kg_fulltext_result=seed,
|
||||||
|
kg_subgraph_result=(entities, []),
|
||||||
|
)
|
||||||
|
strategy = RetrievalStrategy(enable_vector=False)
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="test", collection_name="kb", graph_id="g",
|
||||||
|
strategy=strategy, user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert ctx.vector_chunks == []
|
||||||
|
assert len(ctx.graph_context.entities) == 1
|
||||||
|
retriever._milvus.search.assert_not_called()
|
||||||
|
|
||||||
|
def test_no_seed_entities(self):
|
||||||
|
"""图谱全文检索无结果时,不调用子图查询。"""
|
||||||
|
retriever = _make_retriever(kg_fulltext_result=[])
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="test", collection_name="kb", graph_id="g",
|
||||||
|
strategy=RetrievalStrategy(enable_vector=False), user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert ctx.graph_context.entities == []
|
||||||
|
retriever._kg.get_subgraph.assert_not_called()
|
||||||
|
|
||||||
|
def test_collection_not_found_skips_vector(self):
|
||||||
|
"""collection 不存在时跳过向量检索。"""
|
||||||
|
retriever = _make_retriever(milvus_has_collection=False)
|
||||||
|
strategy = RetrievalStrategy(enable_graph=False)
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="test", collection_name="nonexistent", graph_id="g",
|
||||||
|
strategy=strategy, user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert ctx.vector_chunks == []
|
||||||
|
retriever._milvus.search.assert_not_called()
|
||||||
|
|
||||||
|
def test_both_empty(self):
|
||||||
|
"""两条检索路径都无结果。"""
|
||||||
|
retriever = _make_retriever()
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="nothing", collection_name="kb", graph_id="g",
|
||||||
|
strategy=RetrievalStrategy(), user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert ctx.vector_chunks == []
|
||||||
|
assert ctx.graph_context.entities == []
|
||||||
|
assert "未检索到相关上下文信息" in ctx.merged_text
|
||||||
|
|
||||||
|
def test_vector_error_fail_open(self):
|
||||||
|
"""向量检索异常时 fail-open,图谱检索仍可正常返回。"""
|
||||||
|
retriever = _make_retriever()
|
||||||
|
retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down"))
|
||||||
|
|
||||||
|
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||||
|
retriever._kg.fulltext_search = AsyncMock(return_value=seed)
|
||||||
|
retriever._kg.get_subgraph = AsyncMock(
|
||||||
|
return_value=([EntitySummary(id="e1", name="A", type="T")], [])
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = _run(retriever.retrieve(
|
||||||
|
query="test", collection_name="kb", graph_id="g",
|
||||||
|
strategy=RetrievalStrategy(), user_id="u",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 向量检索失败,但图谱检索仍有结果
|
||||||
|
assert ctx.vector_chunks == []
|
||||||
|
assert len(ctx.graph_context.entities) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _rank_results 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRankResults:
|
||||||
|
"""_rank_results 方法的测试。"""
|
||||||
|
|
||||||
|
def _make_retriever_instance(self) -> GraphRAGRetriever:
|
||||||
|
return GraphRAGRetriever(
|
||||||
|
milvus_client=MagicMock(),
|
||||||
|
kg_client=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_empty_chunks(self):
|
||||||
|
r = self._make_retriever_instance()
|
||||||
|
result = r._rank_results([], [], [], RetrievalStrategy())
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_single_chunk(self):
|
||||||
|
r = self._make_retriever_instance()
|
||||||
|
chunks = [VectorChunk(id="1", text="text", score=0.9)]
|
||||||
|
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].id == "1"
|
||||||
|
|
||||||
|
def test_graph_boost_reorders(self):
|
||||||
|
"""图谱实体命中应提升文档片段排名。"""
|
||||||
|
r = self._make_retriever_instance()
|
||||||
|
# chunk1 向量分高但无图谱命中
|
||||||
|
# chunk2 向量分低但命中图谱实体
|
||||||
|
chunks = [
|
||||||
|
VectorChunk(id="1", text="无关内容", score=0.9),
|
||||||
|
VectorChunk(id="2", text="包含用户数据的内容", score=0.5),
|
||||||
|
]
|
||||||
|
entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||||
|
strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7)
|
||||||
|
|
||||||
|
result = r._rank_results(chunks, entities, [], strategy)
|
||||||
|
|
||||||
|
# chunk2 应该排在前面(graph_boost 更高)
|
||||||
|
assert result[0].id == "2"
|
||||||
|
|
||||||
|
def test_all_same_score(self):
|
||||||
|
"""所有 chunk 分数相同时不崩溃。"""
|
||||||
|
r = self._make_retriever_instance()
|
||||||
|
chunks = [
|
||||||
|
VectorChunk(id="1", text="a", score=0.5),
|
||||||
|
VectorChunk(id="2", text="b", score=0.5),
|
||||||
|
]
|
||||||
|
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||||
|
assert len(result) == 2
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user