Compare commits

..

8 Commits

Author SHA1 Message Date
75db6daeb5 feat(kg): 实现查询阶段的用户数据权限过滤
新增功能:
- 查询阶段权限过滤:管理员看全部,普通用户只看自己创建的数据
- 结构实体(User、Org、Field)对所有用户可见
- 业务实体(Dataset、Workflow、Job、LabelTask、KnowledgeSet)按 created_by 过滤
- CONFIDENTIAL 敏感度过滤:需要特定权限才能查看

安全修复(四轮迭代):
P1-1: CONFIDENTIAL 敏感度过滤
- 4 个查询入口统一计算 excludeConfidential
- assertEntityAccess / isEntityAccessible 新增保密数据检查
- buildPermissionPredicate 在 Cypher 中追加 sensitivity 条件

P1-2: 结构实体按类型白名单判定
- 新增常量 STRUCTURAL_ENTITY_TYPES = Set.of("User", "Org", "Field")
- 业务实体必须匹配 created_by(缺失则拒绝)
- Cypher 从 IS NULL OR 改为 type IN ['User', 'Org', 'Field'] OR

P2-1: getNeighborGraph 路径级权限旁路
- 改为 ALL(n IN nodes(p) WHERE ...) 路径全节点过滤
- 与 getShortestPath 保持一致

P2-2: CONFIDENTIAL 大小写归一化
- Cypher 用 toUpper(trim(...)) 比较
- Java 用 equalsIgnoreCase
- 与 data-management-service 保持一致

权限模型:
- 同步阶段:全量同步(保持图谱完整性)
- 查询阶段:根据用户权限过滤结果
- 使用 RequestUserContextHolder 和 ResourceAccessService

代码变更:+642 行,-32 行
测试结果:130 tests, 0 failures
新增 9 个测试用例

已知 P3 问题(非阻断,可后续优化):
- 组件扫描范围偏大
- 测试质量可进一步增强
- 结构实体白名单重复维护
2026-02-18 12:24:09 +08:00
ebb4548ca5 feat(kg): 补全知识图谱实体同步和关系构建
新增功能:
- 补全 4 类实体同步:Workflow、Job、LabelTask、KnowledgeSet
- 补全 7 类关系构建:USES_DATASET、PRODUCES、ASSIGNED_TO、TRIGGERS、DEPENDS_ON、IMPACTS、SOURCED_FROM
- 新增 39 个测试用例,总计 111 个测试

问题修复(三轮迭代):
第一轮(6 个问题):
- toStringList null/blank 过滤
- mergeUsesDatasetRelations 统一逻辑
- fetchAllPaged 去重抽取
- IMPACTS 占位标记
- 测试断言增强
- syncAll 固定索引改 Map

第二轮(2 个问题):
- 活跃 ID 空值/空白归一化(两层防御)
- 关系构建 N+1 查询消除(预加载 Map)

第三轮(1 个问题):
- 空元素 NPE 防御(GraphSyncService 12 处 + GraphSyncStepService 6 处)

代码变更:+1936 行,-101 行
测试结果:111 tests, 0 failures

已知 P3 问题(非阻塞):
- 安全注释与实现不一致(待权限过滤任务一起处理)
- 测试覆盖缺口(可后续补充)
2026-02-18 11:30:38 +08:00
37b478a052 fix(kg): 修复 Codex 审查发现的 P1/P2 问题并补全测试
修复内容:

P1 级别(关键):
1. 数据隔离漏洞:邻居查询添加 graph_id 路径约束,防止跨图谱数据泄漏
2. 空快照误删风险:添加 allowPurgeOnEmptySnapshot 保护开关(默认 false)
3. 弱默认凭据:启动自检,生产环境检测到默认密码直接拒绝启动

P2 级别(重要):
4. 配置校验:importBatchSize 添加 @Min(1) 验证,启动时 fail-fast
5. N+1 性能:重写 upsertEntity 为单条 Cypher 查询(从 3 条优化到 1 条)
6. 服务认证:添加 mTLS/JWT 文档说明
7. 错误处理:改进 Schema 初始化和序列化错误处理

测试覆盖:
- 新增 69 个单元测试,全部通过
- GraphEntityServiceTest: 13 个测试(CRUD、验证、分页)
- GraphRelationServiceTest: 13 个测试(CRUD、方向验证)
- GraphSyncServiceTest: 5 个测试(验证、全量同步)
- GraphSyncStepServiceTest: 14 个测试(空快照保护、N+1 验证)
- GraphQueryServiceTest: 13 个测试(邻居/路径/子图/搜索)
- GraphInitializerTest: 11 个测试(凭据验证、Schema 初始化)

技术细节:
- 数据隔离:使用 ALL() 函数约束路径中所有节点和关系的 graph_id
- 空快照保护:新增配置项 allow-purge-on-empty-snapshot 和错误码 EMPTY_SNAPSHOT_PURGE_BLOCKED
- 凭据检查:Java 和 Python 双端实现,根据环境(dev/test/prod)采取不同策略
- 性能优化:使用 SDN 复合属性格式(properties.key)在 MERGE 中直接设置属性
- 属性安全:使用白名单 [a-zA-Z0-9_] 防止 Cypher 注入

代码变更:+210 行,-29 行
2026-02-18 09:25:00 +08:00
a260134d7c fix(knowledge-graph): 修复 Codex 审查发现的 5 个问题并新增查询功能
本次提交包含两部分内容:
1. 新增知识图谱查询功能(邻居查询、最短路径、子图提取、全文搜索)
2. 修复 Codex 代码审查发现的 5 个问题(3 个 P1 严重问题 + 2 个 P2 次要问题)

## 新增功能

### GraphQueryService 和 GraphQueryController
- 邻居查询:GET /query/neighbors/{entityId}?depth=2&limit=50
- 最短路径:GET /query/shortest-path?sourceId=...&targetId=...&maxDepth=3
- 子图提取:POST /query/subgraph + body {"entityIds": [...]}
- 全文搜索:GET /query/search?q=keyword&page=0&size=20

### 新增 DTO
- EntitySummaryVO, EdgeSummaryVO:实体和边的摘要信息
- SubgraphVO:子图结果(nodes + edges + counts)
- PathVO:路径结果
- SearchHitVO:搜索结果(含相关度分数)
- SubgraphRequest:子图请求 DTO(含校验)

## 问题修复

### P1-1: 邻居查询图边界风险
**文件**: GraphQueryService.java
**问题**: getNeighborGraph 使用 -[*1..N]-,未约束中间路径节点/关系的 graph_id
**修复**:
- 使用路径变量 p:MATCH p = ...
- 添加 ALL(n IN nodes(p) WHERE n.graph_id = $graphId)
- 添加 ALL(r IN relationships(p) WHERE r.graph_id = $graphId)
- 限定关系类型为 :RELATED_TO
- 排除自环:WHERE e <> neighbor

### P1-2: 全图扫描性能风险
**文件**: GraphRelationRepository.java
**问题**: findByEntityId/countByEntityId 先匹配全图关系,再用 s.id = X OR t.id = X 过滤
**修复**:
- findByEntityId:改为 CALL { 出边锚定查询 UNION ALL 入边锚定查询 }
- countByEntityId:
  - "in"/"out" 方向:将 id: $entityId 直接写入 MATCH 模式
  - "all" 方向:改为 CALL { 出边 UNION 入边 } RETURN count(r)
- 利用 (graph_id, id) 索引直接定位,避免全图扫描

### P1-3: 接口破坏性变更
**文件**: GraphEntityController.java
**问题**: GET /knowledge-graph/{graphId}/entities 从 List<GraphEntity> 变为 PagedResponse<GraphEntity>
**修复**: 使用 Spring MVC params 属性实现零破坏性升级
- @GetMapping(params = "!page"):无 page 参数时返回 List(向后兼容)
- @GetMapping(params = "page"):有 page 参数时返回 PagedResponse(新功能)
- 现有调用方无需改动,新调用方可选择分页

### P2-4: direction 参数未严格校验
**文件**: GraphEntityController.java, GraphRelationService.java
**问题**: 非法 direction 值被静默当作 "all" 处理
**修复**: 双层校验
- Controller 层:@Pattern(regexp = "^(all|in|out)$")
- Service 层:VALID_DIRECTIONS.contains() 校验
- 非法值返回 INVALID_PARAMETER 异常

### P2-5: 子图接口请求体缺少元素级校验
**文件**: GraphQueryController.java, SubgraphRequest.java
**问题**: /query/subgraph 直接接收 List<String>,无 UUID 校验
**修复**: 创建 SubgraphRequest DTO
- @NotEmpty:列表不能为空
- @Size(max = 500):元素数量上限
- List<@Pattern(UUID) String>:每个元素必须是合法 UUID
- Controller 使用 @Valid @RequestBody SubgraphRequest
- ⚠️ API 变更:请求体格式从 ["uuid1"] 变为 {"entityIds": ["uuid1"]}

## 技术亮点

1. **图边界安全**: 路径变量 + ALL 约束确保跨图查询安全
2. **查询性能**: 实体锚定查询替代全图扫描,利用索引优化
3. **向后兼容**: params 属性实现同路径双端点,零破坏性升级
4. **多层防御**: Controller + Service 双层校验,框架级 + 业务级
5. **类型安全**: DTO + Bean Validation 确保请求体格式和内容合法

## 测试建议

1. 编译验证:mvn -pl services/knowledge-graph-service -am compile
2. 测试邻居查询的图边界约束
3. 测试实体关系查询的性能(大数据集)
4. 验证实体列表接口的向后兼容性(无 page 参数)
5. 测试 direction 参数的非法值拒绝
6. 测试子图接口的请求体校验(非法 UUID、空列表、超限)

Co-authored-by: Claude (Anthropic)
Reviewed-by: Codex (OpenAI)
2026-02-18 07:49:16 +08:00
8b1ab8ff36 feat(kg-sync): 实现图谱构建流程(MySQL → Neo4j 同步)
实现功能:
- 实现 GraphSyncService(同步编排器)
- 实现 GraphSyncStepService(同步步骤执行器)
- 实现 GraphSyncController(同步 API)
- 实现 GraphInitializer(图谱初始化)
- 实现 DataManagementClient(数据源客户端)

同步功能:
- syncDatasets:同步数据集实体
- syncFields:同步字段实体
- syncUsers:同步用户实体
- syncOrgs:同步组织实体
- buildHasFieldRelations:构建 HAS_FIELD 关系
- buildDerivedFromRelations:构建 DERIVED_FROM 关系
- buildBelongsToRelations:构建 BELONGS_TO 关系
- syncAll:全量同步(实体 + 关系 + 对账删除)

API 端点:
- POST /{graphId}/sync/full:全量同步
- POST /{graphId}/sync/datasets:同步数据集
- POST /{graphId}/sync/fields:同步字段
- POST /{graphId}/sync/users:同步用户
- POST /{graphId}/sync/orgs:同步组织
- POST /{graphId}/sync/relations/has-field:构建 HAS_FIELD
- POST /{graphId}/sync/relations/derived-from:构建 DERIVED_FROM
- POST /{graphId}/sync/relations/belongs-to:构建 BELONGS_TO

技术实现:
- Upsert 策略:
  - 实体:两阶段(Cypher MERGE 原子创建 + SDN save 更新扩展属性)
  - 关系:Cypher MERGE 幂等创建
- 全量对账删除:purgeStaleEntities() 删除 MySQL 中已删除的实体
- 并发安全:
  - 图级互斥锁(ConcurrentHashMap<String, ReentrantLock>)
  - 复合唯一约束(graph_id, source_id, type)
  - 锁自动回收(releaseLock() 原子检查并移除空闲锁)
- 重试机制:HTTP 调用失败时按指数退避重试(默认 3 次)
- 错误处理:
  - 逐条错误处理(单条失败不影响其他记录)
  - 统一异常包装(BusinessException.of(SYNC_FAILED))
  - 错误信息脱敏(仅返回 errorCount + syncId)
- 事务管理:
  - GraphSyncService(编排器,无事务)
  - GraphSyncStepService(步骤执行器,@Transactional)
- 性能优化:
  - 全量同步共享数据快照
  - 批量日志跟踪
- 图谱初始化:
  - 1 个唯一性约束(entity ID)
  - 1 个复合唯一约束(graph_id, source_id, type)
  - 9 个索引(5 个单字段 + 3 个复合 + 1 个全文)
  - 幂等性保证(IF NOT EXISTS)

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有问题已解决(3个P0 + 5个P1 + 3个P2 + 1个P3)
- 编译验证通过(mvn compile SUCCESS)

设计决策:
- 最终一致性:允许短暂的数据不一致
- 对账机制:定期对比并修复差异
- 信任边界:网关负责鉴权,服务层只做格式校验
- 多实例部署:依赖复合唯一约束兜底
2026-02-17 23:46:03 +08:00
910251e898 feat(kg-relation): 实现 Java 关系(Relation)功能
实现功能:
- 实现 GraphRelationRepository(Neo4jClient + Cypher)
- 实现 GraphRelationService(业务逻辑层)
- 实现 GraphRelationController(REST API)
- 新增 RelationDetail 领域对象
- 新增 RelationVO、UpdateRelationRequest DTO

API 端点:
- POST /{graphId}/relations:创建关系(201)
- GET /{graphId}/relations:分页列表查询(支持 type/page/size)
- GET /{graphId}/relations/{relationId}:单个查询
- PUT /{graphId}/relations/{relationId}:更新关系
- DELETE /{graphId}/relations/{relationId}:删除关系(204)

技术实现:
- Repository:
  - 使用 Neo4jClient + Cypher 实现 CRUD
  - 使用 bindAll(Map) 一次性绑定参数
  - properties 字段使用 JSON 序列化存储
  - 支持分页查询(SKIP/LIMIT)
  - 支持类型过滤
- Service:
  - graphId UUID 格式校验
  - 实体存在性校验
  - @Transactional 事务管理
  - 信任边界说明(网关负责鉴权)
  - 分页 skip 使用 long 计算,上限保护 100,000
- Controller:
  - 所有 pathVariable 添加 UUID pattern 校验
  - 使用 @Validated 启用参数校验
  - 使用平台统一的 PagedResponse 分页响应
- DTO:
  - weight/confidence 添加 @DecimalMin/@DecimalMax(0.0-1.0)
  - relationType 添加 @Size(1-50)
  - sourceEntityId/targetEntityId 添加 UUID pattern 校验

架构设计:
- 分层清晰:interfaces → application → domain
- Repository 返回领域对象 RelationDetail
- DTO 转换在 Service 层
- 关系类型:Neo4j 使用统一 RELATED_TO 标签,语义类型存储在 relation_type 属性

代码审查:
- 经过 2 轮 Codex 审查和 1 轮 Claude 修复
- 所有问题已解决(2个P0 + 2个P1 + 4个P2)
- 编译验证通过(mvn compile SUCCESS)

设计决策:
- 使用 Neo4jClient 而非 Neo4jRepository(@RelationshipProperties 限制)
- 分页 size 上限 200,防止大查询
- properties 使用 JSON 序列化,支持灵活扩展
- 复用现有错误码(ENTITY_NOT_FOUND、RELATION_NOT_FOUND、INVALID_RELATION)
2026-02-17 22:40:27 +08:00
0e0782a452 feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口
实现功能:
- 创建 kg_extraction/interface.py(FastAPI 路由)
- 实现 POST /api/kg/extract(单条文本抽取)
- 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条)
- 集成到 FastAPI 主路由(/api/kg/ 前缀)

技术实现:
- 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature)
- 安全性:
  - API Key 使用 SecretStr 保护
  - 错误信息脱敏(使用 trace_id,不暴露原始异常)
  - 请求文本不写入日志(使用 SHA-256 hash)
  - 强制要求 X-User-Id 头(鉴权边界)
- 超时控制:
  - kg_llm_timeout_seconds(60秒)
  - kg_llm_max_retries(2次)
- 输入校验:
  - graph_id 和 source_id 使用 UUID pattern
  - source_type 使用 Enum(4个值)
  - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符)
- 审计日志:记录 caller、trace_id、text_hash

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有问题已解决(5个 P1/P2 + 3个 P3)
- 语法检查通过

API 端点:
- POST /api/kg/extract:单条文本抽取
- POST /api/kg/extract/batch:批量抽取(最多 50 条)

配置环境变量:
- KG_LLM_API_KEY:LLM API 密钥
- KG_LLM_BASE_URL:自定义端点(可选)
- KG_LLM_MODEL:模型名称(默认 gpt-4o-mini)
- KG_LLM_TEMPERATURE:生成温度(默认 0.0)
- KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60)
- KG_LLM_MAX_RETRIES:重试次数(默认 2)
2026-02-17 22:01:06 +08:00
5a553ddde3 feat(knowledge-graph): 实现知识图谱基础设施搭建
实现功能:
- Neo4j Docker Compose 配置(社区版,端口 7474/7687,数据持久化)
- Makefile 新增 Neo4j 命令(neo4j-up/down/logs/shell)
- knowledge-graph-service Spring Boot 服务(完整的 DDD 分层架构)
- kg_extraction Python 模块(基于 LangChain LLMGraphTransformer)

技术实现:
- Neo4j 配置:环境变量化密码,统一默认值 datamate123
- Java 服务:
  - Domain: GraphEntity, GraphRelation 实体模型
  - Repository: Spring Data Neo4j,支持 graphId 范围查询
  - Service: 业务逻辑,graphId 双重校验,查询限流
  - Controller: REST API,UUID 格式校验
  - Exception: 实现 ErrorCode 接口,统一异常体系
- Python 模块:
  - KnowledgeGraphExtractor 类
  - 支持异步/同步/批量抽取
  - 支持 schema-guided 模式
  - 兼容 OpenAI 及自部署模型

关键设计:
- graphId 权限边界:所有实体操作都在正确的 graphId 范围内
- 查询限流:depth 和 limit 参数受配置约束
- 异常处理:统一使用 BusinessException + ErrorCode
- 凭据管理:环境变量化,避免硬编码
- 双重防御:Controller 格式校验 + Service 业务校验

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有 P0 和 P1 问题已解决
- 编译通过,无阻塞性问题

文件变更:
- 新增:Neo4j 配置、knowledge-graph-service(11 个 Java 文件)、kg_extraction(3 个 Python 文件)
- 修改:Makefile、pom.xml、application.yml、pyproject.toml
2026-02-17 20:42:55 +08:00
52 changed files with 8148 additions and 2 deletions

View File

@@ -76,6 +76,12 @@ help:
@echo " make download SAVE=true PLATFORM=linux/arm64 Save ARM64 images"
@echo " make load-images Load all downloaded images from dist/"
@echo ""
@echo "Neo4j Commands:"
@echo " make neo4j-up Start Neo4j graph database"
@echo " make neo4j-down Stop Neo4j graph database"
@echo " make neo4j-logs View Neo4j logs"
@echo " make neo4j-shell Open Neo4j Cypher shell"
@echo ""
@echo "Utility Commands:"
@echo " make create-namespace Create Kubernetes namespace"
@echo " make help Show this help message"
@@ -498,3 +504,25 @@ load-images:
else \
echo "Successfully loaded $$count image(s)"; \
fi
# ========== Neo4j Targets ==========
.PHONY: neo4j-up
neo4j-up:
@echo "Starting Neo4j graph database..."
docker compose -f deployment/docker/neo4j/docker-compose.yml up -d
@echo "Neo4j Browser: http://localhost:7474"
@echo "Bolt URI: bolt://localhost:7687"
.PHONY: neo4j-down
neo4j-down:
@echo "Stopping Neo4j graph database..."
docker compose -f deployment/docker/neo4j/docker-compose.yml down
.PHONY: neo4j-logs
neo4j-logs:
docker compose -f deployment/docker/neo4j/docker-compose.yml logs -f
.PHONY: neo4j-shell
neo4j-shell:
docker exec -it datamate-neo4j cypher-shell -u neo4j -p "$${NEO4J_PASSWORD:-datamate123}"

View File

@@ -0,0 +1,114 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.datamate</groupId>
<artifactId>services</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>knowledge-graph-service</artifactId>
<name>Knowledge Graph Service</name>
<description>知识图谱服务 - 基于Neo4j的实体关系管理与图谱查询</description>
<dependencies>
<dependency>
<groupId>com.datamate</groupId>
<artifactId>domain-common</artifactId>
<version>${project.version}</version>
</dependency>
<!-- Spring Data Neo4j -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-neo4j</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>${mysql.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springdoc</groupId>
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
</dependency>
<dependency>
<groupId>org.openapitools</groupId>
<artifactId>jackson-databind-nullable</artifactId>
</dependency>
<dependency>
<groupId>jakarta.validation</groupId>
<artifactId>jakarta.validation-api</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<arguments>true</arguments>
<classifier>exec</classifier>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>${maven.compiler.source}</source>
<target>${maven.compiler.target}</target>
<annotationProcessorPaths>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
</path>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok-mapstruct-binding</artifactId>
<version>${lombok-mapstruct-binding.version}</version>
</path>
<path>
<groupId>org.mapstruct</groupId>
<artifactId>mapstruct-processor</artifactId>
<version>${mapstruct.version}</version>
</path>
</annotationProcessorPaths>
<compilerArgs>
<arg>-parameters</arg>
<arg>-Amapstruct.defaultComponentModel=spring</arg>
</compilerArgs>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.5</version>
</plugin>
</plugins>
</build>
</project>

View File

@@ -0,0 +1,28 @@
package com.datamate.knowledgegraph;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.neo4j.repository.config.EnableNeo4jRepositories;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.web.client.RestTemplate;
import java.time.Duration;
@Configuration
@ComponentScan(basePackages = {"com.datamate.knowledgegraph", "com.datamate.common.auth"})
@EnableNeo4jRepositories(basePackages = "com.datamate.knowledgegraph.domain.repository")
@EnableScheduling
public class KnowledgeGraphServiceConfiguration {
@Bean("kgRestTemplate")
public RestTemplate kgRestTemplate(RestTemplateBuilder builder, KnowledgeGraphProperties properties) {
KnowledgeGraphProperties.Sync syncConfig = properties.getSync();
return builder
.connectTimeout(Duration.ofMillis(syncConfig.getConnectTimeout()))
.readTimeout(Duration.ofMillis(syncConfig.getReadTimeout()))
.build();
}
}

View File

@@ -0,0 +1,170 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.time.LocalDateTime;
import java.util.List;
import java.util.regex.Pattern;
@Service
@Slf4j
@RequiredArgsConstructor
public class GraphEntityService {
/** 分页偏移量上限,防止深翻页导致 Neo4j 性能退化。 */
private static final long MAX_SKIP = 100_000L;
private static final Pattern UUID_PATTERN = Pattern.compile(
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
);
private final GraphEntityRepository entityRepository;
private final KnowledgeGraphProperties properties;
@Transactional
public GraphEntity createEntity(String graphId, CreateEntityRequest request) {
validateGraphId(graphId);
GraphEntity entity = GraphEntity.builder()
.name(request.getName())
.type(request.getType())
.description(request.getDescription())
.aliases(request.getAliases())
.properties(request.getProperties())
.sourceId(request.getSourceId())
.sourceType(request.getSourceType())
.graphId(graphId)
.confidence(request.getConfidence() != null ? request.getConfidence() : 1.0)
.createdAt(LocalDateTime.now())
.updatedAt(LocalDateTime.now())
.build();
return entityRepository.save(entity);
}
public GraphEntity getEntity(String graphId, String entityId) {
validateGraphId(graphId);
return entityRepository.findByIdAndGraphId(entityId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.ENTITY_NOT_FOUND));
}
public List<GraphEntity> listEntities(String graphId) {
validateGraphId(graphId);
return entityRepository.findByGraphId(graphId);
}
public List<GraphEntity> searchEntities(String graphId, String name) {
validateGraphId(graphId);
return entityRepository.findByGraphIdAndNameContaining(graphId, name);
}
public List<GraphEntity> listEntitiesByType(String graphId, String type) {
validateGraphId(graphId);
return entityRepository.findByGraphIdAndType(graphId, type);
}
// -----------------------------------------------------------------------
// 分页查询
// -----------------------------------------------------------------------
public PagedResponse<GraphEntity> listEntitiesPaged(String graphId, int page, int size) {
validateGraphId(graphId);
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
List<GraphEntity> entities = entityRepository.findByGraphIdPaged(graphId, skip, safeSize);
long total = entityRepository.countByGraphId(graphId);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
return PagedResponse.of(entities, safePage, total, totalPages);
}
public PagedResponse<GraphEntity> listEntitiesByTypePaged(String graphId, String type, int page, int size) {
validateGraphId(graphId);
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
List<GraphEntity> entities = entityRepository.findByGraphIdAndTypePaged(graphId, type, skip, safeSize);
long total = entityRepository.countByGraphIdAndType(graphId, type);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
return PagedResponse.of(entities, safePage, total, totalPages);
}
public PagedResponse<GraphEntity> searchEntitiesPaged(String graphId, String keyword, int page, int size) {
validateGraphId(graphId);
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
List<GraphEntity> entities = entityRepository.findByGraphIdAndNameContainingPaged(graphId, keyword, skip, safeSize);
long total = entityRepository.countByGraphIdAndNameContaining(graphId, keyword);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
return PagedResponse.of(entities, safePage, total, totalPages);
}
@Transactional
public GraphEntity updateEntity(String graphId, String entityId, UpdateEntityRequest request) {
validateGraphId(graphId);
GraphEntity entity = getEntity(graphId, entityId);
if (request.getName() != null) {
entity.setName(request.getName());
}
if (request.getDescription() != null) {
entity.setDescription(request.getDescription());
}
if (request.getAliases() != null) {
entity.setAliases(request.getAliases());
}
if (request.getProperties() != null) {
entity.setProperties(request.getProperties());
}
entity.setUpdatedAt(LocalDateTime.now());
return entityRepository.save(entity);
}
@Transactional
public void deleteEntity(String graphId, String entityId) {
validateGraphId(graphId);
GraphEntity entity = getEntity(graphId, entityId);
entityRepository.delete(entity);
}
public List<GraphEntity> getNeighbors(String graphId, String entityId, int depth, int limit) {
validateGraphId(graphId);
int clampedDepth = Math.max(1, Math.min(depth, properties.getMaxDepth()));
int clampedLimit = Math.max(1, Math.min(limit, properties.getMaxNodesPerQuery()));
return entityRepository.findNeighbors(graphId, entityId, clampedDepth, clampedLimit);
}
public long countEntities(String graphId) {
validateGraphId(graphId);
return entityRepository.countByGraphId(graphId);
}
/**
* 校验 graphId 格式(UUID)。
* 防止恶意构造的 graphId 注入 Cypher 查询。
*/
private void validateGraphId(String graphId) {
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
}
}
}

View File

@@ -0,0 +1,589 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import com.datamate.knowledgegraph.interfaces.dto.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.Value;
import org.neo4j.driver.types.MapAccessor;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.regex.Pattern;
/**
* 知识图谱查询服务。
* <p>
* 提供图遍历(N 跳邻居、最短路径、子图提取)和全文搜索功能。
* 使用 {@link Neo4jClient} 执行复杂 Cypher 查询。
* <p>
* 查询结果根据用户权限进行过滤:
* <ul>
* <li>管理员:不过滤,看到全部数据</li>
* <li>普通用户:按 {@code created_by} 过滤,只看到自己创建的业务实体;
* 结构型实体(User、Org、Field 等无 created_by 的实体)对所有用户可见</li>
* </ul>
*/
@Service
@Slf4j
@RequiredArgsConstructor
public class GraphQueryService {
private static final String REL_TYPE = "RELATED_TO";
private static final long MAX_SKIP = 100_000L;
/** 结构型实体类型白名单:对所有用户可见,不按 created_by 过滤 */
private static final Set<String> STRUCTURAL_ENTITY_TYPES = Set.of("User", "Org", "Field");
private static final Pattern UUID_PATTERN = Pattern.compile(
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
);
private final Neo4jClient neo4jClient;
private final GraphEntityRepository entityRepository;
private final KnowledgeGraphProperties properties;
private final ResourceAccessService resourceAccessService;
// -----------------------------------------------------------------------
// N 跳邻居
// -----------------------------------------------------------------------
/**
* 查询实体的 N 跳邻居,返回邻居节点和连接边。
*
* @param depth 跳数(1-3,由配置上限约束)
* @param limit 返回节点数上限
*/
public SubgraphVO getNeighborGraph(String graphId, String entityId, int depth, int limit) {
validateGraphId(graphId);
String filterUserId = resolveOwnerFilter();
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
// 校验实体存在 + 权限
GraphEntity startEntity = entityRepository.findByIdAndGraphId(entityId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.ENTITY_NOT_FOUND));
if (filterUserId != null) {
assertEntityAccess(startEntity, filterUserId, excludeConfidential);
}
int clampedDepth = Math.max(1, Math.min(depth, properties.getMaxDepth()));
int clampedLimit = Math.max(1, Math.min(limit, properties.getMaxNodesPerQuery()));
// 路径级全节点权限过滤(与 getShortestPath 一致)
String permFilter = "";
if (filterUserId != null) {
StringBuilder pf = new StringBuilder("AND ALL(n IN nodes(p) WHERE ");
pf.append("(n.type IN ['User', 'Org', 'Field'] OR n.`properties.created_by` = $filterUserId)");
if (excludeConfidential) {
pf.append(" AND (toUpper(trim(n.`properties.sensitivity`)) IS NULL OR toUpper(trim(n.`properties.sensitivity`)) <> 'CONFIDENTIAL')");
}
pf.append(") ");
permFilter = pf.toString();
}
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("entityId", entityId);
params.put("limit", clampedLimit);
if (filterUserId != null) {
params.put("filterUserId", filterUserId);
}
// 查询邻居节点(路径变量约束中间节点与关系均属于同一图谱,权限过滤覆盖路径全节点)
List<EntitySummaryVO> nodes = neo4jClient
.query(
"MATCH p = (e:Entity {graph_id: $graphId, id: $entityId})" +
"-[:" + REL_TYPE + "*1.." + clampedDepth + "]-(neighbor:Entity) " +
"WHERE e <> neighbor " +
" AND ALL(n IN nodes(p) WHERE n.graph_id = $graphId) " +
" AND ALL(r IN relationships(p) WHERE r.graph_id = $graphId) " +
permFilter +
"WITH DISTINCT neighbor LIMIT $limit " +
"RETURN neighbor.id AS id, neighbor.name AS name, neighbor.type AS type, " +
"neighbor.description AS description"
)
.bindAll(params)
.fetchAs(EntitySummaryVO.class)
.mappedBy((ts, record) -> EntitySummaryVO.builder()
.id(record.get("id").asString(null))
.name(record.get("name").asString(null))
.type(record.get("type").asString(null))
.description(record.get("description").asString(null))
.build())
.all()
.stream().toList();
// 收集所有节点 ID(包括起始节点)
Set<String> nodeIds = new LinkedHashSet<>();
nodeIds.add(entityId);
nodes.forEach(n -> nodeIds.add(n.getId()));
// 查询这些节点之间的边
List<EdgeSummaryVO> edges = queryEdgesBetween(graphId, new ArrayList<>(nodeIds));
// 将起始节点加入节点列表
List<EntitySummaryVO> allNodes = new ArrayList<>();
allNodes.add(EntitySummaryVO.builder()
.id(startEntity.getId())
.name(startEntity.getName())
.type(startEntity.getType())
.description(startEntity.getDescription())
.build());
allNodes.addAll(nodes);
return SubgraphVO.builder()
.nodes(allNodes)
.edges(edges)
.nodeCount(allNodes.size())
.edgeCount(edges.size())
.build();
}
// -----------------------------------------------------------------------
// 最短路径
// -----------------------------------------------------------------------
/**
* 查询两个实体之间的最短路径。
*
* @param maxDepth 最大搜索深度(由配置上限约束)
* @return 路径结果,如果不存在路径则返回空路径
*/
public PathVO getShortestPath(String graphId, String sourceId, String targetId, int maxDepth) {
validateGraphId(graphId);
String filterUserId = resolveOwnerFilter();
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
// 校验两个实体存在 + 权限
GraphEntity sourceEntity = entityRepository.findByIdAndGraphId(sourceId, graphId)
.orElseThrow(() -> BusinessException.of(
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "源实体不存在"));
if (filterUserId != null) {
assertEntityAccess(sourceEntity, filterUserId, excludeConfidential);
}
entityRepository.findByIdAndGraphId(targetId, graphId)
.ifPresentOrElse(
targetEntity -> {
if (filterUserId != null && !sourceId.equals(targetId)) {
assertEntityAccess(targetEntity, filterUserId, excludeConfidential);
}
},
() -> { throw BusinessException.of(
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "目标实体不存在"); }
);
if (sourceId.equals(targetId)) {
// 起止相同,返回单节点路径
EntitySummaryVO node = EntitySummaryVO.builder()
.id(sourceEntity.getId())
.name(sourceEntity.getName())
.type(sourceEntity.getType())
.description(sourceEntity.getDescription())
.build();
return PathVO.builder()
.nodes(List.of(node))
.edges(List.of())
.pathLength(0)
.build();
}
int clampedDepth = Math.max(1, Math.min(maxDepth, properties.getMaxDepth()));
String permFilter = "";
if (filterUserId != null) {
StringBuilder pf = new StringBuilder("AND ALL(n IN nodes(path) WHERE ");
pf.append("(n.type IN ['User', 'Org', 'Field'] OR n.`properties.created_by` = $filterUserId)");
if (excludeConfidential) {
pf.append(" AND (toUpper(trim(n.`properties.sensitivity`)) IS NULL OR toUpper(trim(n.`properties.sensitivity`)) <> 'CONFIDENTIAL')");
}
pf.append(") ");
permFilter = pf.toString();
}
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("sourceId", sourceId);
params.put("targetId", targetId);
if (filterUserId != null) {
params.put("filterUserId", filterUserId);
}
// 使用 Neo4j shortestPath 函数
String cypher =
"MATCH (s:Entity {graph_id: $graphId, id: $sourceId}), " +
" (t:Entity {graph_id: $graphId, id: $targetId}), " +
" path = shortestPath((s)-[:" + REL_TYPE + "*1.." + clampedDepth + "]-(t)) " +
"WHERE ALL(n IN nodes(path) WHERE n.graph_id = $graphId) " +
permFilter +
"RETURN " +
" [n IN nodes(path) | {id: n.id, name: n.name, type: n.type, description: n.description}] AS pathNodes, " +
" [r IN relationships(path) | {id: r.id, relation_type: r.relation_type, weight: r.weight, " +
" source: startNode(r).id, target: endNode(r).id}] AS pathEdges, " +
" length(path) AS pathLength";
return neo4jClient.query(cypher)
.bindAll(params)
.fetchAs(PathVO.class)
.mappedBy((ts, record) -> mapPathRecord(record))
.one()
.orElse(PathVO.builder()
.nodes(List.of())
.edges(List.of())
.pathLength(-1)
.build());
}
// -----------------------------------------------------------------------
// 子图提取
// -----------------------------------------------------------------------
/**
* 提取指定实体集合之间的关系网络(子图)。
*
* @param entityIds 实体 ID 集合
*/
public SubgraphVO getSubgraph(String graphId, List<String> entityIds) {
validateGraphId(graphId);
String filterUserId = resolveOwnerFilter();
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
if (entityIds == null || entityIds.isEmpty()) {
return SubgraphVO.builder()
.nodes(List.of())
.edges(List.of())
.nodeCount(0)
.edgeCount(0)
.build();
}
int maxNodes = properties.getMaxNodesPerQuery();
if (entityIds.size() > maxNodes) {
throw BusinessException.of(KnowledgeGraphErrorCode.MAX_NODES_EXCEEDED,
"实体数量超出限制(最大 " + maxNodes + "");
}
// 查询存在的实体
List<GraphEntity> entities = entityRepository.findByGraphIdAndIdIn(graphId, entityIds);
// 权限过滤:非管理员只能看到自己创建的业务实体和结构型实体
if (filterUserId != null) {
entities = entities.stream()
.filter(e -> isEntityAccessible(e, filterUserId, excludeConfidential))
.toList();
}
List<EntitySummaryVO> nodes = entities.stream()
.map(e -> EntitySummaryVO.builder()
.id(e.getId())
.name(e.getName())
.type(e.getType())
.description(e.getDescription())
.build())
.toList();
if (nodes.isEmpty()) {
return SubgraphVO.builder()
.nodes(List.of())
.edges(List.of())
.nodeCount(0)
.edgeCount(0)
.build();
}
// 查询这些节点之间的边
List<String> existingIds = entities.stream().map(GraphEntity::getId).toList();
List<EdgeSummaryVO> edges = queryEdgesBetween(graphId, existingIds);
return SubgraphVO.builder()
.nodes(nodes)
.edges(edges)
.nodeCount(nodes.size())
.edgeCount(edges.size())
.build();
}
// -----------------------------------------------------------------------
// 全文搜索
// -----------------------------------------------------------------------
/**
* 基于 Neo4j 全文索引搜索实体(name + description)。
* <p>
* 使用 GraphInitializer 创建的 {@code entity_fulltext} 索引,
* 返回按相关度排序的结果。
*
* @param query 搜索关键词(支持 Lucene 查询语法)
*/
public PagedResponse<SearchHitVO> fulltextSearch(String graphId, String query, int page, int size) {
validateGraphId(graphId);
String filterUserId = resolveOwnerFilter();
boolean excludeConfidential = filterUserId != null && !resourceAccessService.canViewConfidential();
if (query == null || query.isBlank()) {
return PagedResponse.of(List.of(), 0, 0, 0);
}
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
// 对搜索关键词进行安全处理:转义 Lucene 特殊字符
String safeQuery = escapeLuceneQuery(query);
String permFilter = buildPermissionPredicate("node", filterUserId, excludeConfidential);
Map<String, Object> searchParams = new HashMap<>();
searchParams.put("graphId", graphId);
searchParams.put("query", safeQuery);
searchParams.put("skip", skip);
searchParams.put("size", safeSize);
if (filterUserId != null) {
searchParams.put("filterUserId", filterUserId);
}
List<SearchHitVO> results = neo4jClient
.query(
"CALL db.index.fulltext.queryNodes('entity_fulltext', $query) YIELD node, score " +
"WHERE node.graph_id = $graphId " +
permFilter +
"RETURN node.id AS id, node.name AS name, node.type AS type, " +
"node.description AS description, score " +
"ORDER BY score DESC " +
"SKIP $skip LIMIT $size"
)
.bindAll(searchParams)
.fetchAs(SearchHitVO.class)
.mappedBy((ts, record) -> SearchHitVO.builder()
.id(record.get("id").asString(null))
.name(record.get("name").asString(null))
.type(record.get("type").asString(null))
.description(record.get("description").asString(null))
.score(record.get("score").asDouble())
.build())
.all()
.stream().toList();
Map<String, Object> countParams = new HashMap<>();
countParams.put("graphId", graphId);
countParams.put("query", safeQuery);
if (filterUserId != null) {
countParams.put("filterUserId", filterUserId);
}
long total = neo4jClient
.query(
"CALL db.index.fulltext.queryNodes('entity_fulltext', $query) YIELD node, score " +
"WHERE node.graph_id = $graphId " +
permFilter +
"RETURN count(*) AS total"
)
.bindAll(countParams)
.fetchAs(Long.class)
.mappedBy((ts, record) -> record.get("total").asLong())
.one()
.orElse(0L);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
return PagedResponse.of(results, safePage, total, totalPages);
}
// -----------------------------------------------------------------------
// 权限过滤
// -----------------------------------------------------------------------
/**
* 获取 owner 过滤用户 ID。
* 管理员返回 null(不过滤),普通用户返回当前 userId。
*/
private String resolveOwnerFilter() {
return resourceAccessService.resolveOwnerFilterUserId();
}
/**
* 构建 Cypher 权限过滤条件片段。
* <p>
* 管理员返回空字符串(不过滤);
* 普通用户返回 AND 子句:仅保留结构型实体(User、Org、Field)
* 和 {@code created_by} 等于当前用户的业务实体。
* 若无保密数据权限,额外过滤 sensitivity=CONFIDENTIAL。
*/
private static String buildPermissionPredicate(String nodeAlias, String filterUserId, boolean excludeConfidential) {
StringBuilder sb = new StringBuilder();
if (filterUserId != null) {
sb.append("AND (").append(nodeAlias).append(".type IN ['User', 'Org', 'Field'] OR ")
.append(nodeAlias).append(".`properties.created_by` = $filterUserId) ");
}
if (excludeConfidential) {
sb.append("AND (toUpper(trim(").append(nodeAlias).append(".`properties.sensitivity`)) IS NULL OR ")
.append("toUpper(trim(").append(nodeAlias).append(".`properties.sensitivity`)) <> 'CONFIDENTIAL') ");
}
return sb.toString();
}
/**
* 校验非管理员用户对实体的访问权限。
* 保密数据需要 canViewConfidential 权限;
* 结构型实体(User、Org、Field)对所有用户可见;
* 业务实体必须匹配 created_by。
*/
private static void assertEntityAccess(GraphEntity entity, String filterUserId, boolean excludeConfidential) {
// 保密数据检查(大小写不敏感,与 data-management 一致)
if (excludeConfidential) {
Object sensitivity = entity.getProperties() != null
? entity.getProperties().get("sensitivity") : null;
if (sensitivity != null && sensitivity.toString().trim().equalsIgnoreCase("CONFIDENTIAL")) {
throw BusinessException.of(SystemErrorCode.INSUFFICIENT_PERMISSIONS, "无权访问保密数据");
}
}
// 结构型实体按类型白名单放行
if (STRUCTURAL_ENTITY_TYPES.contains(entity.getType())) {
return;
}
// 业务实体必须匹配 owner
Object createdBy = entity.getProperties() != null
? entity.getProperties().get("created_by") : null;
if (createdBy == null || !filterUserId.equals(createdBy.toString())) {
throw BusinessException.of(SystemErrorCode.INSUFFICIENT_PERMISSIONS, "无权访问该实体");
}
}
/**
* 判断实体是否对指定用户可访问。
* 保密数据需要 canViewConfidential 权限;
* 结构型实体(User、Org、Field)对所有用户可见;
* 业务实体必须匹配 created_by。
*/
private static boolean isEntityAccessible(GraphEntity entity, String filterUserId, boolean excludeConfidential) {
// 保密数据检查(大小写不敏感,与 data-management 一致)
if (excludeConfidential) {
Object sensitivity = entity.getProperties() != null
? entity.getProperties().get("sensitivity") : null;
if (sensitivity != null && sensitivity.toString().trim().equalsIgnoreCase("CONFIDENTIAL")) {
return false;
}
}
// 结构型实体按类型白名单放行
if (STRUCTURAL_ENTITY_TYPES.contains(entity.getType())) {
return true;
}
// 业务实体必须匹配 owner
Object createdBy = entity.getProperties() != null
? entity.getProperties().get("created_by") : null;
return createdBy != null && filterUserId.equals(createdBy.toString());
}
// -----------------------------------------------------------------------
// 内部方法
// -----------------------------------------------------------------------
/**
* 查询指定节点集合之间的所有边。
*/
private List<EdgeSummaryVO> queryEdgesBetween(String graphId, List<String> nodeIds) {
if (nodeIds.size() < 2) {
return List.of();
}
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})-[r:" + REL_TYPE + " {graph_id: $graphId}]->(t:Entity {graph_id: $graphId}) " +
"WHERE s.id IN $nodeIds AND t.id IN $nodeIds " +
"RETURN r.id AS id, s.id AS sourceEntityId, t.id AS targetEntityId, " +
"r.relation_type AS relationType, r.weight AS weight"
)
.bindAll(Map.of("graphId", graphId, "nodeIds", nodeIds))
.fetchAs(EdgeSummaryVO.class)
.mappedBy((ts, record) -> EdgeSummaryVO.builder()
.id(record.get("id").asString(null))
.sourceEntityId(record.get("sourceEntityId").asString(null))
.targetEntityId(record.get("targetEntityId").asString(null))
.relationType(record.get("relationType").asString(null))
.weight(record.get("weight").isNull() ? null : record.get("weight").asDouble())
.build())
.all()
.stream().toList();
}
private PathVO mapPathRecord(MapAccessor record) {
// 解析路径节点
List<EntitySummaryVO> nodes = new ArrayList<>();
Value pathNodes = record.get("pathNodes");
if (pathNodes != null && !pathNodes.isNull()) {
for (Value nodeVal : pathNodes.asList(v -> v)) {
nodes.add(EntitySummaryVO.builder()
.id(getStringOrNull(nodeVal, "id"))
.name(getStringOrNull(nodeVal, "name"))
.type(getStringOrNull(nodeVal, "type"))
.description(getStringOrNull(nodeVal, "description"))
.build());
}
}
// 解析路径边
List<EdgeSummaryVO> edges = new ArrayList<>();
Value pathEdges = record.get("pathEdges");
if (pathEdges != null && !pathEdges.isNull()) {
for (Value edgeVal : pathEdges.asList(v -> v)) {
edges.add(EdgeSummaryVO.builder()
.id(getStringOrNull(edgeVal, "id"))
.sourceEntityId(getStringOrNull(edgeVal, "source"))
.targetEntityId(getStringOrNull(edgeVal, "target"))
.relationType(getStringOrNull(edgeVal, "relation_type"))
.weight(getDoubleOrNull(edgeVal, "weight"))
.build());
}
}
int pathLength = record.get("pathLength").asInt(0);
return PathVO.builder()
.nodes(nodes)
.edges(edges)
.pathLength(pathLength)
.build();
}
/**
* 转义 Lucene 查询中的特殊字符,防止查询注入。
*/
private static String escapeLuceneQuery(String query) {
// Lucene 特殊字符: + - && || ! ( ) { } [ ] ^ " ~ * ? : \ /
StringBuilder sb = new StringBuilder();
for (char c : query.toCharArray()) {
if ("+-&|!(){}[]^\"~*?:\\/".indexOf(c) >= 0) {
sb.append('\\');
}
sb.append(c);
}
return sb.toString();
}
private static String getStringOrNull(Value value, String key) {
Value v = value.get(key);
return (v == null || v.isNull()) ? null : v.asString();
}
private static Double getDoubleOrNull(Value value, String key) {
Value v = value.get(key);
return (v == null || v.isNull()) ? null : v.asDouble();
}
private void validateGraphId(String graphId) {
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
}
}
}

View File

@@ -0,0 +1,218 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.domain.model.RelationDetail;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
import com.datamate.knowledgegraph.interfaces.dto.UpdateRelationRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
/**
* 知识图谱关系业务服务。
* <p>
* <b>信任边界说明</b>:本服务仅通过内网被 API Gateway / Java 后端调用,
* 网关层已完成用户身份认证与权限校验,服务层不再重复鉴权,
* 仅校验 graphId 格式(防 Cypher 注入)与数据完整性约束。
*/
@Service
@Slf4j
@RequiredArgsConstructor
public class GraphRelationService {
/** 分页偏移量上限,防止深翻页导致 Neo4j 性能退化。 */
private static final long MAX_SKIP = 100_000L;
/** 合法的关系查询方向。 */
private static final Set<String> VALID_DIRECTIONS = Set.of("all", "in", "out");
private static final Pattern UUID_PATTERN = Pattern.compile(
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
);
private final GraphRelationRepository relationRepository;
private final GraphEntityRepository entityRepository;
@Transactional
public RelationVO createRelation(String graphId, CreateRelationRequest request) {
validateGraphId(graphId);
// 校验源实体存在
entityRepository.findByIdAndGraphId(request.getSourceEntityId(), graphId)
.orElseThrow(() -> BusinessException.of(
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "源实体不存在"));
// 校验目标实体存在
entityRepository.findByIdAndGraphId(request.getTargetEntityId(), graphId)
.orElseThrow(() -> BusinessException.of(
KnowledgeGraphErrorCode.ENTITY_NOT_FOUND, "目标实体不存在"));
RelationDetail detail = relationRepository.create(
graphId,
request.getSourceEntityId(),
request.getTargetEntityId(),
request.getRelationType(),
request.getProperties(),
request.getWeight(),
request.getSourceId(),
request.getConfidence()
).orElseThrow(() -> BusinessException.of(
KnowledgeGraphErrorCode.INVALID_RELATION, "关系创建失败"));
log.info("Relation created: id={}, graphId={}, type={}, source={} -> target={}",
detail.getId(), graphId, request.getRelationType(),
request.getSourceEntityId(), request.getTargetEntityId());
return toVO(detail);
}
public RelationVO getRelation(String graphId, String relationId) {
validateGraphId(graphId);
RelationDetail detail = relationRepository.findByIdAndGraphId(relationId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
return toVO(detail);
}
public PagedResponse<RelationVO> listRelations(String graphId, String type, int page, int size) {
validateGraphId(graphId);
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
List<RelationDetail> details = relationRepository.findByGraphId(graphId, type, skip, safeSize);
long total = relationRepository.countByGraphId(graphId, type);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
List<RelationVO> content = details.stream().map(GraphRelationService::toVO).toList();
return PagedResponse.of(content, safePage, total, totalPages);
}
/**
* 查询实体的关系列表。
*
* @param direction "all"、"in" 或 "out"
*/
public PagedResponse<RelationVO> listEntityRelations(String graphId, String entityId,
String direction, String type,
int page, int size) {
validateGraphId(graphId);
// 校验实体存在
entityRepository.findByIdAndGraphId(entityId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.ENTITY_NOT_FOUND));
int safePage = Math.max(0, page);
int safeSize = Math.max(1, Math.min(size, 200));
long skip = (long) safePage * safeSize;
if (skip > MAX_SKIP) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "分页偏移量过大");
}
String safeDirection = (direction != null) ? direction : "all";
if (!VALID_DIRECTIONS.contains(safeDirection)) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER,
"direction 参数无效,允许值:all, in, out");
}
List<RelationDetail> details;
switch (safeDirection) {
case "in":
details = relationRepository.findInboundByEntityId(graphId, entityId, type, skip, safeSize);
break;
case "out":
details = relationRepository.findOutboundByEntityId(graphId, entityId, type, skip, safeSize);
break;
default:
details = relationRepository.findByEntityId(graphId, entityId, type, skip, safeSize);
break;
}
long total = relationRepository.countByEntityId(graphId, entityId, type, safeDirection);
long totalPages = safeSize > 0 ? (total + safeSize - 1) / safeSize : 0;
List<RelationVO> content = details.stream().map(GraphRelationService::toVO).toList();
return PagedResponse.of(content, safePage, total, totalPages);
}
@Transactional
public RelationVO updateRelation(String graphId, String relationId, UpdateRelationRequest request) {
validateGraphId(graphId);
// 确认关系存在
relationRepository.findByIdAndGraphId(relationId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
RelationDetail detail = relationRepository.update(
relationId, graphId,
request.getRelationType(),
request.getProperties(),
request.getWeight(),
request.getConfidence()
).orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
log.info("Relation updated: id={}, graphId={}", relationId, graphId);
return toVO(detail);
}
@Transactional
public void deleteRelation(String graphId, String relationId) {
validateGraphId(graphId);
// 确认关系存在
relationRepository.findByIdAndGraphId(relationId, graphId)
.orElseThrow(() -> BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND));
long deleted = relationRepository.deleteByIdAndGraphId(relationId, graphId);
if (deleted <= 0) {
throw BusinessException.of(KnowledgeGraphErrorCode.RELATION_NOT_FOUND);
}
log.info("Relation deleted: id={}, graphId={}", relationId, graphId);
}
// -----------------------------------------------------------------------
// 领域对象 → 视图对象 转换
// -----------------------------------------------------------------------
private static RelationVO toVO(RelationDetail detail) {
return RelationVO.builder()
.id(detail.getId())
.sourceEntityId(detail.getSourceEntityId())
.sourceEntityName(detail.getSourceEntityName())
.sourceEntityType(detail.getSourceEntityType())
.targetEntityId(detail.getTargetEntityId())
.targetEntityName(detail.getTargetEntityName())
.targetEntityType(detail.getTargetEntityType())
.relationType(detail.getRelationType())
.properties(detail.getProperties())
.weight(detail.getWeight())
.confidence(detail.getConfidence())
.sourceId(detail.getSourceId())
.graphId(detail.getGraphId())
.createdAt(detail.getCreatedAt())
.build();
}
/**
* 校验 graphId 格式(UUID)。
* 防止恶意构造的 graphId 注入 Cypher 查询。
*/
private void validateGraphId(String graphId) {
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
}
}
}

View File

@@ -0,0 +1,679 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.common.infrastructure.exception.SystemErrorCode;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.JobDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.LabelTaskDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.KnowledgeSetDTO;
import com.datamate.knowledgegraph.infrastructure.exception.KnowledgeGraphErrorCode;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* 知识图谱数据同步编排器(无 {@code @Transactional})。
* <p>
* 负责拉取数据、编排同步步骤、管理并发锁,具体写操作委托给
* {@link GraphSyncStepService}(事务边界)。
* <p>
* <b>并发控制</b>:同一 graphId 的同步操作通过 {@link ReentrantLock} 互斥,
* 防止并发写入导致数据不一致。
* <p>
* <b>数据快照</b>:全量同步时只拉取一次数据集列表,
* 在各步骤间共享,避免重复 HTTP 调用。
* <p>
* <b>多实例部署</b>:当前图级锁为进程内 {@link ReentrantLock},
* 多实例部署时依赖 Neo4j 复合唯一约束 (graph_id, source_id, type)
* 兜底防止重复写入。如需严格互斥,可替换为 Redis/DB 分布式锁。
* <p>
* <b>信任边界</b>:本服务仅通过内网被 API Gateway / 定时任务调用,
* 网关层已完成用户身份认证与权限校验。
*/
@Service
@Slf4j
@RequiredArgsConstructor
public class GraphSyncService {
private static final Pattern UUID_PATTERN = Pattern.compile(
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
);
private final GraphSyncStepService stepService;
private final DataManagementClient dataManagementClient;
private final KnowledgeGraphProperties properties;
/** 同 graphId 互斥锁,防止并发同步。 */
private final ConcurrentHashMap<String, ReentrantLock> graphLocks = new ConcurrentHashMap<>();
// -----------------------------------------------------------------------
// 全量同步
// -----------------------------------------------------------------------
public List<SyncResult> syncAll(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
log.info("[{}] Starting full sync for graphId={}", syncId, graphId);
// 一次拉取,全程共享
List<DatasetDTO> datasets = fetchDatasetsWithRetry(syncId);
List<WorkflowDTO> workflows = fetchWithRetry(syncId, "workflows",
() -> dataManagementClient.listAllWorkflows());
List<JobDTO> jobs = fetchWithRetry(syncId, "jobs",
() -> dataManagementClient.listAllJobs());
List<LabelTaskDTO> labelTasks = fetchWithRetry(syncId, "label-tasks",
() -> dataManagementClient.listAllLabelTasks());
List<KnowledgeSetDTO> knowledgeSets = fetchWithRetry(syncId, "knowledge-sets",
() -> dataManagementClient.listAllKnowledgeSets());
// 使用 LinkedHashMap 按 syncType 存取,保持插入顺序且避免固定下标
Map<String, SyncResult> resultMap = new LinkedHashMap<>();
// 实体同步
resultMap.put("Dataset", stepService.upsertDatasetEntities(graphId, datasets, syncId));
resultMap.put("Field", stepService.upsertFieldEntities(graphId, datasets, syncId));
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
resultMap.put("User", stepService.upsertUserEntities(graphId, usernames, syncId));
resultMap.put("Org", stepService.upsertOrgEntities(graphId, syncId));
resultMap.put("Workflow", stepService.upsertWorkflowEntities(graphId, workflows, syncId));
resultMap.put("Job", stepService.upsertJobEntities(graphId, jobs, syncId));
resultMap.put("LabelTask", stepService.upsertLabelTaskEntities(graphId, labelTasks, syncId));
resultMap.put("KnowledgeSet", stepService.upsertKnowledgeSetEntities(graphId, knowledgeSets, syncId));
// 全量对账:删除 MySQL 已移除的记录(按 syncType 查找,无需固定下标)
Set<String> activeDatasetIds = datasets.stream()
.filter(Objects::nonNull)
.map(DatasetDTO::getId)
.filter(Objects::nonNull)
.filter(id -> !id.isBlank())
.collect(Collectors.toSet());
resultMap.get("Dataset").setPurged(
stepService.purgeStaleEntities(graphId, "Dataset", activeDatasetIds, syncId));
Set<String> activeFieldIds = new HashSet<>();
for (DatasetDTO dto : datasets) {
if (dto == null || dto.getTags() == null) {
continue;
}
for (DataManagementClient.TagDTO tag : dto.getTags()) {
if (tag == null || tag.getName() == null) {
continue;
}
activeFieldIds.add(dto.getId() + ":tag:" + tag.getName());
}
}
resultMap.get("Field").setPurged(
stepService.purgeStaleEntities(graphId, "Field", activeFieldIds, syncId));
Set<String> activeUserIds = usernames.stream()
.map(u -> "user:" + u)
.collect(Collectors.toSet());
resultMap.get("User").setPurged(
stepService.purgeStaleEntities(graphId, "User", activeUserIds, syncId));
Set<String> activeWorkflowIds = workflows.stream()
.filter(Objects::nonNull)
.map(WorkflowDTO::getId)
.filter(Objects::nonNull)
.filter(id -> !id.isBlank())
.collect(Collectors.toSet());
resultMap.get("Workflow").setPurged(
stepService.purgeStaleEntities(graphId, "Workflow", activeWorkflowIds, syncId));
Set<String> activeJobIds = jobs.stream()
.filter(Objects::nonNull)
.map(JobDTO::getId)
.filter(Objects::nonNull)
.filter(id -> !id.isBlank())
.collect(Collectors.toSet());
resultMap.get("Job").setPurged(
stepService.purgeStaleEntities(graphId, "Job", activeJobIds, syncId));
Set<String> activeLabelTaskIds = labelTasks.stream()
.filter(Objects::nonNull)
.map(LabelTaskDTO::getId)
.filter(Objects::nonNull)
.filter(id -> !id.isBlank())
.collect(Collectors.toSet());
resultMap.get("LabelTask").setPurged(
stepService.purgeStaleEntities(graphId, "LabelTask", activeLabelTaskIds, syncId));
Set<String> activeKnowledgeSetIds = knowledgeSets.stream()
.filter(Objects::nonNull)
.map(KnowledgeSetDTO::getId)
.filter(Objects::nonNull)
.filter(id -> !id.isBlank())
.collect(Collectors.toSet());
resultMap.get("KnowledgeSet").setPurged(
stepService.purgeStaleEntities(graphId, "KnowledgeSet", activeKnowledgeSetIds, syncId));
// 关系构建(MERGE 幂等)
resultMap.put("HAS_FIELD", stepService.mergeHasFieldRelations(graphId, syncId));
resultMap.put("DERIVED_FROM", stepService.mergeDerivedFromRelations(graphId, syncId));
resultMap.put("BELONGS_TO", stepService.mergeBelongsToRelations(graphId, syncId));
resultMap.put("USES_DATASET", stepService.mergeUsesDatasetRelations(graphId, syncId));
resultMap.put("PRODUCES", stepService.mergeProducesRelations(graphId, syncId));
resultMap.put("ASSIGNED_TO", stepService.mergeAssignedToRelations(graphId, syncId));
resultMap.put("TRIGGERS", stepService.mergeTriggersRelations(graphId, syncId));
resultMap.put("DEPENDS_ON", stepService.mergeDependsOnRelations(graphId, syncId));
resultMap.put("IMPACTS", stepService.mergeImpactsRelations(graphId, syncId));
resultMap.put("SOURCED_FROM", stepService.mergeSourcedFromRelations(graphId, syncId));
List<SyncResult> results = new ArrayList<>(resultMap.values());
log.info("[{}] Full sync completed for graphId={}. Summary: {}", syncId, graphId,
results.stream()
.map(r -> r.getSyncType() + "(+" + r.getCreated() + "/~" + r.getUpdated() + "/-" + r.getFailed() + ")")
.collect(Collectors.joining(", ")));
return results;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Full sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "全量同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
// -----------------------------------------------------------------------
// 单步同步(各自获取锁和数据)
// -----------------------------------------------------------------------
public SyncResult syncDatasets(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<DatasetDTO> datasets = fetchDatasetsWithRetry(syncId);
SyncResult result = stepService.upsertDatasetEntities(graphId, datasets, syncId);
Set<String> activeIds = datasets.stream()
.filter(Objects::nonNull).map(DatasetDTO::getId)
.filter(Objects::nonNull).filter(id -> !id.isBlank())
.collect(Collectors.toSet());
int purged = stepService.purgeStaleEntities(graphId, "Dataset", activeIds, syncId);
result.setPurged(purged);
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Dataset sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "数据集同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncFields(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<DatasetDTO> datasets = fetchDatasetsWithRetry(syncId);
SyncResult result = stepService.upsertFieldEntities(graphId, datasets, syncId);
Set<String> activeFieldIds = new HashSet<>();
for (DatasetDTO dto : datasets) {
if (dto == null || dto.getTags() == null) {
continue;
}
for (DataManagementClient.TagDTO tag : dto.getTags()) {
if (tag == null || tag.getName() == null) {
continue;
}
activeFieldIds.add(dto.getId() + ":tag:" + tag.getName());
}
}
result.setPurged(stepService.purgeStaleEntities(graphId, "Field", activeFieldIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Field sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "字段同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncUsers(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<DatasetDTO> datasets = fetchDatasetsWithRetry(syncId);
List<WorkflowDTO> workflows = fetchWithRetry(syncId, "workflows",
() -> dataManagementClient.listAllWorkflows());
List<JobDTO> jobs = fetchWithRetry(syncId, "jobs",
() -> dataManagementClient.listAllJobs());
List<LabelTaskDTO> labelTasks = fetchWithRetry(syncId, "label-tasks",
() -> dataManagementClient.listAllLabelTasks());
List<KnowledgeSetDTO> knowledgeSets = fetchWithRetry(syncId, "knowledge-sets",
() -> dataManagementClient.listAllKnowledgeSets());
Set<String> usernames = extractUsernames(datasets, workflows, jobs, labelTasks, knowledgeSets);
SyncResult result = stepService.upsertUserEntities(graphId, usernames, syncId);
Set<String> activeUserIds = usernames.stream().map(u -> "user:" + u).collect(Collectors.toSet());
result.setPurged(stepService.purgeStaleEntities(graphId, "User", activeUserIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] User sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "用户同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncOrgs(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.upsertOrgEntities(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Org sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "组织同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildHasFieldRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeHasFieldRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] HAS_FIELD relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"HAS_FIELD 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildDerivedFromRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeDerivedFromRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] DERIVED_FROM relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"DERIVED_FROM 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildBelongsToRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeBelongsToRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] BELONGS_TO relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"BELONGS_TO 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
// -----------------------------------------------------------------------
// 新增实体同步
// -----------------------------------------------------------------------
public SyncResult syncWorkflows(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<WorkflowDTO> workflows = fetchWithRetry(syncId, "workflows",
() -> dataManagementClient.listAllWorkflows());
SyncResult result = stepService.upsertWorkflowEntities(graphId, workflows, syncId);
Set<String> activeIds = workflows.stream()
.filter(Objects::nonNull).map(WorkflowDTO::getId)
.filter(Objects::nonNull).filter(id -> !id.isBlank())
.collect(Collectors.toSet());
result.setPurged(stepService.purgeStaleEntities(graphId, "Workflow", activeIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Workflow sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "工作流同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncJobs(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<JobDTO> jobs = fetchWithRetry(syncId, "jobs",
() -> dataManagementClient.listAllJobs());
SyncResult result = stepService.upsertJobEntities(graphId, jobs, syncId);
Set<String> activeIds = jobs.stream()
.filter(Objects::nonNull).map(JobDTO::getId)
.filter(Objects::nonNull).filter(id -> !id.isBlank())
.collect(Collectors.toSet());
result.setPurged(stepService.purgeStaleEntities(graphId, "Job", activeIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] Job sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "作业同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncLabelTasks(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<LabelTaskDTO> tasks = fetchWithRetry(syncId, "label-tasks",
() -> dataManagementClient.listAllLabelTasks());
SyncResult result = stepService.upsertLabelTaskEntities(graphId, tasks, syncId);
Set<String> activeIds = tasks.stream()
.filter(Objects::nonNull).map(LabelTaskDTO::getId)
.filter(Objects::nonNull).filter(id -> !id.isBlank())
.collect(Collectors.toSet());
result.setPurged(stepService.purgeStaleEntities(graphId, "LabelTask", activeIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] LabelTask sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "标注任务同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult syncKnowledgeSets(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
List<KnowledgeSetDTO> knowledgeSets = fetchWithRetry(syncId, "knowledge-sets",
() -> dataManagementClient.listAllKnowledgeSets());
SyncResult result = stepService.upsertKnowledgeSetEntities(graphId, knowledgeSets, syncId);
Set<String> activeIds = knowledgeSets.stream()
.filter(Objects::nonNull).map(KnowledgeSetDTO::getId)
.filter(Objects::nonNull).filter(id -> !id.isBlank())
.collect(Collectors.toSet());
result.setPurged(stepService.purgeStaleEntities(graphId, "KnowledgeSet", activeIds, syncId));
return result;
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] KnowledgeSet sync failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "知识集同步失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
// -----------------------------------------------------------------------
// 新增关系构建
// -----------------------------------------------------------------------
public SyncResult buildUsesDatasetRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeUsesDatasetRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] USES_DATASET relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"USES_DATASET 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildProducesRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeProducesRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] PRODUCES relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"PRODUCES 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildAssignedToRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeAssignedToRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] ASSIGNED_TO relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"ASSIGNED_TO 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildTriggersRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeTriggersRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] TRIGGERS relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"TRIGGERS 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildDependsOnRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeDependsOnRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] DEPENDS_ON relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"DEPENDS_ON 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildImpactsRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeImpactsRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] IMPACTS relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"IMPACTS 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
public SyncResult buildSourcedFromRelations(String graphId) {
validateGraphId(graphId);
String syncId = UUID.randomUUID().toString().substring(0, 8);
ReentrantLock lock = acquireLock(graphId, syncId);
try {
return stepService.mergeSourcedFromRelations(graphId, syncId);
} catch (BusinessException e) {
throw e;
} catch (Exception e) {
log.error("[{}] SOURCED_FROM relation build failed for graphId={}", syncId, graphId, e);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"SOURCED_FROM 关系构建失败,syncId=" + syncId);
} finally {
releaseLock(graphId, lock);
}
}
// -----------------------------------------------------------------------
// 内部方法
// -----------------------------------------------------------------------
private ReentrantLock acquireLock(String graphId, String syncId) {
ReentrantLock lock = graphLocks.computeIfAbsent(graphId, k -> new ReentrantLock());
if (!lock.tryLock()) {
log.warn("[{}] Graph {} is already being synced, rejecting concurrent request", syncId, graphId);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "该图谱正在同步中,请稍后重试");
}
return lock;
}
/**
* 释放锁并在无竞争时清理锁对象,防止 graphLocks 无限增长。
*/
private void releaseLock(String graphId, ReentrantLock lock) {
lock.unlock();
graphLocks.compute(graphId, (key, existing) -> {
// 仅当锁空闲且无等待线程时移除,compute 保证此 key 的原子性
if (existing != null && !existing.isLocked() && !existing.hasQueuedThreads()) {
return null;
}
return existing;
});
}
private List<DatasetDTO> fetchDatasetsWithRetry(String syncId) {
return fetchWithRetry(syncId, "datasets", () -> dataManagementClient.listAllDatasets());
}
/**
* 通用带重试的数据拉取方法。
*/
private <T> List<T> fetchWithRetry(String syncId, String resourceName,
java.util.function.Supplier<List<T>> fetcher) {
int maxRetries = properties.getSync().getMaxRetries();
long retryInterval = properties.getSync().getRetryInterval();
Exception lastException = null;
for (int attempt = 1; attempt <= maxRetries; attempt++) {
try {
return fetcher.get();
} catch (Exception e) {
lastException = e;
log.warn("[{}] {} fetch attempt {}/{} failed: {}", syncId, resourceName, attempt, maxRetries, e.getMessage());
if (attempt < maxRetries) {
try {
Thread.sleep(retryInterval * attempt);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED, "同步被中断");
}
}
}
}
log.error("[{}] All {} fetch attempts for {} failed", syncId, maxRetries, resourceName, lastException);
throw BusinessException.of(KnowledgeGraphErrorCode.SYNC_FAILED,
"拉取" + resourceName + "失败(已重试 " + maxRetries + " 次),syncId=" + syncId);
}
/**
* 从所有实体类型中提取用户名。
*/
private static Set<String> extractUsernames(List<DatasetDTO> datasets,
List<WorkflowDTO> workflows,
List<JobDTO> jobs,
List<LabelTaskDTO> labelTasks,
List<KnowledgeSetDTO> knowledgeSets) {
Set<String> usernames = new LinkedHashSet<>();
for (DatasetDTO dto : datasets) {
if (dto == null) { continue; }
addIfPresent(usernames, dto.getCreatedBy());
addIfPresent(usernames, dto.getUpdatedBy());
}
for (WorkflowDTO dto : workflows) {
if (dto == null) { continue; }
addIfPresent(usernames, dto.getCreatedBy());
addIfPresent(usernames, dto.getUpdatedBy());
}
for (JobDTO dto : jobs) {
if (dto == null) { continue; }
addIfPresent(usernames, dto.getCreatedBy());
addIfPresent(usernames, dto.getUpdatedBy());
}
for (LabelTaskDTO dto : labelTasks) {
if (dto == null) { continue; }
addIfPresent(usernames, dto.getCreatedBy());
addIfPresent(usernames, dto.getUpdatedBy());
}
for (KnowledgeSetDTO dto : knowledgeSets) {
if (dto == null) { continue; }
addIfPresent(usernames, dto.getCreatedBy());
addIfPresent(usernames, dto.getUpdatedBy());
}
return usernames;
}
private static void addIfPresent(Set<String> set, String value) {
if (value != null && !value.isBlank()) {
set.add(value);
}
}
private void validateGraphId(String graphId) {
if (graphId == null || !UUID_PATTERN.matcher(graphId).matches()) {
throw BusinessException.of(SystemErrorCode.INVALID_PARAMETER, "graphId 格式无效");
}
}
}

View File

@@ -0,0 +1,968 @@
package com.datamate.knowledgegraph.application;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.TagDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.JobDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.LabelTaskDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.KnowledgeSetDTO;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;
/**
* 同步步骤执行器(事务边界)。
* <p>
* 所有写操作在独立 {@code @Transactional} 方法中执行,
* 由 {@link GraphSyncService} 编排调用,避免自调用导致事务失效。
* <p>
* 关系构建使用 Cypher MERGE 保证幂等性,
* 实体 upsert 使用 Cypher MERGE 基于 (graph_id, source_id, type) 复合约束原子操作,
* 扩展属性通过 SDN 更新。
*/
@Service
@Slf4j
@RequiredArgsConstructor
public class GraphSyncStepService {
private static final String SOURCE_TYPE_SYNC = "SYNC";
private static final String REL_TYPE = "RELATED_TO";
private final GraphEntityRepository entityRepository;
private final Neo4jClient neo4jClient;
private final KnowledgeGraphProperties properties;
// -----------------------------------------------------------------------
// 实体 upsert
// -----------------------------------------------------------------------
@Transactional
public SyncResult upsertDatasetEntities(String graphId, List<DatasetDTO> datasets, String syncId) {
SyncResult result = beginResult("Dataset", syncId);
int batchSize = properties.getImportBatchSize();
for (int i = 0; i < datasets.size(); i++) {
DatasetDTO dto = datasets.get(i);
if (dto == null) {
result.incrementSkipped();
continue;
}
String sourceId = dto.getId();
try {
Map<String, Object> props = new HashMap<>();
props.put("dataset_type", dto.getDatasetType());
props.put("status", dto.getStatus());
props.put("total_size", dto.getTotalSize());
props.put("file_count", dto.getFileCount());
if (dto.getParentDatasetId() != null) {
props.put("parent_dataset_id", dto.getParentDatasetId());
}
if (dto.getTags() != null) {
List<String> tagNames = dto.getTags().stream()
.map(TagDTO::getName).toList();
props.put("tags", tagNames);
}
if (dto.getCreatedBy() != null) {
props.put("created_by", dto.getCreatedBy());
}
upsertEntity(graphId, dto.getId(), "Dataset",
dto.getName(), dto.getDescription(), props, result);
if ((i + 1) % batchSize == 0) {
log.debug("[{}] Processed {}/{} datasets", syncId, i + 1, datasets.size());
}
} catch (Exception e) {
log.warn("[{}] Failed to upsert dataset: sourceId={}", syncId, sourceId, e);
result.addError("dataset:" + sourceId);
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertFieldEntities(String graphId, List<DatasetDTO> datasets, String syncId) {
SyncResult result = beginResult("Field", syncId);
for (DatasetDTO dto : datasets) {
if (dto == null || dto.getTags() == null || dto.getTags().isEmpty()) {
continue;
}
String dtoId = dto.getId();
for (TagDTO tag : dto.getTags()) {
if (tag == null) {
continue;
}
String tagName = tag.getName();
try {
String fieldSourceId = dtoId + ":tag:" + tagName;
Map<String, Object> props = new HashMap<>();
props.put("data_type", "TAG");
props.put("dataset_source_id", dtoId);
if (tag.getColor() != null) {
props.put("color", tag.getColor());
}
upsertEntity(graphId, fieldSourceId, "Field", tagName,
"数据集[" + dto.getName() + "]的标签字段", props, result);
} catch (Exception e) {
log.warn("[{}] Failed to upsert field: dataset={}, tag={}",
syncId, dtoId, tagName, e);
result.addError("field:" + dtoId + ":" + tagName);
}
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertUserEntities(String graphId, Set<String> usernames, String syncId) {
SyncResult result = beginResult("User", syncId);
for (String username : usernames) {
try {
Map<String, Object> props = new HashMap<>();
props.put("username", username);
upsertEntity(graphId, "user:" + username, "User", username, null, props, result);
} catch (Exception e) {
log.warn("[{}] Failed to upsert user: username={}", syncId, username, e);
result.addError("user:" + username);
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertOrgEntities(String graphId, String syncId) {
SyncResult result = beginResult("Org", syncId);
try {
Map<String, Object> props = new HashMap<>();
props.put("org_code", "DEFAULT");
props.put("level", 1);
upsertEntity(graphId, "org:default", "Org", "默认组织",
"系统默认组织(待对接组织服务后更新)", props, result);
} catch (Exception e) {
log.warn("[{}] Failed to upsert default org", syncId, e);
result.addError("org:default");
}
return endResult(result);
}
@Transactional
public SyncResult upsertWorkflowEntities(String graphId, List<WorkflowDTO> workflows, String syncId) {
SyncResult result = beginResult("Workflow", syncId);
int batchSize = properties.getImportBatchSize();
for (int i = 0; i < workflows.size(); i++) {
WorkflowDTO dto = workflows.get(i);
if (dto == null) {
result.incrementSkipped();
continue;
}
String sourceId = dto.getId();
try {
Map<String, Object> props = new HashMap<>();
props.put("workflow_type", dto.getWorkflowType());
props.put("status", dto.getStatus());
if (dto.getVersion() != null) {
props.put("version", dto.getVersion());
}
if (dto.getOperatorCount() != null) {
props.put("operator_count", dto.getOperatorCount());
}
if (dto.getSchedule() != null) {
props.put("schedule", dto.getSchedule());
}
if (dto.getInputDatasetIds() != null) {
props.put("input_dataset_ids", dto.getInputDatasetIds());
}
if (dto.getCreatedBy() != null) {
props.put("created_by", dto.getCreatedBy());
}
upsertEntity(graphId, dto.getId(), "Workflow",
dto.getName(), dto.getDescription(), props, result);
if ((i + 1) % batchSize == 0) {
log.debug("[{}] Processed {}/{} workflows", syncId, i + 1, workflows.size());
}
} catch (Exception e) {
log.warn("[{}] Failed to upsert workflow: sourceId={}", syncId, sourceId, e);
result.addError("workflow:" + sourceId);
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertJobEntities(String graphId, List<JobDTO> jobs, String syncId) {
SyncResult result = beginResult("Job", syncId);
int batchSize = properties.getImportBatchSize();
for (int i = 0; i < jobs.size(); i++) {
JobDTO dto = jobs.get(i);
if (dto == null) {
result.incrementSkipped();
continue;
}
String sourceId = dto.getId();
try {
Map<String, Object> props = new HashMap<>();
props.put("job_type", dto.getJobType());
props.put("status", dto.getStatus());
if (dto.getStartedAt() != null) {
props.put("started_at", dto.getStartedAt());
}
if (dto.getCompletedAt() != null) {
props.put("completed_at", dto.getCompletedAt());
}
if (dto.getDurationSeconds() != null) {
props.put("duration_seconds", dto.getDurationSeconds());
}
if (dto.getInputCount() != null) {
props.put("input_count", dto.getInputCount());
}
if (dto.getOutputCount() != null) {
props.put("output_count", dto.getOutputCount());
}
if (dto.getErrorMessage() != null) {
props.put("error_message", dto.getErrorMessage());
}
if (dto.getInputDatasetId() != null) {
props.put("input_dataset_id", dto.getInputDatasetId());
}
if (dto.getOutputDatasetId() != null) {
props.put("output_dataset_id", dto.getOutputDatasetId());
}
if (dto.getWorkflowId() != null) {
props.put("workflow_id", dto.getWorkflowId());
}
if (dto.getDependsOnJobId() != null) {
props.put("depends_on_job_id", dto.getDependsOnJobId());
}
if (dto.getCreatedBy() != null) {
props.put("created_by", dto.getCreatedBy());
}
upsertEntity(graphId, dto.getId(), "Job",
dto.getName(), dto.getDescription(), props, result);
if ((i + 1) % batchSize == 0) {
log.debug("[{}] Processed {}/{} jobs", syncId, i + 1, jobs.size());
}
} catch (Exception e) {
log.warn("[{}] Failed to upsert job: sourceId={}", syncId, sourceId, e);
result.addError("job:" + sourceId);
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertLabelTaskEntities(String graphId, List<LabelTaskDTO> tasks, String syncId) {
SyncResult result = beginResult("LabelTask", syncId);
int batchSize = properties.getImportBatchSize();
for (int i = 0; i < tasks.size(); i++) {
LabelTaskDTO dto = tasks.get(i);
if (dto == null) {
result.incrementSkipped();
continue;
}
String sourceId = dto.getId();
try {
Map<String, Object> props = new HashMap<>();
props.put("task_mode", dto.getTaskMode());
props.put("status", dto.getStatus());
if (dto.getDataType() != null) {
props.put("data_type", dto.getDataType());
}
if (dto.getLabelingType() != null) {
props.put("labeling_type", dto.getLabelingType());
}
if (dto.getProgress() != null) {
props.put("progress", dto.getProgress());
}
if (dto.getTemplateName() != null) {
props.put("template_name", dto.getTemplateName());
}
if (dto.getDatasetId() != null) {
props.put("dataset_id", dto.getDatasetId());
}
if (dto.getCreatedBy() != null) {
props.put("created_by", dto.getCreatedBy());
}
upsertEntity(graphId, dto.getId(), "LabelTask",
dto.getName(), dto.getDescription(), props, result);
if ((i + 1) % batchSize == 0) {
log.debug("[{}] Processed {}/{} label tasks", syncId, i + 1, tasks.size());
}
} catch (Exception e) {
log.warn("[{}] Failed to upsert label task: sourceId={}", syncId, sourceId, e);
result.addError("label_task:" + sourceId);
}
}
return endResult(result);
}
@Transactional
public SyncResult upsertKnowledgeSetEntities(String graphId, List<KnowledgeSetDTO> knowledgeSets, String syncId) {
SyncResult result = beginResult("KnowledgeSet", syncId);
int batchSize = properties.getImportBatchSize();
for (int i = 0; i < knowledgeSets.size(); i++) {
KnowledgeSetDTO dto = knowledgeSets.get(i);
if (dto == null) {
result.incrementSkipped();
continue;
}
String sourceId = dto.getId();
try {
Map<String, Object> props = new HashMap<>();
props.put("status", dto.getStatus());
if (dto.getDomain() != null) {
props.put("domain", dto.getDomain());
}
if (dto.getBusinessLine() != null) {
props.put("business_line", dto.getBusinessLine());
}
if (dto.getSensitivity() != null) {
props.put("sensitivity", dto.getSensitivity());
}
if (dto.getItemCount() != null) {
props.put("item_count", dto.getItemCount());
}
if (dto.getValidFrom() != null) {
props.put("valid_from", dto.getValidFrom());
}
if (dto.getValidTo() != null) {
props.put("valid_to", dto.getValidTo());
}
if (dto.getSourceDatasetIds() != null) {
props.put("source_dataset_ids", dto.getSourceDatasetIds());
}
if (dto.getCreatedBy() != null) {
props.put("created_by", dto.getCreatedBy());
}
upsertEntity(graphId, dto.getId(), "KnowledgeSet",
dto.getName(), dto.getDescription(), props, result);
if ((i + 1) % batchSize == 0) {
log.debug("[{}] Processed {}/{} knowledge sets", syncId, i + 1, knowledgeSets.size());
}
} catch (Exception e) {
log.warn("[{}] Failed to upsert knowledge set: sourceId={}", syncId, sourceId, e);
result.addError("knowledge_set:" + sourceId);
}
}
return endResult(result);
}
// -----------------------------------------------------------------------
// 全量对账删除
// -----------------------------------------------------------------------
/**
* 删除 Neo4j 中 source_type=SYNC 但 source_id 不在活跃集合中的实体。
* 使用 DETACH DELETE 同时清理关联关系。
* <p>
* <b>空快照保护</b>:当 activeSourceIds 为空时,默认拒绝 purge 以防误删全部同步实体。
* 仅当配置 {@code allowPurgeOnEmptySnapshot=true} 时才允许空集触发 purge。
*/
@Transactional
public int purgeStaleEntities(String graphId, String type, Set<String> activeSourceIds, String syncId) {
// 防御式过滤:移除 null / 空白 ID,防止 Cypher 三值逻辑导致 IN 判断失效
Set<String> sanitized = activeSourceIds.stream()
.filter(Objects::nonNull)
.map(String::trim)
.filter(s -> !s.isEmpty())
.collect(Collectors.toSet());
if (sanitized.isEmpty()) {
if (!properties.getSync().isAllowPurgeOnEmptySnapshot()) {
log.warn("[{}] Empty snapshot protection: active source IDs empty for type={}, " +
"purge BLOCKED (set allowPurgeOnEmptySnapshot=true to override)", syncId, type);
return 0;
}
log.warn("[{}] Active source IDs empty for type={}, purging ALL SYNC entities " +
"(allowPurgeOnEmptySnapshot=true)", syncId, type);
}
String cypher;
Map<String, Object> params;
if (sanitized.isEmpty()) {
cypher = "MATCH (e:Entity {graph_id: $graphId, type: $type, source_type: 'SYNC'}) " +
"DETACH DELETE e " +
"RETURN count(*) AS deleted";
params = Map.of("graphId", graphId, "type", type);
} else {
cypher = "MATCH (e:Entity {graph_id: $graphId, type: $type, source_type: 'SYNC'}) " +
"WHERE NOT e.source_id IN $activeSourceIds " +
"DETACH DELETE e " +
"RETURN count(*) AS deleted";
params = Map.of(
"graphId", graphId,
"type", type,
"activeSourceIds", new ArrayList<>(sanitized)
);
}
long deleted = neo4jClient.query(cypher)
.bindAll(params)
.fetchAs(Long.class)
.mappedBy((ts, record) -> record.get("deleted").asLong())
.one()
.orElse(0L);
if (deleted > 0) {
log.info("[{}] Purged {} stale {} entities from graphId={}", syncId, deleted, type, graphId);
}
return (int) deleted;
}
// -----------------------------------------------------------------------
// 关系构建(MERGE 保证幂等)
// -----------------------------------------------------------------------
@Transactional
public SyncResult mergeHasFieldRelations(String graphId, String syncId) {
SyncResult result = beginResult("HAS_FIELD", syncId);
Map<String, String> datasetMap = buildSourceIdToEntityIdMap(graphId, "Dataset");
List<GraphEntity> fields = entityRepository.findByGraphIdAndType(graphId, "Field");
for (GraphEntity field : fields) {
try {
Object datasetSourceId = field.getProperties().get("dataset_source_id");
if (datasetSourceId == null) {
result.incrementSkipped();
continue;
}
String datasetEntityId = datasetMap.get(datasetSourceId.toString());
if (datasetEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, datasetEntityId, field.getId(),
"HAS_FIELD", "{}", syncId);
if (created) {
result.incrementCreated();
} else {
result.incrementSkipped();
}
} catch (Exception e) {
log.warn("[{}] Failed to merge HAS_FIELD for field: id={}", syncId, field.getId(), e);
result.addError("has_field:" + field.getId());
}
}
return endResult(result);
}
@Transactional
public SyncResult mergeDerivedFromRelations(String graphId, String syncId) {
SyncResult result = beginResult("DERIVED_FROM", syncId);
Map<String, String> datasetMap = buildSourceIdToEntityIdMap(graphId, "Dataset");
List<GraphEntity> datasets = entityRepository.findByGraphIdAndType(graphId, "Dataset");
for (GraphEntity dataset : datasets) {
try {
Object parentId = dataset.getProperties().get("parent_dataset_id");
if (parentId == null || parentId.toString().isBlank()) {
continue;
}
String parentEntityId = datasetMap.get(parentId.toString());
if (parentEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, dataset.getId(), parentEntityId,
"DERIVED_FROM", "{\"derivation_type\":\"VERSION\"}", syncId);
if (created) {
result.incrementCreated();
} else {
result.incrementSkipped();
}
} catch (Exception e) {
log.warn("[{}] Failed to merge DERIVED_FROM for dataset: id={}", syncId, dataset.getId(), e);
result.addError("derived_from:" + dataset.getId());
}
}
return endResult(result);
}
@Transactional
public SyncResult mergeBelongsToRelations(String graphId, String syncId) {
SyncResult result = beginResult("BELONGS_TO", syncId);
Optional<GraphEntity> defaultOrgOpt = entityRepository.findByGraphIdAndSourceIdAndType(
graphId, "org:default", "Org");
if (defaultOrgOpt.isEmpty()) {
log.warn("[{}] Default org not found, skipping BELONGS_TO", syncId);
result.addError("belongs_to:org_missing");
return endResult(result);
}
String orgId = defaultOrgOpt.get().getId();
// User → Org
for (GraphEntity user : entityRepository.findByGraphIdAndType(graphId, "User")) {
try {
boolean created = mergeRelation(graphId, user.getId(), orgId,
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge BELONGS_TO for user: id={}", syncId, user.getId(), e);
result.addError("belongs_to:user:" + user.getId());
}
}
// Dataset → Org
for (GraphEntity dataset : entityRepository.findByGraphIdAndType(graphId, "Dataset")) {
try {
boolean created = mergeRelation(graphId, dataset.getId(), orgId,
"BELONGS_TO", "{\"membership_type\":\"PRIMARY\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge BELONGS_TO for dataset: id={}", syncId, dataset.getId(), e);
result.addError("belongs_to:dataset:" + dataset.getId());
}
}
return endResult(result);
}
/**
* 构建 USES_DATASET 关系:Job/LabelTask/Workflow → Dataset。
* <p>
* 通过实体扩展属性中的外键字段查找关联 Dataset:
* - Job.input_dataset_id → Dataset
* - LabelTask.dataset_id → Dataset
* - Workflow.input_dataset_ids → Dataset(多值)
*/
@Transactional
public SyncResult mergeUsesDatasetRelations(String graphId, String syncId) {
SyncResult result = beginResult("USES_DATASET", syncId);
Map<String, String> datasetMap = buildSourceIdToEntityIdMap(graphId, "Dataset");
// Job → Dataset (via input_dataset_id)
for (GraphEntity job : entityRepository.findByGraphIdAndType(graphId, "Job")) {
mergeEntityToDatasets(graphId, job, "input_dataset_id", datasetMap, result, syncId);
}
// LabelTask → Dataset (via dataset_id)
for (GraphEntity task : entityRepository.findByGraphIdAndType(graphId, "LabelTask")) {
mergeEntityToDatasets(graphId, task, "dataset_id", datasetMap, result, syncId);
}
// Workflow → Dataset (via input_dataset_ids, multi-value)
for (GraphEntity workflow : entityRepository.findByGraphIdAndType(graphId, "Workflow")) {
mergeEntityToDatasets(graphId, workflow, "input_dataset_ids", datasetMap, result, syncId);
}
return endResult(result);
}
/**
* 统一处理实体到 Dataset 的 USES_DATASET 关系构建。
* 通过 {@link #toStringList} 兼容单值(String)和多值(List)属性。
* 使用预加载的 datasetMap 避免 N+1 查询。
*/
private void mergeEntityToDatasets(String graphId, GraphEntity entity, String propertyKey,
Map<String, String> datasetMap,
SyncResult result, String syncId) {
try {
Object value = entity.getProperties().get(propertyKey);
if (value == null) {
return;
}
List<String> datasetIds = toStringList(value);
for (String dsId : datasetIds) {
String datasetEntityId = datasetMap.get(dsId);
if (datasetEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, entity.getId(), datasetEntityId,
"USES_DATASET", "{\"usage_role\":\"INPUT\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
}
} catch (Exception e) {
log.warn("[{}] Failed to merge USES_DATASET for entity: id={}", syncId, entity.getId(), e);
result.addError("uses_dataset:" + entity.getId());
}
}
/**
* 构建 PRODUCES 关系:Job → Dataset(通过 output_dataset_id)。
*/
@Transactional
public SyncResult mergeProducesRelations(String graphId, String syncId) {
SyncResult result = beginResult("PRODUCES", syncId);
Map<String, String> datasetMap = buildSourceIdToEntityIdMap(graphId, "Dataset");
for (GraphEntity job : entityRepository.findByGraphIdAndType(graphId, "Job")) {
try {
Object outputDatasetId = job.getProperties().get("output_dataset_id");
if (outputDatasetId == null || outputDatasetId.toString().isBlank()) {
continue;
}
String datasetEntityId = datasetMap.get(outputDatasetId.toString());
if (datasetEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, job.getId(), datasetEntityId,
"PRODUCES", "{\"output_type\":\"PRIMARY\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge PRODUCES for job: id={}", syncId, job.getId(), e);
result.addError("produces:" + job.getId());
}
}
return endResult(result);
}
/**
* 构建 ASSIGNED_TO 关系:LabelTask/Job → User(通过 createdBy 字段)。
*/
@Transactional
public SyncResult mergeAssignedToRelations(String graphId, String syncId) {
SyncResult result = beginResult("ASSIGNED_TO", syncId);
Map<String, String> userMap = buildSourceIdToEntityIdMap(graphId, "User");
// LabelTask → User
for (GraphEntity task : entityRepository.findByGraphIdAndType(graphId, "LabelTask")) {
mergeCreatorAssignment(graphId, task, "label_task", userMap, result, syncId);
}
// Job → User
for (GraphEntity job : entityRepository.findByGraphIdAndType(graphId, "Job")) {
mergeCreatorAssignment(graphId, job, "job", userMap, result, syncId);
}
return endResult(result);
}
private void mergeCreatorAssignment(String graphId, GraphEntity entity, String entityLabel,
Map<String, String> userMap,
SyncResult result, String syncId) {
try {
Object createdBy = entity.getProperties().get("created_by");
if (createdBy == null || createdBy.toString().isBlank()) {
return;
}
String userSourceId = "user:" + createdBy;
String userEntityId = userMap.get(userSourceId);
if (userEntityId == null) {
result.incrementSkipped();
return;
}
boolean created = mergeRelation(graphId, entity.getId(), userEntityId,
"ASSIGNED_TO", "{\"role\":\"OWNER\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge ASSIGNED_TO for {}: id={}", syncId, entityLabel, entity.getId(), e);
result.addError("assigned_to:" + entityLabel + ":" + entity.getId());
}
}
/**
* 构建 TRIGGERS 关系:Workflow → Job(通过 Job.workflow_id)。
*/
@Transactional
public SyncResult mergeTriggersRelations(String graphId, String syncId) {
SyncResult result = beginResult("TRIGGERS", syncId);
Map<String, String> workflowMap = buildSourceIdToEntityIdMap(graphId, "Workflow");
for (GraphEntity job : entityRepository.findByGraphIdAndType(graphId, "Job")) {
try {
Object workflowId = job.getProperties().get("workflow_id");
if (workflowId == null || workflowId.toString().isBlank()) {
continue;
}
String workflowEntityId = workflowMap.get(workflowId.toString());
if (workflowEntityId == null) {
result.incrementSkipped();
continue;
}
// 方向:Workflow → Job
boolean created = mergeRelation(graphId, workflowEntityId, job.getId(),
"TRIGGERS", "{\"trigger_type\":\"MANUAL\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge TRIGGERS for job: id={}", syncId, job.getId(), e);
result.addError("triggers:" + job.getId());
}
}
return endResult(result);
}
/**
* 构建 DEPENDS_ON 关系:Job → Job(通过 Job.depends_on_job_id)。
*/
@Transactional
public SyncResult mergeDependsOnRelations(String graphId, String syncId) {
SyncResult result = beginResult("DEPENDS_ON", syncId);
Map<String, String> jobMap = buildSourceIdToEntityIdMap(graphId, "Job");
for (GraphEntity job : entityRepository.findByGraphIdAndType(graphId, "Job")) {
try {
Object depJobId = job.getProperties().get("depends_on_job_id");
if (depJobId == null || depJobId.toString().isBlank()) {
continue;
}
String depJobEntityId = jobMap.get(depJobId.toString());
if (depJobEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, job.getId(), depJobEntityId,
"DEPENDS_ON", "{\"dependency_type\":\"STRICT\"}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
} catch (Exception e) {
log.warn("[{}] Failed to merge DEPENDS_ON for job: id={}", syncId, job.getId(), e);
result.addError("depends_on:" + job.getId());
}
}
return endResult(result);
}
/**
* 构建 IMPACTS 关系:Field → Field。
* <p>
* TODO: 字段影响关系来源于 LLM 抽取或规则引擎,而非简单外键关联。
* 当前 MVP 阶段为占位实现,后续由抽取模块填充。
*/
@Transactional
public SyncResult mergeImpactsRelations(String graphId, String syncId) {
SyncResult result = beginResult("IMPACTS", syncId);
result.setPlaceholder(true);
log.debug("[{}] IMPACTS relations require extraction data, skipping in sync phase", syncId);
return endResult(result);
}
/**
* 构建 SOURCED_FROM 关系:KnowledgeSet → Dataset(通过 source_dataset_ids)。
*/
@Transactional
public SyncResult mergeSourcedFromRelations(String graphId, String syncId) {
SyncResult result = beginResult("SOURCED_FROM", syncId);
Map<String, String> datasetMap = buildSourceIdToEntityIdMap(graphId, "Dataset");
for (GraphEntity ks : entityRepository.findByGraphIdAndType(graphId, "KnowledgeSet")) {
try {
Object sourceIds = ks.getProperties().get("source_dataset_ids");
if (sourceIds == null) {
continue;
}
List<String> datasetIds = toStringList(sourceIds);
for (String dsId : datasetIds) {
String datasetEntityId = datasetMap.get(dsId);
if (datasetEntityId == null) {
result.incrementSkipped();
continue;
}
boolean created = mergeRelation(graphId, ks.getId(), datasetEntityId,
"SOURCED_FROM", "{}", syncId);
if (created) { result.incrementCreated(); } else { result.incrementSkipped(); }
}
} catch (Exception e) {
log.warn("[{}] Failed to merge SOURCED_FROM for knowledge set: id={}", syncId, ks.getId(), e);
result.addError("sourced_from:" + ks.getId());
}
}
return endResult(result);
}
// -----------------------------------------------------------------------
// 内部方法
// -----------------------------------------------------------------------
/**
* 使用单条 Cypher MERGE 原子创建或匹配实体,同时写入扩展属性。
* <p>
* 相比之前的 MERGE + find + save(3 次 DB 调用),
* 现在合并为单条 Cypher(1 次 DB 调用),消除 N+1 性能问题。
* <p>
* 扩展属性通过 SDN composite property 格式存储({@code properties.key}),
* 属性键经过字符白名单过滤,防止 Cypher 注入。
*/
private void upsertEntity(String graphId, String sourceId, String type,
String name, String description,
Map<String, Object> props, SyncResult result) {
String newId = UUID.randomUUID().toString();
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("sourceId", sourceId);
params.put("type", type);
params.put("newId", newId);
params.put("name", name != null ? name : "");
params.put("description", description != null ? description : "");
// 构建扩展属性的 SET 子句,使用 SDN composite property 格式(properties.key)
StringBuilder propSetClauses = new StringBuilder();
if (props != null) {
int idx = 0;
for (Map.Entry<String, Object> entry : props.entrySet()) {
if (entry.getValue() != null) {
String sanitizedKey = sanitizePropertyKey(entry.getKey());
if (sanitizedKey.isEmpty()) {
continue;
}
String paramName = "prop" + idx++;
propSetClauses.append(", e.`properties.").append(sanitizedKey).append("` = $").append(paramName);
params.put(paramName, toNeo4jValue(entry.getValue()));
}
}
}
String extraSet = propSetClauses.toString();
Boolean isNew = neo4jClient.query(
"MERGE (e:Entity {graph_id: $graphId, source_id: $sourceId, type: $type}) " +
"ON CREATE SET e.id = $newId, e.source_type = 'SYNC', e.confidence = 1.0, " +
" e.name = $name, e.description = $description, " +
" e.created_at = datetime(), e.updated_at = datetime()" + extraSet + " " +
"ON MATCH SET e.name = $name, e.description = $description, " +
" e.updated_at = datetime()" + extraSet + " " +
"RETURN e.id = $newId AS isNew"
)
.bindAll(params)
.fetchAs(Boolean.class)
.mappedBy((ts, record) -> record.get("isNew").asBoolean())
.one()
.orElse(false);
if (isNew) {
result.incrementCreated();
} else {
result.incrementUpdated();
}
}
/**
* 清理属性键,仅允许字母、数字和下划线,防止 Cypher 注入。
*/
private static String sanitizePropertyKey(String key) {
return key.replaceAll("[^a-zA-Z0-9_]", "");
}
/**
* 将 Java 值转换为 Neo4j 兼容的属性值。
* <p>
* Neo4j 属性值必须为原始类型或同类型列表,
* 不支持嵌套 Map 或异构列表。
*/
private static Object toNeo4jValue(Object value) {
if (value instanceof List<?> list) {
if (list.isEmpty()) {
return List.of();
}
// Neo4j 要求列表元素类型一致,统一转为 String 列表;过滤 null 防止脏数据
return list.stream()
.filter(Objects::nonNull)
.map(Object::toString)
.toList();
}
return value;
}
/**
* 使用 Cypher MERGE 创建或匹配关系,保证幂等性。
*
* @return true 如果是新创建的关系,false 如果已存在
*/
/**
* 将属性值(可能是 List 或单个 String)安全转换为 String 列表。
*/
@SuppressWarnings("unchecked")
private static List<String> toStringList(Object value) {
if (value instanceof List<?> list) {
return list.stream()
.filter(Objects::nonNull)
.map(Object::toString)
.filter(s -> !s.isBlank())
.toList();
}
if (value instanceof String str && !str.isBlank()) {
return List.of(str);
}
return List.of();
}
private boolean mergeRelation(String graphId, String sourceEntityId, String targetEntityId,
String relationType, String propertiesJson, String syncId) {
String newId = UUID.randomUUID().toString();
String mergedId = neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId, id: $sourceEntityId}) " +
"MATCH (t:Entity {graph_id: $graphId, id: $targetEntityId}) " +
"MERGE (s)-[r:" + REL_TYPE + " {graph_id: $graphId, relation_type: $relationType}]->(t) " +
"ON CREATE SET r.id = $newId, r.weight = 1.0, r.confidence = 1.0, " +
" r.source_id = '', r.properties_json = $propertiesJson, r.created_at = datetime() " +
"RETURN r.id AS relId"
)
.bindAll(Map.of(
"graphId", graphId,
"sourceEntityId", sourceEntityId,
"targetEntityId", targetEntityId,
"relationType", relationType,
"newId", newId,
"propertiesJson", propertiesJson
))
.fetchAs(String.class)
.mappedBy((ts, record) -> record.get("relId").asString())
.one()
.orElse(null);
return newId.equals(mergedId);
}
private SyncResult beginResult(String syncType, String syncId) {
return SyncResult.builder()
.syncType(syncType)
.syncId(syncId)
.startedAt(LocalDateTime.now())
.errors(new ArrayList<>())
.build();
}
private SyncResult endResult(SyncResult result) {
result.setCompletedAt(LocalDateTime.now());
return result;
}
/**
* 预加载指定类型的 sourceId → entityId 映射,消除关系构建中的 N+1 查询。
*/
private Map<String, String> buildSourceIdToEntityIdMap(String graphId, String type) {
return entityRepository.findByGraphIdAndType(graphId, type).stream()
.filter(e -> e.getSourceId() != null)
.collect(Collectors.toMap(GraphEntity::getSourceId, GraphEntity::getId, (a, b) -> a));
}
}

View File

@@ -0,0 +1,81 @@
package com.datamate.knowledgegraph.domain.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.data.neo4j.core.schema.DynamicLabels;
import org.springframework.data.neo4j.core.schema.GeneratedValue;
import org.springframework.data.neo4j.core.schema.Id;
import org.springframework.data.neo4j.core.schema.Node;
import org.springframework.data.neo4j.core.schema.Property;
import org.springframework.data.neo4j.core.support.UUIDStringGenerator;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 知识图谱实体节点。
* <p>
* 在 Neo4j 中,每个实体作为一个节点存储,
* 通过 {@code type} 属性区分具体类型(Person, Organization, Concept 等),
* 并支持通过 {@code properties} 存储灵活的扩展属性。
*/
@Node("Entity")
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class GraphEntity {
@Id
@GeneratedValue(UUIDStringGenerator.class)
private String id;
@Property("name")
private String name;
@Property("type")
private String type;
@Property("description")
private String description;
@DynamicLabels
@Builder.Default
private List<String> labels = new ArrayList<>();
@Property("aliases")
@Builder.Default
private List<String> aliases = new ArrayList<>();
@Property("properties")
@Builder.Default
private Map<String, Object> properties = new HashMap<>();
/** 来源数据集/知识库的 ID */
@Property("source_id")
private String sourceId;
/** 来源类型:ANNOTATION, KNOWLEDGE_BASE, IMPORT, MANUAL */
@Property("source_type")
private String sourceType;
/** 所属图谱 ID(对应 MySQL 中的 t_dm_knowledge_graphs.id) */
@Property("graph_id")
private String graphId;
/** 自动抽取的置信度 */
@Property("confidence")
@Builder.Default
private Double confidence = 1.0;
@Property("created_at")
private LocalDateTime createdAt;
@Property("updated_at")
private LocalDateTime updatedAt;
}

View File

@@ -0,0 +1,61 @@
package com.datamate.knowledgegraph.domain.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.data.neo4j.core.schema.GeneratedValue;
import org.springframework.data.neo4j.core.schema.Id;
import org.springframework.data.neo4j.core.schema.Property;
import org.springframework.data.neo4j.core.schema.RelationshipProperties;
import org.springframework.data.neo4j.core.schema.TargetNode;
import org.springframework.data.neo4j.core.support.UUIDStringGenerator;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
/**
* 知识图谱关系(边)。
* <p>
* 使用 Spring Data Neo4j 的 {@code @RelationshipProperties} 表示带属性的关系。
* 关系的具体类型通过 {@code relationType} 表达(如 belongs_to, located_in)。
*/
@RelationshipProperties
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class GraphRelation {
@Id
@GeneratedValue(UUIDStringGenerator.class)
private String id;
@TargetNode
private GraphEntity target;
@Property("relation_type")
private String relationType;
@Property("properties")
@Builder.Default
private Map<String, Object> properties = new HashMap<>();
@Property("weight")
@Builder.Default
private Double weight = 1.0;
@Property("source_id")
private String sourceId;
@Property("confidence")
@Builder.Default
private Double confidence = 1.0;
@Property("graph_id")
private String graphId;
@Property("created_at")
private LocalDateTime createdAt;
}

View File

@@ -0,0 +1,54 @@
package com.datamate.knowledgegraph.domain.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
/**
* 关系及其端点实体摘要,用于仓储层查询返回。
* <p>
* 由于 {@link GraphRelation} 使用 {@code @RelationshipProperties} 且仅持有
* 目标节点引用,无法完整表达 Cypher 查询返回的"源节点 + 关系 + 目标节点"结构,
* 因此使用该领域对象作为仓储层的返回类型。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class RelationDetail {
private String id;
private String sourceEntityId;
private String sourceEntityName;
private String sourceEntityType;
private String targetEntityId;
private String targetEntityName;
private String targetEntityType;
private String relationType;
@Builder.Default
private Map<String, Object> properties = new HashMap<>();
private Double weight;
private Double confidence;
/** 来源数据集/知识库的 ID */
private String sourceId;
private String graphId;
private LocalDateTime createdAt;
}

View File

@@ -0,0 +1,81 @@
package com.datamate.knowledgegraph.domain.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
/**
* 同步操作结果统计。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SyncResult {
/** 本次同步的追踪标识 */
private String syncId;
/** 同步的实体/关系类型 */
private String syncType;
@Builder.Default
private int created = 0;
@Builder.Default
private int updated = 0;
@Builder.Default
private int skipped = 0;
@Builder.Default
private int failed = 0;
/** 全量对账删除的过期实体数 */
@Builder.Default
private int purged = 0;
/** 标记为占位符的步骤(功能尚未实现,结果无实际数据) */
@Builder.Default
private boolean placeholder = false;
@Builder.Default
private List<String> errors = new ArrayList<>();
private LocalDateTime startedAt;
private LocalDateTime completedAt;
public int total() {
return created + updated + skipped + failed;
}
public long durationMillis() {
if (startedAt == null || completedAt == null) {
return 0;
}
return java.time.Duration.between(startedAt, completedAt).toMillis();
}
public void incrementCreated() {
created++;
}
public void incrementUpdated() {
updated++;
}
public void incrementSkipped() {
skipped++;
}
public void addError(String error) {
failed++;
errors.add(error);
}
}

View File

@@ -0,0 +1,103 @@
package com.datamate.knowledgegraph.domain.repository;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import org.springframework.data.neo4j.repository.Neo4jRepository;
import org.springframework.data.neo4j.repository.query.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Optional;
@Repository
public interface GraphEntityRepository extends Neo4jRepository<GraphEntity, String> {
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.id = $entityId RETURN e")
Optional<GraphEntity> findByIdAndGraphId(
@Param("entityId") String entityId,
@Param("graphId") String graphId);
List<GraphEntity> findByGraphId(String graphId);
List<GraphEntity> findByGraphIdAndType(String graphId, String type);
List<GraphEntity> findByGraphIdAndNameContaining(String graphId, String name);
@Query("MATCH (e:Entity {graph_id: $graphId}) " +
"WHERE e.name = $name AND e.type = $type " +
"RETURN e")
List<GraphEntity> findByGraphIdAndNameAndType(
@Param("graphId") String graphId,
@Param("name") String name,
@Param("type") String type);
@Query("MATCH p = (e:Entity {graph_id: $graphId, id: $entityId})-[*1..$depth]-(neighbor:Entity) " +
"WHERE e <> neighbor " +
" AND ALL(n IN nodes(p) WHERE n.graph_id = $graphId) " +
" AND ALL(r IN relationships(p) WHERE r.graph_id = $graphId) " +
"RETURN DISTINCT neighbor LIMIT $limit")
List<GraphEntity> findNeighbors(
@Param("graphId") String graphId,
@Param("entityId") String entityId,
@Param("depth") int depth,
@Param("limit") int limit);
@Query("MATCH (e:Entity {graph_id: $graphId}) RETURN count(e)")
long countByGraphId(@Param("graphId") String graphId);
@Query("MATCH (e:Entity {graph_id: $graphId}) " +
"WHERE e.source_id = $sourceId AND e.type = $type " +
"RETURN e")
Optional<GraphEntity> findByGraphIdAndSourceIdAndType(
@Param("graphId") String graphId,
@Param("sourceId") String sourceId,
@Param("type") String type);
// -----------------------------------------------------------------------
// 分页查询
// -----------------------------------------------------------------------
@Query("MATCH (e:Entity {graph_id: $graphId}) " +
"RETURN e ORDER BY e.created_at DESC SKIP $skip LIMIT $limit")
List<GraphEntity> findByGraphIdPaged(
@Param("graphId") String graphId,
@Param("skip") long skip,
@Param("limit") int limit);
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.type = $type " +
"RETURN e ORDER BY e.created_at DESC SKIP $skip LIMIT $limit")
List<GraphEntity> findByGraphIdAndTypePaged(
@Param("graphId") String graphId,
@Param("type") String type,
@Param("skip") long skip,
@Param("limit") int limit);
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.type = $type " +
"RETURN count(e)")
long countByGraphIdAndType(
@Param("graphId") String graphId,
@Param("type") String type);
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.name CONTAINS $name " +
"RETURN e ORDER BY e.created_at DESC SKIP $skip LIMIT $limit")
List<GraphEntity> findByGraphIdAndNameContainingPaged(
@Param("graphId") String graphId,
@Param("name") String name,
@Param("skip") long skip,
@Param("limit") int limit);
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.name CONTAINS $name " +
"RETURN count(e)")
long countByGraphIdAndNameContaining(
@Param("graphId") String graphId,
@Param("name") String name);
// -----------------------------------------------------------------------
// 图查询
// -----------------------------------------------------------------------
@Query("MATCH (e:Entity {graph_id: $graphId}) WHERE e.id IN $entityIds RETURN e")
List<GraphEntity> findByGraphIdAndIdIn(
@Param("graphId") String graphId,
@Param("entityIds") List<String> entityIds);
}

View File

@@ -0,0 +1,502 @@
package com.datamate.knowledgegraph.domain.repository;
import com.datamate.knowledgegraph.domain.model.RelationDetail;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.Value;
import org.neo4j.driver.types.MapAccessor;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.stereotype.Repository;
import java.time.LocalDateTime;
import java.util.*;
/**
* 知识图谱关系仓储。
* <p>
* 由于 {@code GraphRelation} 使用 {@code @RelationshipProperties},
* 无法通过 {@code Neo4jRepository} 直接管理,
* 因此使用 {@code Neo4jClient} 执行 Cypher 查询实现 CRUD。
* <p>
* Neo4j 中使用统一的 {@code RELATED_TO} 关系类型,
* 语义类型通过 {@code relation_type} 属性区分。
* 扩展属性(properties)序列化为 JSON 字符串存储在 {@code properties_json} 属性中。
*/
@Repository
@Slf4j
@RequiredArgsConstructor
public class GraphRelationRepository {
private static final String REL_TYPE = "RELATED_TO";
private static final TypeReference<Map<String, Object>> MAP_TYPE = new TypeReference<>() {};
private static final ObjectMapper MAPPER = new ObjectMapper();
/** 查询返回列(源节点 + 关系 + 目标节点)。 */
private static final String RETURN_COLUMNS =
"RETURN r, " +
"s.id AS sourceEntityId, s.name AS sourceEntityName, s.type AS sourceEntityType, " +
"t.id AS targetEntityId, t.name AS targetEntityName, t.type AS targetEntityType";
private final Neo4jClient neo4jClient;
// -----------------------------------------------------------------------
// 查询
// -----------------------------------------------------------------------
public Optional<RelationDetail> findByIdAndGraphId(String relationId, String graphId) {
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {id: $relationId, graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId}) " +
RETURN_COLUMNS
)
.bindAll(Map.of("graphId", graphId, "relationId", relationId))
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.one();
}
public List<RelationDetail> findByGraphId(String graphId, String type, long skip, int size) {
String typeFilter = (type != null && !type.isBlank())
? "AND r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("type", type != null ? type : "");
params.put("skip", skip);
params.put("size", size);
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId}) " +
"WHERE true " + typeFilter +
RETURN_COLUMNS + " " +
"ORDER BY r.created_at DESC " +
"SKIP $skip LIMIT $size"
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
/**
* 查询实体的所有关系(出边 + 入边)。
* <p>
* 使用 {@code CALL{UNION ALL}} 分别锚定出边和入边查询,
* 避免全图扫描后再过滤的性能瓶颈。
* {@code WITH DISTINCT} 处理自环关系的去重。
*/
public List<RelationDetail> findByEntityId(String graphId, String entityId, String type,
long skip, int size) {
String typeFilter = (type != null && !type.isBlank())
? "WHERE r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("entityId", entityId);
params.put("type", type != null ? type : "");
params.put("skip", skip);
params.put("size", size);
return neo4jClient
.query(
"CALL { " +
"MATCH (s:Entity {graph_id: $graphId, id: $entityId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId}) " +
typeFilter +
"RETURN r, s, t " +
"UNION ALL " +
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId, id: $entityId}) " +
typeFilter +
"RETURN r, s, t " +
"} " +
"WITH DISTINCT r, s, t " +
"ORDER BY r.created_at DESC SKIP $skip LIMIT $size " +
RETURN_COLUMNS
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
/**
* 查询实体的入边关系(该实体为目标节点)。
*/
public List<RelationDetail> findInboundByEntityId(String graphId, String entityId, String type,
long skip, int size) {
String typeFilter = (type != null && !type.isBlank())
? "AND r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("entityId", entityId);
params.put("type", type != null ? type : "");
params.put("skip", skip);
params.put("size", size);
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId, id: $entityId}) " +
"WHERE true " + typeFilter +
RETURN_COLUMNS + " " +
"ORDER BY r.created_at DESC " +
"SKIP $skip LIMIT $size"
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
/**
* 查询实体的出边关系(该实体为源节点)。
*/
public List<RelationDetail> findOutboundByEntityId(String graphId, String entityId, String type,
long skip, int size) {
String typeFilter = (type != null && !type.isBlank())
? "AND r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("entityId", entityId);
params.put("type", type != null ? type : "");
params.put("skip", skip);
params.put("size", size);
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId, id: $entityId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId}) " +
"WHERE true " + typeFilter +
RETURN_COLUMNS + " " +
"ORDER BY r.created_at DESC " +
"SKIP $skip LIMIT $size"
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
/**
* 统计实体的关系数量。
* <p>
* 各方向均以实体锚定 MATCH 模式,避免全图扫描。
* "all" 方向使用 {@code CALL{UNION}} 自动去重自环关系。
*
* @param direction "all"、"in" 或 "out"
*/
public long countByEntityId(String graphId, String entityId, String type, String direction) {
String typeFilter = (type != null && !type.isBlank())
? "WHERE r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("entityId", entityId);
params.put("type", type != null ? type : "");
String cypher;
switch (direction) {
case "in":
cypher = "MATCH (:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId, id: $entityId}) " +
typeFilter +
"RETURN count(r) AS cnt";
break;
case "out":
cypher = "MATCH (:Entity {graph_id: $graphId, id: $entityId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId}) " +
typeFilter +
"RETURN count(r) AS cnt";
break;
default:
cypher = "CALL { " +
"MATCH (:Entity {graph_id: $graphId, id: $entityId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId}) " +
typeFilter +
"RETURN r " +
"UNION " +
"MATCH (:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId, id: $entityId}) " +
typeFilter +
"RETURN r " +
"} " +
"RETURN count(r) AS cnt";
break;
}
return neo4jClient
.query(cypher)
.bindAll(params)
.fetchAs(Long.class)
.mappedBy((typeSystem, record) -> record.get("cnt").asLong())
.one()
.orElse(0L);
}
public List<RelationDetail> findBySourceAndTarget(String graphId, String sourceEntityId, String targetEntityId) {
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId, id: $sourceEntityId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId, id: $targetEntityId}) " +
RETURN_COLUMNS
)
.bindAll(Map.of(
"graphId", graphId,
"sourceEntityId", sourceEntityId,
"targetEntityId", targetEntityId
))
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
public List<RelationDetail> findByType(String graphId, String type) {
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId, relation_type: $type}]->" +
"(t:Entity {graph_id: $graphId}) " +
RETURN_COLUMNS
)
.bindAll(Map.of("graphId", graphId, "type", type))
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.all()
.stream().toList();
}
public long countByGraphId(String graphId, String type) {
String typeFilter = (type != null && !type.isBlank())
? "AND r.relation_type = $type "
: "";
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("type", type != null ? type : "");
return neo4jClient
.query(
"MATCH (:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId}) " +
"WHERE true " + typeFilter +
"RETURN count(r) AS cnt"
)
.bindAll(params)
.fetchAs(Long.class)
.mappedBy((typeSystem, record) -> record.get("cnt").asLong())
.one()
.orElse(0L);
}
// -----------------------------------------------------------------------
// 写入
// -----------------------------------------------------------------------
public Optional<RelationDetail> create(String graphId, String sourceEntityId, String targetEntityId,
String relationType, Map<String, Object> properties,
Double weight, String sourceId, Double confidence) {
String id = UUID.randomUUID().toString();
LocalDateTime now = LocalDateTime.now();
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("sourceEntityId", sourceEntityId);
params.put("targetEntityId", targetEntityId);
params.put("id", id);
params.put("relationType", relationType);
params.put("weight", weight != null ? weight : 1.0);
params.put("confidence", confidence != null ? confidence : 1.0);
params.put("sourceId", sourceId != null ? sourceId : "");
params.put("propertiesJson", serializeProperties(properties));
params.put("createdAt", now);
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId, id: $sourceEntityId}) " +
"MATCH (t:Entity {graph_id: $graphId, id: $targetEntityId}) " +
"CREATE (s)-[r:" + REL_TYPE + " {" +
" id: $id," +
" relation_type: $relationType," +
" weight: $weight," +
" confidence: $confidence," +
" source_id: $sourceId," +
" graph_id: $graphId," +
" properties_json: $propertiesJson," +
" created_at: $createdAt" +
"}]->(t) " +
RETURN_COLUMNS
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.one();
}
public Optional<RelationDetail> update(String relationId, String graphId,
String relationType, Map<String, Object> properties,
Double weight, Double confidence) {
Map<String, Object> params = new HashMap<>();
params.put("graphId", graphId);
params.put("relationId", relationId);
StringBuilder setClauses = new StringBuilder();
if (relationType != null) {
setClauses.append("SET r.relation_type = $relationType ");
params.put("relationType", relationType);
}
if (properties != null) {
setClauses.append("SET r.properties_json = $propertiesJson ");
params.put("propertiesJson", serializeProperties(properties));
}
if (weight != null) {
setClauses.append("SET r.weight = $weight ");
params.put("weight", weight);
}
if (confidence != null) {
setClauses.append("SET r.confidence = $confidence ");
params.put("confidence", confidence);
}
if (setClauses.isEmpty()) {
return findByIdAndGraphId(relationId, graphId);
}
return neo4jClient
.query(
"MATCH (s:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {id: $relationId, graph_id: $graphId}]->" +
"(t:Entity {graph_id: $graphId}) " +
setClauses +
RETURN_COLUMNS
)
.bindAll(params)
.fetchAs(RelationDetail.class)
.mappedBy((typeSystem, record) -> mapRecord(record))
.one();
}
/**
* 删除指定关系,返回实际删除的数量(0 或 1)。
*/
public long deleteByIdAndGraphId(String relationId, String graphId) {
// MATCH 找不到时管道为空行,count(*) 聚合后仍返回 0;
// 找到 1 条时 DELETE 后管道保留该行,count(*) 返回 1。
return neo4jClient
.query(
"MATCH (:Entity {graph_id: $graphId})" +
"-[r:" + REL_TYPE + " {id: $relationId, graph_id: $graphId}]->" +
"(:Entity {graph_id: $graphId}) " +
"DELETE r " +
"RETURN count(*) AS deleted"
)
.bindAll(Map.of("graphId", graphId, "relationId", relationId))
.fetchAs(Long.class)
.mappedBy((typeSystem, record) -> record.get("deleted").asLong())
.one()
.orElse(0L);
}
// -----------------------------------------------------------------------
// 内部映射
// -----------------------------------------------------------------------
private RelationDetail mapRecord(MapAccessor record) {
Value r = record.get("r");
return RelationDetail.builder()
.id(getStringOrNull(r, "id"))
.sourceEntityId(record.get("sourceEntityId").asString(null))
.sourceEntityName(record.get("sourceEntityName").asString(null))
.sourceEntityType(record.get("sourceEntityType").asString(null))
.targetEntityId(record.get("targetEntityId").asString(null))
.targetEntityName(record.get("targetEntityName").asString(null))
.targetEntityType(record.get("targetEntityType").asString(null))
.relationType(getStringOrNull(r, "relation_type"))
.properties(deserializeProperties(getStringOrNull(r, "properties_json")))
.weight(getDoubleOrNull(r, "weight"))
.confidence(getDoubleOrNull(r, "confidence"))
.sourceId(getStringOrNull(r, "source_id"))
.graphId(getStringOrNull(r, "graph_id"))
.createdAt(getLocalDateTimeOrNull(r, "created_at"))
.build();
}
// -----------------------------------------------------------------------
// Properties JSON 序列化
// -----------------------------------------------------------------------
private static String serializeProperties(Map<String, Object> properties) {
if (properties == null || properties.isEmpty()) {
return "{}";
}
try {
return MAPPER.writeValueAsString(properties);
} catch (JsonProcessingException e) {
// 序列化失败不应静默吞掉,向上抛出以暴露数据问题
throw new IllegalArgumentException("Failed to serialize relation properties to JSON", e);
}
}
private static Map<String, Object> deserializeProperties(String json) {
if (json == null || json.isBlank()) {
return new HashMap<>();
}
try {
return MAPPER.readValue(json, MAP_TYPE);
} catch (JsonProcessingException e) {
log.warn("Failed to deserialize properties_json (returning empty map): json='{}', error={}",
json.length() > 100 ? json.substring(0, 100) + "..." : json, e.getMessage());
return new HashMap<>();
}
}
// -----------------------------------------------------------------------
// 字段读取辅助
// -----------------------------------------------------------------------
private static String getStringOrNull(Value value, String key) {
Value v = value.get(key);
return (v == null || v.isNull()) ? null : v.asString();
}
private static Double getDoubleOrNull(Value value, String key) {
Value v = value.get(key);
return (v == null || v.isNull()) ? null : v.asDouble();
}
private static LocalDateTime getLocalDateTimeOrNull(Value value, String key) {
Value v = value.get(key);
return (v == null || v.isNull()) ? null : v.asLocalDateTime();
}
}

View File

@@ -0,0 +1,278 @@
package com.datamate.knowledgegraph.infrastructure.client;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
/**
* 数据管理服务 REST 客户端。
* <p>
* 通过 HTTP 调用 data-management-service 的 REST API,
* 拉取数据集、文件等元数据用于同步到 Neo4j。
*/
@Component
@Slf4j
public class DataManagementClient {
private final RestTemplate restTemplate;
private final String baseUrl;
private final String annotationBaseUrl;
private final int pageSize;
public DataManagementClient(
@Qualifier("kgRestTemplate") RestTemplate restTemplate,
KnowledgeGraphProperties properties) {
this.restTemplate = restTemplate;
this.baseUrl = properties.getSync().getDataManagementUrl();
this.annotationBaseUrl = properties.getSync().getAnnotationServiceUrl();
this.pageSize = properties.getSync().getPageSize();
}
/**
* 拉取所有数据集(自动分页)。
*/
public List<DatasetDTO> listAllDatasets() {
return fetchAllPaged(
baseUrl + "/data-management/datasets",
new ParameterizedTypeReference<PagedResult<DatasetDTO>>() {},
"datasets");
}
/**
* 拉取所有工作流(自动分页)。
*/
public List<WorkflowDTO> listAllWorkflows() {
return fetchAllPaged(
baseUrl + "/data-management/workflows",
new ParameterizedTypeReference<PagedResult<WorkflowDTO>>() {},
"workflows");
}
/**
* 拉取所有作业(自动分页)。
*/
public List<JobDTO> listAllJobs() {
return fetchAllPaged(
baseUrl + "/data-management/jobs",
new ParameterizedTypeReference<PagedResult<JobDTO>>() {},
"jobs");
}
/**
* 拉取所有标注任务(自动分页,从标注服务)。
*/
public List<LabelTaskDTO> listAllLabelTasks() {
return fetchAllPaged(
annotationBaseUrl + "/annotation/label-tasks",
new ParameterizedTypeReference<PagedResult<LabelTaskDTO>>() {},
"label-tasks");
}
/**
* 拉取所有知识集(自动分页)。
*/
public List<KnowledgeSetDTO> listAllKnowledgeSets() {
return fetchAllPaged(
baseUrl + "/data-management/knowledge-sets",
new ParameterizedTypeReference<PagedResult<KnowledgeSetDTO>>() {},
"knowledge-sets");
}
/**
* 通用自动分页拉取方法。
*/
private <T> List<T> fetchAllPaged(String baseEndpoint,
ParameterizedTypeReference<PagedResult<T>> typeRef,
String resourceName) {
List<T> allItems = new ArrayList<>();
int page = 0;
while (true) {
String url = baseEndpoint + "?page=" + page + "&size=" + pageSize;
log.debug("Fetching {}: page={}, size={}", resourceName, page, pageSize);
try {
ResponseEntity<PagedResult<T>> response = restTemplate.exchange(
url, HttpMethod.GET, null, typeRef);
PagedResult<T> body = response.getBody();
if (body == null || body.getContent() == null || body.getContent().isEmpty()) {
break;
}
allItems.addAll(body.getContent());
log.debug("Fetched {} {} (page {}), total so far: {}",
body.getContent().size(), resourceName, page, allItems.size());
if (page >= body.getTotalPages() - 1) {
break;
}
page++;
} catch (RestClientException e) {
log.error("Failed to fetch {} : page={}, url={}", resourceName, page, url, e);
throw e;
}
}
log.info("Fetched {} {} in total", allItems.size(), resourceName);
return allItems;
}
// -----------------------------------------------------------------------
// 响应 DTO(仅包含同步所需字段)
// -----------------------------------------------------------------------
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class PagedResult<T> {
private List<T> content;
private long page;
private long totalElements;
private long totalPages;
}
/**
* 与 data-management-service 的 DatasetResponse 对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class DatasetDTO {
private String id;
private String name;
private String description;
private String parentDatasetId;
private String datasetType;
private String status;
private Long totalSize;
private Integer fileCount;
private String createdBy;
private String updatedBy;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
private List<TagDTO> tags;
}
/**
* 与 data-management-service 的 TagResponse 对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class TagDTO {
private String id;
private String name;
private String color;
private String description;
}
/**
* 与 data-management-service / data-cleaning-service 的 Workflow 对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class WorkflowDTO {
private String id;
private String name;
private String description;
private String workflowType;
private String status;
private String version;
private Integer operatorCount;
private String schedule;
private String createdBy;
private String updatedBy;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
/** 工作流使用的输入数据集 ID 列表 */
private List<String> inputDatasetIds;
}
/**
* 与 data-management-service 的 Job / CleaningTask / DataSynthInstance 等对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class JobDTO {
private String id;
private String name;
private String description;
private String jobType;
private String status;
private String startedAt;
private String completedAt;
private Long durationSeconds;
private Long inputCount;
private Long outputCount;
private String errorMessage;
private String createdBy;
private String updatedBy;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
/** 输入数据集 ID */
private String inputDatasetId;
/** 输出数据集 ID */
private String outputDatasetId;
/** 所属工作流 ID(TRIGGERS 关系) */
private String workflowId;
/** 依赖的作业 ID(DEPENDS_ON 关系) */
private String dependsOnJobId;
}
/**
* 与 data-annotation-service 的 LabelingProject / AutoAnnotationTask 对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class LabelTaskDTO {
private String id;
private String name;
private String description;
private String taskMode;
private String dataType;
private String labelingType;
private String status;
private Double progress;
private String templateName;
private String createdBy;
private String updatedBy;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
/** 标注使用的数据集 ID(USES_DATASET 关系) */
private String datasetId;
}
/**
* 与 data-management-service 的 KnowledgeSet 对齐。
*/
@Data
@JsonIgnoreProperties(ignoreUnknown = true)
public static class KnowledgeSetDTO {
private String id;
private String name;
private String description;
private String status;
private String domain;
private String businessLine;
private String sensitivity;
private Integer itemCount;
private String validFrom;
private String validTo;
private String createdBy;
private String updatedBy;
private LocalDateTime createdAt;
private LocalDateTime updatedAt;
/** 来源数据集 ID 列表(SOURCED_FROM 关系) */
private List<String> sourceDatasetIds;
}
}

View File

@@ -0,0 +1,29 @@
package com.datamate.knowledgegraph.infrastructure.exception;
import com.datamate.common.infrastructure.exception.ErrorCode;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 知识图谱模块错误码
*/
@Getter
@AllArgsConstructor
public enum KnowledgeGraphErrorCode implements ErrorCode {
ENTITY_NOT_FOUND("knowledge_graph.0001", "实体不存在"),
RELATION_NOT_FOUND("knowledge_graph.0002", "关系不存在"),
GRAPH_NOT_FOUND("knowledge_graph.0003", "图谱不存在"),
DUPLICATE_ENTITY("knowledge_graph.0004", "实体已存在"),
INVALID_RELATION("knowledge_graph.0005", "无效的关系定义"),
IMPORT_FAILED("knowledge_graph.0006", "图谱导入失败"),
QUERY_DEPTH_EXCEEDED("knowledge_graph.0007", "查询深度超出限制"),
MAX_NODES_EXCEEDED("knowledge_graph.0008", "查询结果节点数超出限制"),
SYNC_FAILED("knowledge_graph.0009", "数据同步失败"),
EMPTY_SNAPSHOT_PURGE_BLOCKED("knowledge_graph.0010", "空快照保护:上游返回空列表,已阻止 purge 操作"),
SCHEMA_INIT_FAILED("knowledge_graph.0011", "图谱 Schema 初始化失败"),
INSECURE_DEFAULT_CREDENTIALS("knowledge_graph.0012", "检测到默认凭据,生产环境禁止使用默认密码");
private final String code;
private final String message;
}

View File

@@ -0,0 +1,155 @@
package com.datamate.knowledgegraph.infrastructure.neo4j;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.core.annotation.Order;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Set;
/**
* 图谱 Schema 初始化器。
* <p>
* 应用启动时自动创建 Neo4j 索引和约束。
* 所有语句使用 {@code IF NOT EXISTS},保证幂等性。
* <p>
* 对应 {@code docs/knowledge-graph/schema/schema.cypher} 中的第 1-3 部分。
* <p>
* <b>安全自检</b>:在非开发环境中,检测到默认 Neo4j 密码时拒绝启动。
*/
@Component
@Slf4j
@RequiredArgsConstructor
@Order(1)
public class GraphInitializer implements ApplicationRunner {
/** 已知的弱默认密码,启动时拒绝。 */
private static final Set<String> BLOCKED_DEFAULT_PASSWORDS = Set.of(
"datamate123", "neo4j", "password", "123456", "admin"
);
/** 仅识别「已存在」类错误消息的关键词,其余错误不应吞掉。 */
private static final Set<String> ALREADY_EXISTS_KEYWORDS = Set.of(
"already exists", "already exist", "EquivalentSchemaRuleAlreadyExists"
);
private final Neo4jClient neo4jClient;
private final KnowledgeGraphProperties properties;
@Value("${spring.neo4j.authentication.password:}")
private String neo4jPassword;
@Value("${spring.profiles.active:default}")
private String activeProfile;
/**
* 需要在启动时执行的 Cypher 语句。
* 每条语句必须独立执行(Neo4j 不支持多条 DDL 在同一事务中)。
*/
private static final List<String> SCHEMA_STATEMENTS = List.of(
// 约束(自动创建对应索引)
"CREATE CONSTRAINT entity_id_unique IF NOT EXISTS FOR (n:Entity) REQUIRE n.id IS UNIQUE",
// 同步 upsert 复合唯一约束:防止并发写入产生重复实体
"CREATE CONSTRAINT entity_sync_unique IF NOT EXISTS " +
"FOR (n:Entity) REQUIRE (n.graph_id, n.source_id, n.type) IS UNIQUE",
// 单字段索引
"CREATE INDEX entity_graph_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id)",
"CREATE INDEX entity_type IF NOT EXISTS FOR (n:Entity) ON (n.type)",
"CREATE INDEX entity_name IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX entity_source_id IF NOT EXISTS FOR (n:Entity) ON (n.source_id)",
"CREATE INDEX entity_created_at IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
// 复合索引
"CREATE INDEX entity_graph_id_type IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.type)",
"CREATE INDEX entity_graph_id_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.id)",
"CREATE INDEX entity_graph_id_source_id IF NOT EXISTS FOR (n:Entity) ON (n.graph_id, n.source_id)",
// 全文索引
"CREATE FULLTEXT INDEX entity_fulltext IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.description]"
);
@Override
public void run(ApplicationArguments args) {
// ── 安全自检:默认凭据检测 ──
validateCredentials();
if (!properties.getSync().isAutoInitSchema()) {
log.info("Schema auto-init is disabled, skipping");
return;
}
log.info("Initializing Neo4j schema: {} statements to execute", SCHEMA_STATEMENTS.size());
int succeeded = 0;
int failed = 0;
for (String statement : SCHEMA_STATEMENTS) {
try {
neo4jClient.query(statement).run();
succeeded++;
log.debug("Schema statement executed: {}", truncate(statement));
} catch (Exception e) {
if (isAlreadyExistsError(e)) {
// 约束/索引已存在,安全跳过
succeeded++;
log.debug("Schema element already exists (safe to skip): {}", truncate(statement));
} else {
// 非「已存在」错误:记录并抛出,阻止启动
failed++;
log.error("Schema statement FAILED: {} — {}", truncate(statement), e.getMessage());
throw new IllegalStateException(
"Neo4j schema initialization failed: " + truncate(statement), e);
}
}
}
log.info("Neo4j schema initialization completed: succeeded={}, failed={}", succeeded, failed);
}
/**
* 检测是否使用了默认凭据。
* <p>
* 在 dev/test 环境中仅发出警告,在其他环境(prod、staging 等)中直接拒绝启动。
*/
private void validateCredentials() {
if (neo4jPassword == null || neo4jPassword.isBlank()) {
return;
}
if (BLOCKED_DEFAULT_PASSWORDS.contains(neo4jPassword)) {
boolean isDev = activeProfile.contains("dev") || activeProfile.contains("test")
|| activeProfile.contains("local");
if (isDev) {
log.warn("⚠ Neo4j is using a WEAK DEFAULT password. "
+ "This is acceptable in dev/test but MUST be changed for production.");
} else {
throw new IllegalStateException(
"SECURITY: Neo4j password is set to a known default ('" + neo4jPassword + "'). "
+ "Production environments MUST use a strong, unique password. "
+ "Set the NEO4J_PASSWORD environment variable to a secure value.");
}
}
}
/**
* 判断异常是否仅因为 Schema 元素已存在(安全可忽略)。
*/
private static boolean isAlreadyExistsError(Exception e) {
String msg = e.getMessage();
if (msg == null) {
return false;
}
String lowerMsg = msg.toLowerCase();
return ALREADY_EXISTS_KEYWORDS.stream().anyMatch(kw -> lowerMsg.contains(kw.toLowerCase()));
}
private static String truncate(String s) {
return s.length() <= 100 ? s : s.substring(0, 97) + "...";
}
}

View File

@@ -0,0 +1,63 @@
package com.datamate.knowledgegraph.infrastructure.neo4j;
import jakarta.validation.constraints.Min;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import org.springframework.validation.annotation.Validated;
@Data
@Component
@Validated
@ConfigurationProperties(prefix = "datamate.knowledge-graph")
public class KnowledgeGraphProperties {
/** 默认查询跳数限制 */
private int maxDepth = 3;
/** 子图返回最大节点数 */
private int maxNodesPerQuery = 500;
/** 批量导入批次大小(必须 >= 1,否则取模运算会抛异常) */
@Min(value = 1, message = "importBatchSize 必须 >= 1")
private int importBatchSize = 100;
/** 同步相关配置 */
private Sync sync = new Sync();
@Data
public static class Sync {
/** 数据管理服务基础 URL */
private String dataManagementUrl = "http://localhost:8080";
/** 标注服务基础 URL */
private String annotationServiceUrl = "http://localhost:8081";
/** 同步每页拉取数量 */
private int pageSize = 200;
/** HTTP 连接超时(毫秒) */
private int connectTimeout = 5000;
/** HTTP 读取超时(毫秒) */
private int readTimeout = 30000;
/** 失败时最大重试次数 */
private int maxRetries = 3;
/** 重试间隔(毫秒) */
private long retryInterval = 1000;
/** 是否在启动时自动初始化 Schema */
private boolean autoInitSchema = true;
/**
* 是否允许空快照触发 purge(默认 false)。
* <p>
* 当上游返回空列表时,如果该开关为 false,purge 将被跳过以防误删全部同步实体。
* 仅在确认数据源确实为空时才应开启此开关。
*/
private boolean allowPurgeOnEmptySnapshot = false;
}
}

View File

@@ -0,0 +1,31 @@
package com.datamate.knowledgegraph.interfaces.dto;
import jakarta.validation.constraints.NotBlank;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Data
public class CreateEntityRequest {
@NotBlank(message = "实体名称不能为空")
private String name;
@NotBlank(message = "实体类型不能为空")
private String type;
private String description;
private List<String> aliases = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private String sourceId;
private String sourceType;
private Double confidence;
}

View File

@@ -0,0 +1,42 @@
package com.datamate.knowledgegraph.interfaces.dto;
import jakarta.validation.constraints.DecimalMax;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import lombok.Data;
import java.util.HashMap;
import java.util.Map;
@Data
public class CreateRelationRequest {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
@NotBlank(message = "源实体ID不能为空")
@Pattern(regexp = UUID_REGEX, message = "源实体ID格式无效")
private String sourceEntityId;
@NotBlank(message = "目标实体ID不能为空")
@Pattern(regexp = UUID_REGEX, message = "目标实体ID格式无效")
private String targetEntityId;
@NotBlank(message = "关系类型不能为空")
@Size(min = 1, max = 50, message = "关系类型长度必须在1-50之间")
private String relationType;
private Map<String, Object> properties = new HashMap<>();
@DecimalMin(value = "0.0", message = "权重必须在0.0-1.0之间")
@DecimalMax(value = "1.0", message = "权重必须在0.0-1.0之间")
private Double weight;
private String sourceId;
@DecimalMin(value = "0.0", message = "置信度必须在0.0-1.0之间")
@DecimalMax(value = "1.0", message = "置信度必须在0.0-1.0之间")
private Double confidence;
}

View File

@@ -0,0 +1,22 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* 关系摘要,用于图遍历结果中的边表示。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class EdgeSummaryVO {
private String id;
private String sourceEntityId;
private String targetEntityId;
private String relationType;
private Double weight;
}

View File

@@ -0,0 +1,21 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* 实体摘要,用于图遍历结果中的节点表示。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class EntitySummaryVO {
private String id;
private String name;
private String type;
private String description;
}

View File

@@ -0,0 +1,27 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* 最短路径查询结果。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class PathVO {
/** 路径上的节点列表(按顺序) */
private List<EntitySummaryVO> nodes;
/** 路径上的边列表(按顺序) */
private List<EdgeSummaryVO> edges;
/** 路径长度(跳数) */
private int pathLength;
}

View File

@@ -0,0 +1,53 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
/**
* 关系查询结果视图对象。
* <p>
* 包含关系的完整信息,包括源实体和目标实体的摘要信息,
* 用于 REST API 响应。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class RelationVO {
private String id;
private String sourceEntityId;
private String sourceEntityName;
private String sourceEntityType;
private String targetEntityId;
private String targetEntityName;
private String targetEntityType;
private String relationType;
@Builder.Default
private Map<String, Object> properties = new HashMap<>();
private Double weight;
private Double confidence;
/** 来源数据集/知识库的 ID */
private String sourceId;
private String graphId;
private LocalDateTime createdAt;
}

View File

@@ -0,0 +1,24 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* 全文搜索命中结果,包含相关度分数。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SearchHitVO {
private String id;
private String name;
private String type;
private String description;
/** 全文搜索相关度分数(越高越相关) */
private double score;
}

View File

@@ -0,0 +1,26 @@
package com.datamate.knowledgegraph.interfaces.dto;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* 子图查询请求。
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SubgraphRequest {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
@NotEmpty(message = "实体 ID 列表不能为空")
@Size(max = 500, message = "实体数量超出限制(最大 500)")
private List<@Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String> entityIds;
}

View File

@@ -0,0 +1,30 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
/**
* 子图查询结果。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SubgraphVO {
/** 子图中的节点列表 */
private List<EntitySummaryVO> nodes;
/** 子图中的边列表 */
private List<EdgeSummaryVO> edges;
/** 节点数量 */
private int nodeCount;
/** 边数量 */
private int edgeCount;
}

View File

@@ -0,0 +1,56 @@
package com.datamate.knowledgegraph.interfaces.dto;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.time.LocalDateTime;
/**
* 同步结果视图对象。
* <p>
* 不暴露内部错误详情(errors 列表),仅返回错误计数和 syncId,
* 前端可通过 syncId 向运维查询具体日志。
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class SyncResultVO {
private String syncId;
private String syncType;
private int created;
private int updated;
private int skipped;
private int failed;
private int purged;
private int total;
private long durationMillis;
/** 标记为占位符的步骤(功能尚未实现) */
private boolean placeholder;
/** 错误数量(不暴露具体错误信息) */
private int errorCount;
private LocalDateTime startedAt;
private LocalDateTime completedAt;
public static SyncResultVO from(SyncResult result) {
return SyncResultVO.builder()
.syncId(result.getSyncId())
.syncType(result.getSyncType())
.created(result.getCreated())
.updated(result.getUpdated())
.skipped(result.getSkipped())
.failed(result.getFailed())
.purged(result.getPurged())
.total(result.total())
.durationMillis(result.durationMillis())
.placeholder(result.isPlaceholder())
.errorCount(result.getErrors() != null ? result.getErrors().size() : 0)
.startedAt(result.getStartedAt())
.completedAt(result.getCompletedAt())
.build();
}
}

View File

@@ -0,0 +1,18 @@
package com.datamate.knowledgegraph.interfaces.dto;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class UpdateEntityRequest {
private String name;
private String description;
private List<String> aliases;
private Map<String, Object> properties;
}

View File

@@ -0,0 +1,30 @@
package com.datamate.knowledgegraph.interfaces.dto;
import jakarta.validation.constraints.DecimalMax;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.Size;
import lombok.Data;
import java.util.Map;
/**
* 关系更新请求。
* <p>
* 所有字段均为可选,仅更新提供了值的字段(patch 语义)。
*/
@Data
public class UpdateRelationRequest {
@Size(min = 1, max = 50, message = "关系类型长度必须在1-50之间")
private String relationType;
private Map<String, Object> properties;
@DecimalMin(value = "0.0", message = "权重必须在0.0-1.0之间")
@DecimalMax(value = "1.0", message = "权重必须在0.0-1.0之间")
private Double weight;
@DecimalMin(value = "0.0", message = "置信度必须在0.0-1.0之间")
@DecimalMax(value = "1.0", message = "置信度必须在0.0-1.0之间")
private Double confidence;
}

View File

@@ -0,0 +1,122 @@
package com.datamate.knowledgegraph.interfaces.rest;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.application.GraphEntityService;
import com.datamate.knowledgegraph.application.GraphRelationService;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Pattern;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@RequestMapping("/knowledge-graph/{graphId}/entities")
@RequiredArgsConstructor
@Validated
public class GraphEntityController {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
private final GraphEntityService entityService;
private final GraphRelationService relationService;
@PostMapping
@ResponseStatus(HttpStatus.CREATED)
public GraphEntity createEntity(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@Valid @RequestBody CreateEntityRequest request) {
return entityService.createEntity(graphId, request);
}
@GetMapping("/{entityId}")
public GraphEntity getEntity(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId) {
return entityService.getEntity(graphId, entityId);
}
/**
* 查询实体列表(非分页,向后兼容)。
* <p>
* 当请求不包含 {@code page} 参数时匹配此端点,返回 {@code List}。
* 需要分页时请传入 {@code page} 参数,将路由到分页端点。
*/
@GetMapping(params = "!page")
public List<GraphEntity> listEntities(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@RequestParam(required = false) String type,
@RequestParam(required = false) String keyword) {
if (keyword != null && !keyword.isBlank()) {
return entityService.searchEntities(graphId, keyword);
}
if (type != null && !type.isBlank()) {
return entityService.listEntitiesByType(graphId, type);
}
return entityService.listEntities(graphId);
}
/**
* 查询实体列表(分页)。
* <p>
* 当请求包含 {@code page} 参数时匹配此端点,返回 {@code PagedResponse}。
*/
@GetMapping(params = "page")
public PagedResponse<GraphEntity> listEntitiesPaged(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@RequestParam(required = false) String type,
@RequestParam(required = false) String keyword,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "20") int size) {
if (keyword != null && !keyword.isBlank()) {
return entityService.searchEntitiesPaged(graphId, keyword, page, size);
}
if (type != null && !type.isBlank()) {
return entityService.listEntitiesByTypePaged(graphId, type, page, size);
}
return entityService.listEntitiesPaged(graphId, page, size);
}
@PutMapping("/{entityId}")
public GraphEntity updateEntity(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId,
@Valid @RequestBody UpdateEntityRequest request) {
return entityService.updateEntity(graphId, entityId, request);
}
@DeleteMapping("/{entityId}")
@ResponseStatus(HttpStatus.NO_CONTENT)
public void deleteEntity(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId) {
entityService.deleteEntity(graphId, entityId);
}
@GetMapping("/{entityId}/relations")
public PagedResponse<RelationVO> listEntityRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId,
@RequestParam(defaultValue = "all") @Pattern(regexp = "^(all|in|out)$", message = "direction 参数无效,允许值:all, in, out") String direction,
@RequestParam(required = false) String type,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "20") int size) {
return relationService.listEntityRelations(graphId, entityId, direction, type, page, size);
}
@GetMapping("/{entityId}/neighbors")
public List<GraphEntity> getNeighbors(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId,
@RequestParam(defaultValue = "2") int depth,
@RequestParam(defaultValue = "50") int limit) {
return entityService.getNeighbors(graphId, entityId, depth, limit);
}
}

View File

@@ -0,0 +1,86 @@
package com.datamate.knowledgegraph.interfaces.rest;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.application.GraphQueryService;
import com.datamate.knowledgegraph.interfaces.dto.PathVO;
import com.datamate.knowledgegraph.interfaces.dto.SearchHitVO;
import com.datamate.knowledgegraph.interfaces.dto.SubgraphRequest;
import com.datamate.knowledgegraph.interfaces.dto.SubgraphVO;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Pattern;
import lombok.RequiredArgsConstructor;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
/**
* 知识图谱查询接口。
* <p>
* 提供图遍历(邻居、最短路径、子图)和全文搜索功能。
*/
@RestController
@RequestMapping("/knowledge-graph/{graphId}/query")
@RequiredArgsConstructor
@Validated
public class GraphQueryController {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
private final GraphQueryService queryService;
// -----------------------------------------------------------------------
// 图遍历
// -----------------------------------------------------------------------
/**
* 查询实体的 N 跳邻居子图。
*/
@GetMapping("/neighbors/{entityId}")
public SubgraphVO getNeighborGraph(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "entityId 格式无效") String entityId,
@RequestParam(defaultValue = "2") int depth,
@RequestParam(defaultValue = "50") int limit) {
return queryService.getNeighborGraph(graphId, entityId, depth, limit);
}
/**
* 查询两个实体之间的最短路径。
*/
@GetMapping("/shortest-path")
public PathVO getShortestPath(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@RequestParam @Pattern(regexp = UUID_REGEX, message = "sourceId 格式无效") String sourceId,
@RequestParam @Pattern(regexp = UUID_REGEX, message = "targetId 格式无效") String targetId,
@RequestParam(defaultValue = "3") int maxDepth) {
return queryService.getShortestPath(graphId, sourceId, targetId, maxDepth);
}
/**
* 提取指定实体集合的子图(关系网络)。
*/
@PostMapping("/subgraph")
public SubgraphVO getSubgraph(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@Valid @RequestBody SubgraphRequest request) {
return queryService.getSubgraph(graphId, request.getEntityIds());
}
// -----------------------------------------------------------------------
// 全文搜索
// -----------------------------------------------------------------------
/**
* 基于全文索引搜索实体。
* <p>
* 搜索 name 和 description 字段,按相关度排序。
*/
@GetMapping("/search")
public PagedResponse<SearchHitVO> fulltextSearch(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@RequestParam String q,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "20") int size) {
return queryService.fulltextSearch(graphId, q, page, size);
}
}

View File

@@ -0,0 +1,65 @@
package com.datamate.knowledgegraph.interfaces.rest;
import com.datamate.common.interfaces.PagedResponse;
import com.datamate.knowledgegraph.application.GraphRelationService;
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
import com.datamate.knowledgegraph.interfaces.dto.UpdateRelationRequest;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Pattern;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/knowledge-graph/{graphId}/relations")
@RequiredArgsConstructor
@Validated
public class GraphRelationController {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
private final GraphRelationService relationService;
@PostMapping
@ResponseStatus(HttpStatus.CREATED)
public RelationVO createRelation(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@Valid @RequestBody CreateRelationRequest request) {
return relationService.createRelation(graphId, request);
}
@GetMapping
public PagedResponse<RelationVO> listRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@RequestParam(required = false) String type,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "20") int size) {
return relationService.listRelations(graphId, type, page, size);
}
@GetMapping("/{relationId}")
public RelationVO getRelation(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "relationId 格式无效") String relationId) {
return relationService.getRelation(graphId, relationId);
}
@PutMapping("/{relationId}")
public RelationVO updateRelation(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "relationId 格式无效") String relationId,
@Valid @RequestBody UpdateRelationRequest request) {
return relationService.updateRelation(graphId, relationId, request);
}
@DeleteMapping("/{relationId}")
@ResponseStatus(HttpStatus.NO_CONTENT)
public void deleteRelation(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId,
@PathVariable @Pattern(regexp = UUID_REGEX, message = "relationId 格式无效") String relationId) {
relationService.deleteRelation(graphId, relationId);
}
}

View File

@@ -0,0 +1,214 @@
package com.datamate.knowledgegraph.interfaces.rest;
import com.datamate.knowledgegraph.application.GraphSyncService;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import com.datamate.knowledgegraph.interfaces.dto.SyncResultVO;
import jakarta.validation.constraints.Pattern;
import lombok.RequiredArgsConstructor;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
/**
* 知识图谱数据同步 API。
* <p>
* 提供手动触发 MySQL → Neo4j 同步的 REST 端点。
* 生产环境中也可通过定时任务自动触发。
* <p>
* <b>安全说明</b>:本接口仅供内部服务调用(API Gateway / 定时任务),
* 外部请求必须经 API Gateway 鉴权后转发。
* 生产环境建议通过 mTLS 或内部 JWT 进一步加固服务间认证。
* 当前通过 {@code X-Internal-Token} 请求头进行简单的内部调用校验。
*/
@RestController
@RequestMapping("/knowledge-graph/{graphId}/sync")
@RequiredArgsConstructor
@Validated
public class GraphSyncController {
private static final String UUID_REGEX =
"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$";
private final GraphSyncService syncService;
/**
* 全量同步:拉取所有实体并构建关系。
*/
@PostMapping("/full")
public List<SyncResultVO> syncAll(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
List<SyncResult> results = syncService.syncAll(graphId);
return results.stream().map(SyncResultVO::from).toList();
}
/**
* 同步数据集实体。
*/
@PostMapping("/datasets")
public SyncResultVO syncDatasets(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncDatasets(graphId));
}
/**
* 同步字段实体。
*/
@PostMapping("/fields")
public SyncResultVO syncFields(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncFields(graphId));
}
/**
* 同步用户实体。
*/
@PostMapping("/users")
public SyncResultVO syncUsers(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncUsers(graphId));
}
/**
* 同步组织实体。
*/
@PostMapping("/orgs")
public SyncResultVO syncOrgs(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncOrgs(graphId));
}
/**
* 构建 HAS_FIELD 关系。
*/
@PostMapping("/relations/has-field")
public SyncResultVO buildHasFieldRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildHasFieldRelations(graphId));
}
/**
* 构建 DERIVED_FROM 关系。
*/
@PostMapping("/relations/derived-from")
public SyncResultVO buildDerivedFromRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildDerivedFromRelations(graphId));
}
/**
* 构建 BELONGS_TO 关系。
*/
@PostMapping("/relations/belongs-to")
public SyncResultVO buildBelongsToRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildBelongsToRelations(graphId));
}
// -----------------------------------------------------------------------
// 新增实体同步端点
// -----------------------------------------------------------------------
/**
* 同步工作流实体。
*/
@PostMapping("/workflows")
public SyncResultVO syncWorkflows(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncWorkflows(graphId));
}
/**
* 同步作业实体。
*/
@PostMapping("/jobs")
public SyncResultVO syncJobs(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncJobs(graphId));
}
/**
* 同步标注任务实体。
*/
@PostMapping("/label-tasks")
public SyncResultVO syncLabelTasks(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncLabelTasks(graphId));
}
/**
* 同步知识集实体。
*/
@PostMapping("/knowledge-sets")
public SyncResultVO syncKnowledgeSets(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.syncKnowledgeSets(graphId));
}
// -----------------------------------------------------------------------
// 新增关系构建端点
// -----------------------------------------------------------------------
/**
* 构建 USES_DATASET 关系。
*/
@PostMapping("/relations/uses-dataset")
public SyncResultVO buildUsesDatasetRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildUsesDatasetRelations(graphId));
}
/**
* 构建 PRODUCES 关系。
*/
@PostMapping("/relations/produces")
public SyncResultVO buildProducesRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildProducesRelations(graphId));
}
/**
* 构建 ASSIGNED_TO 关系。
*/
@PostMapping("/relations/assigned-to")
public SyncResultVO buildAssignedToRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildAssignedToRelations(graphId));
}
/**
* 构建 TRIGGERS 关系。
*/
@PostMapping("/relations/triggers")
public SyncResultVO buildTriggersRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildTriggersRelations(graphId));
}
/**
* 构建 DEPENDS_ON 关系。
*/
@PostMapping("/relations/depends-on")
public SyncResultVO buildDependsOnRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildDependsOnRelations(graphId));
}
/**
* 构建 IMPACTS 关系。
*/
@PostMapping("/relations/impacts")
public SyncResultVO buildImpactsRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildImpactsRelations(graphId));
}
/**
* 构建 SOURCED_FROM 关系。
*/
@PostMapping("/relations/sourced-from")
public SyncResultVO buildSourcedFromRelations(
@PathVariable @Pattern(regexp = UUID_REGEX, message = "graphId 格式无效") String graphId) {
return SyncResultVO.from(syncService.buildSourcedFromRelations(graphId));
}
}

View File

@@ -0,0 +1,45 @@
# 知识图谱服务 - Neo4j连接配置
# 该配置在 main-application 的 spring.config.import 中引入
# 注意:生产环境务必通过环境变量 NEO4J_PASSWORD 设置密码,不要使用默认值
spring:
neo4j:
uri: ${NEO4J_URI:bolt://datamate-neo4j:7687}
authentication:
username: ${NEO4J_USERNAME:neo4j}
password: ${NEO4J_PASSWORD:datamate123}
pool:
max-connection-pool-size: ${NEO4J_POOL_MAX_SIZE:50}
connection-acquisition-timeout: 30s
max-connection-lifetime: 1h
log-leaked-sessions: true
# 知识图谱服务配置
datamate:
knowledge-graph:
# 默认查询跳数限制
max-depth: ${KG_MAX_DEPTH:3}
# 子图返回最大节点数
max-nodes-per-query: ${KG_MAX_NODES:500}
# 批量导入批次大小
import-batch-size: ${KG_IMPORT_BATCH_SIZE:100}
# MySQL → Neo4j 同步配置
sync:
# 数据管理服务地址
data-management-url: ${DATA_MANAGEMENT_URL:http://localhost:8080}
# 标注服务地址
annotation-service-url: ${ANNOTATION_SERVICE_URL:http://localhost:8081}
# 每页拉取数量
page-size: ${KG_SYNC_PAGE_SIZE:200}
# HTTP 连接超时(毫秒)
connect-timeout: ${KG_SYNC_CONNECT_TIMEOUT:5000}
# HTTP 读取超时(毫秒)
read-timeout: ${KG_SYNC_READ_TIMEOUT:30000}
# 失败时最大重试次数
max-retries: ${KG_SYNC_MAX_RETRIES:3}
# 重试间隔(毫秒)
retry-interval: ${KG_SYNC_RETRY_INTERVAL:1000}
# 是否在启动时自动初始化 Schema
auto-init-schema: ${KG_AUTO_INIT_SCHEMA:true}
# 是否允许空快照触发 purge(默认 false,防止上游返回空列表时误删全部同步实体)
allow-purge-on-empty-snapshot: ${KG_ALLOW_PURGE_ON_EMPTY_SNAPSHOT:false}

View File

@@ -0,0 +1,233 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import com.datamate.knowledgegraph.interfaces.dto.CreateEntityRequest;
import com.datamate.knowledgegraph.interfaces.dto.UpdateEntityRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphEntityServiceTest {
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
private static final String ENTITY_ID = "660e8400-e29b-41d4-a716-446655440001";
private static final String INVALID_GRAPH_ID = "not-a-uuid";
@Mock
private GraphEntityRepository entityRepository;
@Mock
private KnowledgeGraphProperties properties;
@InjectMocks
private GraphEntityService entityService;
private GraphEntity sampleEntity;
@BeforeEach
void setUp() {
sampleEntity = GraphEntity.builder()
.id(ENTITY_ID)
.name("TestDataset")
.type("Dataset")
.description("A test dataset")
.graphId(GRAPH_ID)
.confidence(1.0)
.createdAt(LocalDateTime.now())
.updatedAt(LocalDateTime.now())
.build();
}
// -----------------------------------------------------------------------
// graphId 校验
// -----------------------------------------------------------------------
@Test
void getEntity_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> entityService.getEntity(INVALID_GRAPH_ID, ENTITY_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void getEntity_nullGraphId_throwsBusinessException() {
assertThatThrownBy(() -> entityService.getEntity(null, ENTITY_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// createEntity
// -----------------------------------------------------------------------
@Test
void createEntity_success() {
CreateEntityRequest request = new CreateEntityRequest();
request.setName("NewEntity");
request.setType("Dataset");
request.setDescription("Desc");
when(entityRepository.save(any(GraphEntity.class))).thenReturn(sampleEntity);
GraphEntity result = entityService.createEntity(GRAPH_ID, request);
assertThat(result).isNotNull();
assertThat(result.getName()).isEqualTo("TestDataset");
verify(entityRepository).save(any(GraphEntity.class));
}
// -----------------------------------------------------------------------
// getEntity
// -----------------------------------------------------------------------
@Test
void getEntity_found() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleEntity));
GraphEntity result = entityService.getEntity(GRAPH_ID, ENTITY_ID);
assertThat(result.getId()).isEqualTo(ENTITY_ID);
assertThat(result.getName()).isEqualTo("TestDataset");
}
@Test
void getEntity_notFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> entityService.getEntity(GRAPH_ID, ENTITY_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// listEntities
// -----------------------------------------------------------------------
@Test
void listEntities_returnsAll() {
when(entityRepository.findByGraphId(GRAPH_ID))
.thenReturn(List.of(sampleEntity));
List<GraphEntity> results = entityService.listEntities(GRAPH_ID);
assertThat(results).hasSize(1);
assertThat(results.get(0).getName()).isEqualTo("TestDataset");
}
// -----------------------------------------------------------------------
// updateEntity
// -----------------------------------------------------------------------
@Test
void updateEntity_partialUpdate_onlyChangesProvidedFields() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleEntity));
when(entityRepository.save(any(GraphEntity.class)))
.thenAnswer(inv -> inv.getArgument(0));
UpdateEntityRequest request = new UpdateEntityRequest();
request.setName("UpdatedName");
// description not set — should remain unchanged
GraphEntity result = entityService.updateEntity(GRAPH_ID, ENTITY_ID, request);
assertThat(result.getName()).isEqualTo("UpdatedName");
assertThat(result.getDescription()).isEqualTo("A test dataset");
}
// -----------------------------------------------------------------------
// deleteEntity
// -----------------------------------------------------------------------
@Test
void deleteEntity_success() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleEntity));
entityService.deleteEntity(GRAPH_ID, ENTITY_ID);
verify(entityRepository).delete(sampleEntity);
}
@Test
void deleteEntity_notFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> entityService.deleteEntity(GRAPH_ID, ENTITY_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// getNeighbors — 深度/限制 clamping
// -----------------------------------------------------------------------
@Test
void getNeighbors_clampsDepthAndLimit() {
when(properties.getMaxDepth()).thenReturn(3);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
when(entityRepository.findNeighbors(eq(GRAPH_ID), eq(ENTITY_ID), eq(3), eq(500)))
.thenReturn(List.of());
List<GraphEntity> result = entityService.getNeighbors(GRAPH_ID, ENTITY_ID, 100, 99999);
assertThat(result).isEmpty();
// depth clamped to maxDepth=3, limit clamped to maxNodesPerQuery=500
verify(entityRepository).findNeighbors(GRAPH_ID, ENTITY_ID, 3, 500);
}
// -----------------------------------------------------------------------
// 分页
// -----------------------------------------------------------------------
@Test
void listEntitiesPaged_normalPage() {
when(entityRepository.findByGraphIdPaged(GRAPH_ID, 0L, 20))
.thenReturn(List.of(sampleEntity));
when(entityRepository.countByGraphId(GRAPH_ID)).thenReturn(1L);
var result = entityService.listEntitiesPaged(GRAPH_ID, 0, 20);
assertThat(result.getContent()).hasSize(1);
assertThat(result.getTotalElements()).isEqualTo(1);
}
@Test
void listEntitiesPaged_negativePage_clampedToZero() {
when(entityRepository.findByGraphIdPaged(GRAPH_ID, 0L, 20))
.thenReturn(List.of());
when(entityRepository.countByGraphId(GRAPH_ID)).thenReturn(0L);
var result = entityService.listEntitiesPaged(GRAPH_ID, -1, 20);
assertThat(result.getPage()).isEqualTo(0);
}
@Test
void listEntitiesPaged_oversizedPage_clampedTo200() {
when(entityRepository.findByGraphIdPaged(GRAPH_ID, 0L, 200))
.thenReturn(List.of());
when(entityRepository.countByGraphId(GRAPH_ID)).thenReturn(0L);
entityService.listEntitiesPaged(GRAPH_ID, 0, 999);
verify(entityRepository).findByGraphIdPaged(GRAPH_ID, 0L, 200);
}
}

View File

@@ -0,0 +1,597 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.auth.application.ResourceAccessService;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import com.datamate.knowledgegraph.interfaces.dto.SubgraphVO;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.data.neo4j.core.Neo4jClient;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphQueryServiceTest {
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
private static final String ENTITY_ID = "660e8400-e29b-41d4-a716-446655440001";
private static final String ENTITY_ID_2 = "660e8400-e29b-41d4-a716-446655440002";
private static final String INVALID_GRAPH_ID = "bad-id";
@Mock
private Neo4jClient neo4jClient;
@Mock
private GraphEntityRepository entityRepository;
@Mock
private KnowledgeGraphProperties properties;
@Mock
private ResourceAccessService resourceAccessService;
@InjectMocks
private GraphQueryService queryService;
@BeforeEach
void setUp() {
}
// -----------------------------------------------------------------------
// graphId 校验
// -----------------------------------------------------------------------
@Test
void getNeighborGraph_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> queryService.getNeighborGraph(INVALID_GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
}
@Test
void getShortestPath_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> queryService.getShortestPath(INVALID_GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3))
.isInstanceOf(BusinessException.class);
}
@Test
void getSubgraph_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> queryService.getSubgraph(INVALID_GRAPH_ID, List.of(ENTITY_ID)))
.isInstanceOf(BusinessException.class);
}
@Test
void fulltextSearch_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> queryService.fulltextSearch(INVALID_GRAPH_ID, "test", 0, 20))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// getNeighborGraph — 实体不存在
// -----------------------------------------------------------------------
@Test
void getNeighborGraph_entityNotFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// getShortestPath — 起止相同
// -----------------------------------------------------------------------
@Test
void getShortestPath_sameSourceAndTarget_returnsSingleNode() {
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("Node").type("Dataset").graphId(GRAPH_ID).build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
var result = queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID, 3);
assertThat(result.getPathLength()).isEqualTo(0);
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getEdges()).isEmpty();
}
@Test
void getShortestPath_sourceNotFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// getSubgraph — 空输入
// -----------------------------------------------------------------------
@Test
void getSubgraph_nullEntityIds_returnsEmptySubgraph() {
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, null);
assertThat(result.getNodes()).isEmpty();
assertThat(result.getEdges()).isEmpty();
assertThat(result.getNodeCount()).isEqualTo(0);
}
@Test
void getSubgraph_emptyEntityIds_returnsEmptySubgraph() {
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of());
assertThat(result.getNodes()).isEmpty();
assertThat(result.getEdges()).isEmpty();
}
@Test
void getSubgraph_exceedsMaxNodes_throwsBusinessException() {
when(properties.getMaxNodesPerQuery()).thenReturn(5);
List<String> tooManyIds = List.of("1", "2", "3", "4", "5", "6");
assertThatThrownBy(() -> queryService.getSubgraph(GRAPH_ID, tooManyIds))
.isInstanceOf(BusinessException.class);
}
@Test
void getSubgraph_noExistingEntities_returnsEmptySubgraph() {
when(properties.getMaxNodesPerQuery()).thenReturn(500);
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of());
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
assertThat(result.getNodes()).isEmpty();
}
// -----------------------------------------------------------------------
// fulltextSearch — 空查询
// -----------------------------------------------------------------------
@Test
void fulltextSearch_blankQuery_returnsEmpty() {
var result = queryService.fulltextSearch(GRAPH_ID, "", 0, 20);
assertThat(result.getContent()).isEmpty();
assertThat(result.getTotalElements()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
@Test
void fulltextSearch_nullQuery_returnsEmpty() {
var result = queryService.fulltextSearch(GRAPH_ID, null, 0, 20);
assertThat(result.getContent()).isEmpty();
}
// -----------------------------------------------------------------------
// 权限过滤
// -----------------------------------------------------------------------
@Nested
class PermissionFilteringTest {
private static final String CURRENT_USER_ID = "user-123";
private static final String OTHER_USER_ID = "other-user";
// -- getNeighborGraph 权限 --
@Test
void getNeighborGraph_nonAdmin_otherEntity_throwsInsufficientPermissions() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getNeighborGraph_admin_otherEntity_noPermissionDenied() {
// 管理员返回 null → 不过滤
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(null);
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
when(properties.getMaxDepth()).thenReturn(3);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
// 管理员不会被权限拦截,会继续到 Neo4jClient 调用
// 由于 Neo4jClient 未完全 mock,会抛出其他异常,不是 BusinessException
try {
queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50);
} catch (BusinessException e) {
throw new AssertionError("Admin should not be blocked by permission check", e);
} catch (Exception ignored) {
// Neo4jClient mock chain 未完成,预期其他异常
}
}
// -- getShortestPath 权限 --
@Test
void getShortestPath_nonAdmin_sourceNotAccessible_throws() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity sourceEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
assertThatThrownBy(() -> queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getShortestPath_nonAdmin_targetNotAccessible_throws() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity sourceEntity = GraphEntity.builder()
.id(ENTITY_ID).name("My Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID)))
.build();
GraphEntity targetEntity = GraphEntity.builder()
.id(ENTITY_ID_2).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
when(entityRepository.findByIdAndGraphId(ENTITY_ID_2, GRAPH_ID))
.thenReturn(Optional.of(targetEntity));
assertThatThrownBy(() -> queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID_2, 3))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getShortestPath_nonAdmin_sameOwnEntity_returnsSingleNode() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("My Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID)))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
var result = queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID, 3);
assertThat(result.getPathLength()).isEqualTo(0);
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("My Dataset");
}
@Test
void getShortestPath_nonAdmin_structuralEntity_noPermissionDenied() {
// 结构型实体(无 created_by)对所有用户可见
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity structuralEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Admin User").type("User").graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(structuralEntity));
// 起止相同 → 返回单节点路径,不需要 Neo4jClient
var result = queryService.getShortestPath(GRAPH_ID, ENTITY_ID, ENTITY_ID, 3);
assertThat(result.getPathLength()).isEqualTo(0);
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getType()).isEqualTo("User");
}
// -- getSubgraph 权限过滤 --
@Test
void getSubgraph_nonAdmin_filtersInaccessibleEntities() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
GraphEntity ownEntity = GraphEntity.builder()
.id(ENTITY_ID).name("My Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID)))
.build();
GraphEntity otherEntity = GraphEntity.builder()
.id(ENTITY_ID_2).name("Other Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2)))
.thenReturn(List.of(ownEntity, otherEntity));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2));
// 只返回自己创建的实体(另一个被过滤),单节点无边
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("My Dataset");
assertThat(result.getEdges()).isEmpty();
assertThat(result.getNodeCount()).isEqualTo(1);
}
@Test
void getSubgraph_nonAdmin_allFiltered_returnsEmptySubgraph() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
GraphEntity otherEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Other Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", OTHER_USER_ID)))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of(otherEntity));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
assertThat(result.getNodes()).isEmpty();
assertThat(result.getEdges()).isEmpty();
assertThat(result.getNodeCount()).isEqualTo(0);
}
@Test
void getSubgraph_nonAdmin_structuralEntitiesVisible() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
// 结构型实体没有 created_by → 对所有用户可见
GraphEntity structuralEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Default Org").type("Org").graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of(structuralEntity));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getType()).isEqualTo("Org");
}
@Test
void getSubgraph_admin_seesAllEntities() {
// 管理员返回 null → 不过滤
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(null);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
GraphEntity otherUserEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Other's Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", "user-1")))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of(otherUserEntity));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
// 管理员看到其他用户的实体(不被过滤)
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("Other's Dataset");
}
// -- P1-2: 业务实体缺失 created_by(脏数据)被正确拦截 --
@Test
void getNeighborGraph_nonAdmin_businessEntityWithoutCreatedBy_throws() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
// 业务实体缺失 created_by → 应被拦截
GraphEntity dirtyEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Dirty Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(dirtyEntity));
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getSubgraph_nonAdmin_businessEntityWithoutCreatedBy_filtered() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
// 业务实体缺失 created_by → 应被过滤
GraphEntity dirtyEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Dirty Dataset").type("Dataset").graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of(dirtyEntity));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
assertThat(result.getNodes()).isEmpty();
assertThat(result.getNodeCount()).isEqualTo(0);
}
// -- P1-1: CONFIDENTIAL 敏感度过滤 --
@Test
void getNeighborGraph_nonAdmin_confidentialEntity_throwsWithoutPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
// canViewConfidential() 默认返回 false(mock 默认值)→ 无保密权限
GraphEntity confidentialEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "CONFIDENTIAL")))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(confidentialEntity));
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getNeighborGraph_nonAdmin_confidentialEntity_allowedWithPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(resourceAccessService.canViewConfidential()).thenReturn(true);
GraphEntity confidentialEntity = GraphEntity.builder()
.id(ENTITY_ID).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "CONFIDENTIAL")))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(confidentialEntity));
when(properties.getMaxDepth()).thenReturn(3);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
// 有保密权限 → 通过安全检查,继续到 Neo4jClient 调用
try {
queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50);
} catch (BusinessException e) {
throw new AssertionError("Should not be blocked by permission check", e);
} catch (Exception ignored) {
// Neo4jClient mock chain 未完成,预期其他异常
}
}
@Test
void getSubgraph_nonAdmin_confidentialEntity_filteredWithoutPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
// canViewConfidential() 默认返回 false → 无保密权限
GraphEntity ownNonConfidential = GraphEntity.builder()
.id(ENTITY_ID).name("Normal KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID)))
.build();
GraphEntity ownConfidential = GraphEntity.builder()
.id(ENTITY_ID_2).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "CONFIDENTIAL")))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2)))
.thenReturn(List.of(ownNonConfidential, ownConfidential));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2));
// CONFIDENTIAL 实体被过滤,只剩普通实体
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("Normal KS");
}
@Test
void getSubgraph_nonAdmin_confidentialEntity_visibleWithPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(resourceAccessService.canViewConfidential()).thenReturn(true);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
GraphEntity ownConfidential = GraphEntity.builder()
.id(ENTITY_ID).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "CONFIDENTIAL")))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID)))
.thenReturn(List.of(ownConfidential));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID));
// 有保密权限 → 看到 CONFIDENTIAL 实体
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("Secret KS");
}
// -- P2-2: CONFIDENTIAL 大小写不敏感 --
@Test
void getNeighborGraph_nonAdmin_lowercaseConfidential_throwsWithoutPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "confidential")))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getNeighborGraph_nonAdmin_mixedCaseConfidentialWithSpaces_throwsWithoutPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
GraphEntity entity = GraphEntity.builder()
.id(ENTITY_ID).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", " Confidential ")))
.build();
when(entityRepository.findByIdAndGraphId(ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(entity));
assertThatThrownBy(() -> queryService.getNeighborGraph(GRAPH_ID, ENTITY_ID, 2, 50))
.isInstanceOf(BusinessException.class);
verifyNoInteractions(neo4jClient);
}
@Test
void getSubgraph_nonAdmin_lowercaseConfidential_filteredWithoutPermission() {
when(resourceAccessService.resolveOwnerFilterUserId()).thenReturn(CURRENT_USER_ID);
when(properties.getMaxNodesPerQuery()).thenReturn(500);
GraphEntity normalKs = GraphEntity.builder()
.id(ENTITY_ID).name("Normal KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID)))
.build();
GraphEntity lowercaseConfidential = GraphEntity.builder()
.id(ENTITY_ID_2).name("Secret KS").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", CURRENT_USER_ID, "sensitivity", "confidential")))
.build();
when(entityRepository.findByGraphIdAndIdIn(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2)))
.thenReturn(List.of(normalKs, lowercaseConfidential));
SubgraphVO result = queryService.getSubgraph(GRAPH_ID, List.of(ENTITY_ID, ENTITY_ID_2));
assertThat(result.getNodes()).hasSize(1);
assertThat(result.getNodes().get(0).getName()).isEqualTo("Normal KS");
}
}
}

View File

@@ -0,0 +1,270 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.model.RelationDetail;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.domain.repository.GraphRelationRepository;
import com.datamate.knowledgegraph.interfaces.dto.CreateRelationRequest;
import com.datamate.knowledgegraph.interfaces.dto.RelationVO;
import com.datamate.knowledgegraph.interfaces.dto.UpdateRelationRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphRelationServiceTest {
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
private static final String RELATION_ID = "770e8400-e29b-41d4-a716-446655440002";
private static final String SOURCE_ENTITY_ID = "660e8400-e29b-41d4-a716-446655440001";
private static final String TARGET_ENTITY_ID = "660e8400-e29b-41d4-a716-446655440003";
private static final String INVALID_GRAPH_ID = "not-a-uuid";
@Mock
private GraphRelationRepository relationRepository;
@Mock
private GraphEntityRepository entityRepository;
@InjectMocks
private GraphRelationService relationService;
private RelationDetail sampleDetail;
private GraphEntity sourceEntity;
private GraphEntity targetEntity;
@BeforeEach
void setUp() {
sampleDetail = RelationDetail.builder()
.id(RELATION_ID)
.sourceEntityId(SOURCE_ENTITY_ID)
.sourceEntityName("Source")
.sourceEntityType("Dataset")
.targetEntityId(TARGET_ENTITY_ID)
.targetEntityName("Target")
.targetEntityType("Field")
.relationType("HAS_FIELD")
.properties(Map.of())
.weight(1.0)
.confidence(1.0)
.graphId(GRAPH_ID)
.createdAt(LocalDateTime.now())
.build();
sourceEntity = GraphEntity.builder()
.id(SOURCE_ENTITY_ID).name("Source").type("Dataset").graphId(GRAPH_ID).build();
targetEntity = GraphEntity.builder()
.id(TARGET_ENTITY_ID).name("Target").type("Field").graphId(GRAPH_ID).build();
}
// -----------------------------------------------------------------------
// graphId 校验
// -----------------------------------------------------------------------
@Test
void getRelation_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> relationService.getRelation(INVALID_GRAPH_ID, RELATION_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// createRelation
// -----------------------------------------------------------------------
@Test
void createRelation_success() {
when(entityRepository.findByIdAndGraphId(SOURCE_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
when(entityRepository.findByIdAndGraphId(TARGET_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(targetEntity));
when(relationRepository.create(eq(GRAPH_ID), eq(SOURCE_ENTITY_ID), eq(TARGET_ENTITY_ID),
eq("HAS_FIELD"), anyMap(), isNull(), isNull(), isNull()))
.thenReturn(Optional.of(sampleDetail));
CreateRelationRequest request = new CreateRelationRequest();
request.setSourceEntityId(SOURCE_ENTITY_ID);
request.setTargetEntityId(TARGET_ENTITY_ID);
request.setRelationType("HAS_FIELD");
RelationVO result = relationService.createRelation(GRAPH_ID, request);
assertThat(result.getId()).isEqualTo(RELATION_ID);
assertThat(result.getRelationType()).isEqualTo("HAS_FIELD");
assertThat(result.getSourceEntityId()).isEqualTo(SOURCE_ENTITY_ID);
assertThat(result.getTargetEntityId()).isEqualTo(TARGET_ENTITY_ID);
}
@Test
void createRelation_sourceNotFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(SOURCE_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
CreateRelationRequest request = new CreateRelationRequest();
request.setSourceEntityId(SOURCE_ENTITY_ID);
request.setTargetEntityId(TARGET_ENTITY_ID);
request.setRelationType("HAS_FIELD");
assertThatThrownBy(() -> relationService.createRelation(GRAPH_ID, request))
.isInstanceOf(BusinessException.class);
}
@Test
void createRelation_targetNotFound_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(SOURCE_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
when(entityRepository.findByIdAndGraphId(TARGET_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.empty());
CreateRelationRequest request = new CreateRelationRequest();
request.setSourceEntityId(SOURCE_ENTITY_ID);
request.setTargetEntityId(TARGET_ENTITY_ID);
request.setRelationType("HAS_FIELD");
assertThatThrownBy(() -> relationService.createRelation(GRAPH_ID, request))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// getRelation
// -----------------------------------------------------------------------
@Test
void getRelation_found() {
when(relationRepository.findByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleDetail));
RelationVO result = relationService.getRelation(GRAPH_ID, RELATION_ID);
assertThat(result.getId()).isEqualTo(RELATION_ID);
}
@Test
void getRelation_notFound_throwsBusinessException() {
when(relationRepository.findByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> relationService.getRelation(GRAPH_ID, RELATION_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// listRelations (分页)
// -----------------------------------------------------------------------
@Test
void listRelations_returnsPaged() {
when(relationRepository.findByGraphId(GRAPH_ID, null, 0L, 20))
.thenReturn(List.of(sampleDetail));
when(relationRepository.countByGraphId(GRAPH_ID, null))
.thenReturn(1L);
var result = relationService.listRelations(GRAPH_ID, null, 0, 20);
assertThat(result.getContent()).hasSize(1);
assertThat(result.getTotalElements()).isEqualTo(1);
}
@Test
void listRelations_oversizedPage_clampedTo200() {
when(relationRepository.findByGraphId(GRAPH_ID, null, 0L, 200))
.thenReturn(List.of());
when(relationRepository.countByGraphId(GRAPH_ID, null))
.thenReturn(0L);
relationService.listRelations(GRAPH_ID, null, 0, 999);
verify(relationRepository).findByGraphId(GRAPH_ID, null, 0L, 200);
}
// -----------------------------------------------------------------------
// listEntityRelations — direction 校验
// -----------------------------------------------------------------------
@Test
void listEntityRelations_invalidDirection_throwsBusinessException() {
when(entityRepository.findByIdAndGraphId(SOURCE_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
assertThatThrownBy(() ->
relationService.listEntityRelations(GRAPH_ID, SOURCE_ENTITY_ID, "invalid", null, 0, 20))
.isInstanceOf(BusinessException.class);
}
@Test
void listEntityRelations_inDirection() {
when(entityRepository.findByIdAndGraphId(SOURCE_ENTITY_ID, GRAPH_ID))
.thenReturn(Optional.of(sourceEntity));
when(relationRepository.findInboundByEntityId(GRAPH_ID, SOURCE_ENTITY_ID, null, 0L, 20))
.thenReturn(List.of(sampleDetail));
when(relationRepository.countByEntityId(GRAPH_ID, SOURCE_ENTITY_ID, null, "in"))
.thenReturn(1L);
var result = relationService.listEntityRelations(
GRAPH_ID, SOURCE_ENTITY_ID, "in", null, 0, 20);
assertThat(result.getContent()).hasSize(1);
}
// -----------------------------------------------------------------------
// updateRelation
// -----------------------------------------------------------------------
@Test
void updateRelation_success() {
when(relationRepository.findByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleDetail));
RelationDetail updated = RelationDetail.builder()
.id(RELATION_ID).relationType("USES").weight(0.8)
.sourceEntityId(SOURCE_ENTITY_ID).targetEntityId(TARGET_ENTITY_ID)
.graphId(GRAPH_ID).build();
when(relationRepository.update(eq(RELATION_ID), eq(GRAPH_ID), eq("USES"), isNull(), eq(0.8), isNull()))
.thenReturn(Optional.of(updated));
UpdateRelationRequest request = new UpdateRelationRequest();
request.setRelationType("USES");
request.setWeight(0.8);
RelationVO result = relationService.updateRelation(GRAPH_ID, RELATION_ID, request);
assertThat(result.getRelationType()).isEqualTo("USES");
}
// -----------------------------------------------------------------------
// deleteRelation
// -----------------------------------------------------------------------
@Test
void deleteRelation_success() {
when(relationRepository.findByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(Optional.of(sampleDetail));
when(relationRepository.deleteByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(1L);
relationService.deleteRelation(GRAPH_ID, RELATION_ID);
verify(relationRepository).deleteByIdAndGraphId(RELATION_ID, GRAPH_ID);
}
@Test
void deleteRelation_notFound_throwsBusinessException() {
when(relationRepository.findByIdAndGraphId(RELATION_ID, GRAPH_ID))
.thenReturn(Optional.empty());
assertThatThrownBy(() -> relationService.deleteRelation(GRAPH_ID, RELATION_ID))
.isInstanceOf(BusinessException.class);
}
}

View File

@@ -0,0 +1,338 @@
package com.datamate.knowledgegraph.application;
import com.datamate.common.infrastructure.exception.BusinessException;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.JobDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.LabelTaskDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.KnowledgeSetDTO;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphSyncServiceTest {
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
private static final String INVALID_GRAPH_ID = "bad-id";
@Mock
private GraphSyncStepService stepService;
@Mock
private DataManagementClient dataManagementClient;
@Mock
private KnowledgeGraphProperties properties;
@InjectMocks
private GraphSyncService syncService;
private KnowledgeGraphProperties.Sync syncConfig;
@BeforeEach
void setUp() {
syncConfig = new KnowledgeGraphProperties.Sync();
syncConfig.setMaxRetries(1);
syncConfig.setRetryInterval(10);
}
// -----------------------------------------------------------------------
// graphId 校验
// -----------------------------------------------------------------------
@Test
void syncAll_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncAll(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void syncAll_nullGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncAll(null))
.isInstanceOf(BusinessException.class);
}
@Test
void syncDatasets_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncDatasets(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void syncWorkflows_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncWorkflows(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void syncJobs_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncJobs(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void syncLabelTasks_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncLabelTasks(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void syncKnowledgeSets_invalidGraphId_throwsBusinessException() {
assertThatThrownBy(() -> syncService.syncKnowledgeSets(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
// -----------------------------------------------------------------------
// syncAll — 正常流程(8 实体 + 10 关系 = 18 结果)
// -----------------------------------------------------------------------
@Test
void syncAll_success_returnsAllResults() {
when(properties.getSync()).thenReturn(syncConfig);
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setName("Test");
dto.setCreatedBy("admin");
when(dataManagementClient.listAllDatasets()).thenReturn(List.of(dto));
when(dataManagementClient.listAllWorkflows()).thenReturn(List.of());
when(dataManagementClient.listAllJobs()).thenReturn(List.of());
when(dataManagementClient.listAllLabelTasks()).thenReturn(List.of());
when(dataManagementClient.listAllKnowledgeSets()).thenReturn(List.of());
// 8 entity upsert results
when(stepService.upsertDatasetEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Dataset").build());
when(stepService.upsertFieldEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Field").build());
when(stepService.upsertUserEntities(eq(GRAPH_ID), anySet(), anyString()))
.thenReturn(SyncResult.builder().syncType("User").build());
when(stepService.upsertOrgEntities(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("Org").build());
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Workflow").build());
when(stepService.upsertJobEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Job").build());
when(stepService.upsertLabelTaskEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("LabelTask").build());
when(stepService.upsertKnowledgeSetEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("KnowledgeSet").build());
when(stepService.purgeStaleEntities(eq(GRAPH_ID), anyString(), anySet(), anyString()))
.thenReturn(0);
// 10 relation merge results
when(stepService.mergeHasFieldRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("HAS_FIELD").build());
when(stepService.mergeDerivedFromRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("DERIVED_FROM").build());
when(stepService.mergeBelongsToRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("BELONGS_TO").build());
when(stepService.mergeUsesDatasetRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("USES_DATASET").build());
when(stepService.mergeProducesRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("PRODUCES").build());
when(stepService.mergeAssignedToRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("ASSIGNED_TO").build());
when(stepService.mergeTriggersRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("TRIGGERS").build());
when(stepService.mergeDependsOnRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("DEPENDS_ON").build());
when(stepService.mergeImpactsRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("IMPACTS").build());
when(stepService.mergeSourcedFromRelations(eq(GRAPH_ID), anyString()))
.thenReturn(SyncResult.builder().syncType("SOURCED_FROM").build());
List<SyncResult> results = syncService.syncAll(GRAPH_ID);
// 8 entities + 10 relations = 18
assertThat(results).hasSize(18);
// 按 syncType 建立索引,避免依赖固定下标
Map<String, SyncResult> byType = results.stream()
.collect(Collectors.toMap(SyncResult::getSyncType, Function.identity()));
// 验证所有 8 个实体类型都存在
assertThat(byType).containsKeys("Dataset", "Field", "User", "Org",
"Workflow", "Job", "LabelTask", "KnowledgeSet");
// 验证所有 10 个关系类型都存在
assertThat(byType).containsKeys("HAS_FIELD", "DERIVED_FROM", "BELONGS_TO",
"USES_DATASET", "PRODUCES", "ASSIGNED_TO", "TRIGGERS",
"DEPENDS_ON", "IMPACTS", "SOURCED_FROM");
}
// -----------------------------------------------------------------------
// 重试耗尽
// -----------------------------------------------------------------------
@Test
void syncDatasets_fetchRetryExhausted_throwsBusinessException() {
when(properties.getSync()).thenReturn(syncConfig);
when(dataManagementClient.listAllDatasets()).thenThrow(new RuntimeException("connection refused"));
assertThatThrownBy(() -> syncService.syncDatasets(GRAPH_ID))
.isInstanceOf(BusinessException.class)
.hasMessageContaining("datasets");
}
// -----------------------------------------------------------------------
// 新增实体单步同步
// -----------------------------------------------------------------------
@Nested
class NewEntitySyncTest {
@Test
void syncWorkflows_success() {
when(properties.getSync()).thenReturn(syncConfig);
WorkflowDTO wf = new WorkflowDTO();
wf.setId("wf-001");
wf.setName("Clean Pipeline");
when(dataManagementClient.listAllWorkflows()).thenReturn(List.of(wf));
when(stepService.upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Workflow").build());
when(stepService.purgeStaleEntities(eq(GRAPH_ID), eq("Workflow"), anySet(), anyString()))
.thenReturn(0);
SyncResult result = syncService.syncWorkflows(GRAPH_ID);
assertThat(result.getSyncType()).isEqualTo("Workflow");
verify(stepService).upsertWorkflowEntities(eq(GRAPH_ID), anyList(), anyString());
}
@Test
void syncJobs_success() {
when(properties.getSync()).thenReturn(syncConfig);
JobDTO job = new JobDTO();
job.setId("job-001");
job.setName("Clean Job");
when(dataManagementClient.listAllJobs()).thenReturn(List.of(job));
when(stepService.upsertJobEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("Job").build());
when(stepService.purgeStaleEntities(eq(GRAPH_ID), eq("Job"), anySet(), anyString()))
.thenReturn(0);
SyncResult result = syncService.syncJobs(GRAPH_ID);
assertThat(result.getSyncType()).isEqualTo("Job");
verify(stepService).upsertJobEntities(eq(GRAPH_ID), anyList(), anyString());
}
@Test
void syncLabelTasks_success() {
when(properties.getSync()).thenReturn(syncConfig);
LabelTaskDTO task = new LabelTaskDTO();
task.setId("lt-001");
task.setName("Label Task");
when(dataManagementClient.listAllLabelTasks()).thenReturn(List.of(task));
when(stepService.upsertLabelTaskEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("LabelTask").build());
when(stepService.purgeStaleEntities(eq(GRAPH_ID), eq("LabelTask"), anySet(), anyString()))
.thenReturn(0);
SyncResult result = syncService.syncLabelTasks(GRAPH_ID);
assertThat(result.getSyncType()).isEqualTo("LabelTask");
}
@Test
void syncKnowledgeSets_success() {
when(properties.getSync()).thenReturn(syncConfig);
KnowledgeSetDTO ks = new KnowledgeSetDTO();
ks.setId("ks-001");
ks.setName("Knowledge Set");
when(dataManagementClient.listAllKnowledgeSets()).thenReturn(List.of(ks));
when(stepService.upsertKnowledgeSetEntities(eq(GRAPH_ID), anyList(), anyString()))
.thenReturn(SyncResult.builder().syncType("KnowledgeSet").build());
when(stepService.purgeStaleEntities(eq(GRAPH_ID), eq("KnowledgeSet"), anySet(), anyString()))
.thenReturn(0);
SyncResult result = syncService.syncKnowledgeSets(GRAPH_ID);
assertThat(result.getSyncType()).isEqualTo("KnowledgeSet");
}
@Test
void syncWorkflows_fetchFailed_throwsBusinessException() {
when(properties.getSync()).thenReturn(syncConfig);
when(dataManagementClient.listAllWorkflows()).thenThrow(new RuntimeException("timeout"));
assertThatThrownBy(() -> syncService.syncWorkflows(GRAPH_ID))
.isInstanceOf(BusinessException.class)
.hasMessageContaining("workflows");
}
}
// -----------------------------------------------------------------------
// 新增关系构建
// -----------------------------------------------------------------------
@Nested
class NewRelationBuildTest {
@Test
void buildUsesDatasetRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildUsesDatasetRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildProducesRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildProducesRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildAssignedToRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildAssignedToRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildTriggersRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildTriggersRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildDependsOnRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildDependsOnRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildImpactsRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildImpactsRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
@Test
void buildSourcedFromRelations_invalidGraphId_throws() {
assertThatThrownBy(() -> syncService.buildSourcedFromRelations(INVALID_GRAPH_ID))
.isInstanceOf(BusinessException.class);
}
}
}

View File

@@ -0,0 +1,821 @@
package com.datamate.knowledgegraph.application;
import com.datamate.knowledgegraph.domain.model.GraphEntity;
import com.datamate.knowledgegraph.domain.model.SyncResult;
import com.datamate.knowledgegraph.domain.repository.GraphEntityRepository;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.DatasetDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.TagDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.WorkflowDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.JobDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.LabelTaskDTO;
import com.datamate.knowledgegraph.infrastructure.client.DataManagementClient.KnowledgeSetDTO;
import com.datamate.knowledgegraph.infrastructure.neo4j.KnowledgeGraphProperties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.data.neo4j.core.Neo4jClient.UnboundRunnableSpec;
import org.springframework.data.neo4j.core.Neo4jClient.RunnableSpec;
import org.springframework.data.neo4j.core.Neo4jClient.RecordFetchSpec;
import org.springframework.data.neo4j.core.Neo4jClient.MappingSpec;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphSyncStepServiceTest {
private static final String GRAPH_ID = "550e8400-e29b-41d4-a716-446655440000";
private static final String SYNC_ID = "test-sync";
@Mock
private GraphEntityRepository entityRepository;
@Mock
private Neo4jClient neo4jClient;
@Mock
private KnowledgeGraphProperties properties;
@InjectMocks
private GraphSyncStepService stepService;
@Captor
private ArgumentCaptor<String> cypherCaptor;
@BeforeEach
void setUp() {
}
// -----------------------------------------------------------------------
// Neo4jClient mock chain helper
// -----------------------------------------------------------------------
@SuppressWarnings("unchecked")
private void setupNeo4jQueryChain(Class<?> fetchType, Object returnValue) {
UnboundRunnableSpec unboundSpec = mock(UnboundRunnableSpec.class);
RunnableSpec runnableSpec = mock(RunnableSpec.class);
MappingSpec mappingSpec = mock(MappingSpec.class);
RecordFetchSpec fetchSpec = mock(RecordFetchSpec.class);
when(neo4jClient.query(anyString())).thenReturn(unboundSpec);
when(unboundSpec.bindAll(anyMap())).thenReturn(runnableSpec);
when(runnableSpec.fetchAs(any(Class.class))).thenReturn(mappingSpec);
when(mappingSpec.mappedBy(any())).thenReturn(fetchSpec);
when(fetchSpec.one()).thenReturn(Optional.ofNullable(returnValue));
}
// -----------------------------------------------------------------------
// purgeStaleEntities — P1-2 空快照保护
// -----------------------------------------------------------------------
@Nested
class PurgeStaleEntitiesTest {
@Test
void emptySnapshot_defaultConfig_blocksPurge() {
KnowledgeGraphProperties.Sync syncConfig = new KnowledgeGraphProperties.Sync();
syncConfig.setAllowPurgeOnEmptySnapshot(false);
when(properties.getSync()).thenReturn(syncConfig);
int deleted = stepService.purgeStaleEntities(
GRAPH_ID, "Dataset", Collections.emptySet(), SYNC_ID);
assertThat(deleted).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
@Test
void emptySnapshot_explicitAllow_executesPurge() {
KnowledgeGraphProperties.Sync syncConfig = new KnowledgeGraphProperties.Sync();
syncConfig.setAllowPurgeOnEmptySnapshot(true);
when(properties.getSync()).thenReturn(syncConfig);
setupNeo4jQueryChain(Long.class, 5L);
int deleted = stepService.purgeStaleEntities(
GRAPH_ID, "Dataset", Collections.emptySet(), SYNC_ID);
assertThat(deleted).isEqualTo(5);
verify(neo4jClient).query(anyString());
}
@Test
void nonEmptySnapshot_purgesStaleEntities() {
setupNeo4jQueryChain(Long.class, 2L);
Set<String> activeIds = Set.of("ds-001", "ds-002");
int deleted = stepService.purgeStaleEntities(
GRAPH_ID, "Dataset", activeIds, SYNC_ID);
assertThat(deleted).isEqualTo(2);
verify(neo4jClient).query(contains("NOT e.source_id IN $activeSourceIds"));
}
@Test
void nonEmptySnapshot_nothingToDelete_returnsZero() {
setupNeo4jQueryChain(Long.class, 0L);
Set<String> activeIds = Set.of("ds-001");
int deleted = stepService.purgeStaleEntities(
GRAPH_ID, "Dataset", activeIds, SYNC_ID);
assertThat(deleted).isEqualTo(0);
}
}
// -----------------------------------------------------------------------
// upsertDatasetEntities
// -----------------------------------------------------------------------
@Nested
class UpsertDatasetEntitiesTest {
@Test
void upsert_newEntity_incrementsCreated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setName("Test Dataset");
dto.setDescription("Desc");
dto.setDatasetType("TEXT");
dto.setStatus("ACTIVE");
SyncResult result = stepService.upsertDatasetEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(1);
assertThat(result.getUpdated()).isEqualTo(0);
assertThat(result.getSyncType()).isEqualTo("Dataset");
}
@Test
void upsert_existingEntity_incrementsUpdated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, false);
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setName("Updated");
SyncResult result = stepService.upsertDatasetEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
assertThat(result.getUpdated()).isEqualTo(1);
}
@Test
void upsert_emptyList_returnsZeroCounts() {
when(properties.getImportBatchSize()).thenReturn(100);
SyncResult result = stepService.upsertDatasetEntities(
GRAPH_ID, List.of(), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
assertThat(result.getUpdated()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
@Test
void upsert_cypher_containsPropertiesSetClauses() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setName("Dataset");
dto.setDatasetType("TEXT");
dto.setStatus("ACTIVE");
stepService.upsertDatasetEntities(GRAPH_ID, List.of(dto), SYNC_ID);
verify(neo4jClient).query(cypherCaptor.capture());
String cypher = cypherCaptor.getValue();
assertThat(cypher).contains("MERGE");
assertThat(cypher).contains("properties.");
verifyNoInteractions(entityRepository);
}
@Test
void upsert_multipleEntities_eachGetsSeparateMerge() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
DatasetDTO dto1 = new DatasetDTO();
dto1.setId("ds-001");
dto1.setName("DS1");
DatasetDTO dto2 = new DatasetDTO();
dto2.setId("ds-002");
dto2.setName("DS2");
SyncResult result = stepService.upsertDatasetEntities(
GRAPH_ID, List.of(dto1, dto2), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(2);
verify(neo4jClient, times(2)).query(anyString());
}
}
// -----------------------------------------------------------------------
// upsertFieldEntities
// -----------------------------------------------------------------------
@Nested
class UpsertFieldEntitiesTest {
@Test
void upsertFields_datasetsWithNoTags_returnsZero() {
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setTags(null);
SyncResult result = stepService.upsertFieldEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
assertThat(result.getUpdated()).isEqualTo(0);
}
@Test
void upsertFields_datasetsWithTags_createsFieldPerTag() {
setupNeo4jQueryChain(Boolean.class, true);
DatasetDTO dto = new DatasetDTO();
dto.setId("ds-001");
dto.setName("Dataset1");
TagDTO tag1 = new TagDTO();
tag1.setName("tag_a");
TagDTO tag2 = new TagDTO();
tag2.setName("tag_b");
dto.setTags(List.of(tag1, tag2));
SyncResult result = stepService.upsertFieldEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(2);
}
}
// -----------------------------------------------------------------------
// upsertWorkflowEntities
// -----------------------------------------------------------------------
@Nested
class UpsertWorkflowEntitiesTest {
@Test
void upsert_newWorkflow_incrementsCreated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
WorkflowDTO dto = new WorkflowDTO();
dto.setId("wf-001");
dto.setName("Clean Pipeline");
dto.setWorkflowType("CLEANING");
dto.setStatus("ACTIVE");
dto.setVersion("2.1");
dto.setOperatorCount(3);
SyncResult result = stepService.upsertWorkflowEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(1);
assertThat(result.getSyncType()).isEqualTo("Workflow");
}
@Test
void upsert_emptyList_returnsZero() {
when(properties.getImportBatchSize()).thenReturn(100);
SyncResult result = stepService.upsertWorkflowEntities(
GRAPH_ID, List.of(), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
@Test
void upsert_withInputDatasetIds_storesProperty() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
WorkflowDTO dto = new WorkflowDTO();
dto.setId("wf-001");
dto.setName("Pipeline");
dto.setWorkflowType("CLEANING");
dto.setInputDatasetIds(List.of("ds-001", "ds-002"));
stepService.upsertWorkflowEntities(GRAPH_ID, List.of(dto), SYNC_ID);
verify(neo4jClient).query(cypherCaptor.capture());
assertThat(cypherCaptor.getValue()).contains("properties.input_dataset_ids");
}
}
// -----------------------------------------------------------------------
// upsertJobEntities
// -----------------------------------------------------------------------
@Nested
class UpsertJobEntitiesTest {
@Test
void upsert_newJob_incrementsCreated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
JobDTO dto = new JobDTO();
dto.setId("job-001");
dto.setName("Clean Job");
dto.setJobType("CLEANING");
dto.setStatus("COMPLETED");
dto.setDurationSeconds(2100L);
dto.setInputDatasetId("ds-001");
dto.setOutputDatasetId("ds-002");
dto.setWorkflowId("wf-001");
dto.setCreatedBy("admin");
SyncResult result = stepService.upsertJobEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(1);
assertThat(result.getSyncType()).isEqualTo("Job");
}
@Test
void upsert_jobWithDependency_storesProperty() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
JobDTO dto = new JobDTO();
dto.setId("job-002");
dto.setName("Downstream Job");
dto.setJobType("SYNTHESIS");
dto.setStatus("PENDING");
dto.setDependsOnJobId("job-001");
stepService.upsertJobEntities(GRAPH_ID, List.of(dto), SYNC_ID);
verify(neo4jClient).query(cypherCaptor.capture());
assertThat(cypherCaptor.getValue()).contains("properties.depends_on_job_id");
}
@Test
void upsert_emptyList_returnsZero() {
when(properties.getImportBatchSize()).thenReturn(100);
SyncResult result = stepService.upsertJobEntities(
GRAPH_ID, List.of(), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
}
// -----------------------------------------------------------------------
// upsertLabelTaskEntities
// -----------------------------------------------------------------------
@Nested
class UpsertLabelTaskEntitiesTest {
@Test
void upsert_newLabelTask_incrementsCreated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
LabelTaskDTO dto = new LabelTaskDTO();
dto.setId("lt-001");
dto.setName("Label Task");
dto.setTaskMode("MANUAL");
dto.setStatus("IN_PROGRESS");
dto.setDataType("image");
dto.setLabelingType("object_detection");
dto.setProgress(45.5);
dto.setDatasetId("ds-001");
dto.setCreatedBy("admin");
SyncResult result = stepService.upsertLabelTaskEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(1);
assertThat(result.getSyncType()).isEqualTo("LabelTask");
}
@Test
void upsert_emptyList_returnsZero() {
when(properties.getImportBatchSize()).thenReturn(100);
SyncResult result = stepService.upsertLabelTaskEntities(
GRAPH_ID, List.of(), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
}
// -----------------------------------------------------------------------
// upsertKnowledgeSetEntities
// -----------------------------------------------------------------------
@Nested
class UpsertKnowledgeSetEntitiesTest {
@Test
void upsert_newKnowledgeSet_incrementsCreated() {
when(properties.getImportBatchSize()).thenReturn(100);
setupNeo4jQueryChain(Boolean.class, true);
KnowledgeSetDTO dto = new KnowledgeSetDTO();
dto.setId("ks-001");
dto.setName("Medical Knowledge");
dto.setStatus("PUBLISHED");
dto.setDomain("medical");
dto.setSensitivity("INTERNAL");
dto.setItemCount(320);
dto.setSourceDatasetIds(List.of("ds-001", "ds-002"));
SyncResult result = stepService.upsertKnowledgeSetEntities(
GRAPH_ID, List.of(dto), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(1);
assertThat(result.getSyncType()).isEqualTo("KnowledgeSet");
}
@Test
void upsert_emptyList_returnsZero() {
when(properties.getImportBatchSize()).thenReturn(100);
SyncResult result = stepService.upsertKnowledgeSetEntities(
GRAPH_ID, List.of(), SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
verifyNoInteractions(neo4jClient);
}
}
// -----------------------------------------------------------------------
// 关系构建 - 已有
// -----------------------------------------------------------------------
@Nested
class MergeRelationsTest {
@Test
void mergeHasField_noFields_returnsEmptyResult() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Field"))
.thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset"))
.thenReturn(List.of());
SyncResult result = stepService.mergeHasFieldRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("HAS_FIELD");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeDerivedFrom_noParent_skipsRelation() {
GraphEntity dataset = GraphEntity.builder()
.id("entity-1")
.type("Dataset")
.graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset"))
.thenReturn(List.of(dataset));
SyncResult result = stepService.mergeDerivedFromRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeBelongsTo_noDefaultOrg_returnsError() {
when(entityRepository.findByGraphIdAndSourceIdAndType(GRAPH_ID, "org:default", "Org"))
.thenReturn(Optional.empty());
SyncResult result = stepService.mergeBelongsToRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getFailed()).isGreaterThan(0);
assertThat(result.getErrors()).contains("belongs_to:org_missing");
}
}
// -----------------------------------------------------------------------
// 新增关系构建
// -----------------------------------------------------------------------
@Nested
class NewMergeRelationsTest {
@Test
void mergeUsesDataset_noJobs_noTasks_noWorkflows_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of());
SyncResult result = stepService.mergeUsesDatasetRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("USES_DATASET");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeUsesDataset_jobWithInputDataset_createsRelation() {
setupNeo4jQueryChain(String.class, "new-rel-id");
GraphEntity job = GraphEntity.builder()
.id("job-entity-1").type("Job").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("input_dataset_id", "ds-001")))
.build();
GraphEntity dataset = GraphEntity.builder()
.id("ds-entity-1").sourceId("ds-001").type("Dataset").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of(job));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of(dataset));
SyncResult result = stepService.mergeUsesDatasetRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("USES_DATASET");
verify(neo4jClient).query(cypherCaptor.capture());
String cypher = cypherCaptor.getValue();
// 验证关系类型和方向:source → target
assertThat(cypher).contains("RELATED_TO");
assertThat(cypher).contains("relation_type: $relationType");
}
@Test
void mergeUsesDataset_workflowWithSingleStringInput_handledCorrectly() {
setupNeo4jQueryChain(String.class, "new-rel-id");
// 单值 String 而非 List,验证 toStringList 统一处理
GraphEntity workflow = GraphEntity.builder()
.id("wf-entity-1").type("Workflow").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("input_dataset_ids", "ds-single")))
.build();
GraphEntity dataset = GraphEntity.builder()
.id("ds-entity-s").sourceId("ds-single").type("Dataset").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of(workflow));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of(dataset));
SyncResult result = stepService.mergeUsesDatasetRelations(GRAPH_ID, SYNC_ID);
verify(neo4jClient, atLeastOnce()).query(anyString());
}
@Test
void mergeUsesDataset_listWithNullElements_filtersNulls() {
setupNeo4jQueryChain(String.class, "new-rel-id");
List<Object> listWithNulls = new ArrayList<>();
listWithNulls.add("ds-good");
listWithNulls.add(null);
listWithNulls.add("");
listWithNulls.add(" ");
GraphEntity workflow = GraphEntity.builder()
.id("wf-entity-2").type("Workflow").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("input_dataset_ids", listWithNulls)))
.build();
GraphEntity dataset = GraphEntity.builder()
.id("ds-entity-g").sourceId("ds-good").type("Dataset").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of(workflow));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of(dataset));
SyncResult result = stepService.mergeUsesDatasetRelations(GRAPH_ID, SYNC_ID);
// 只有 "ds-good" 应被处理(null、空、空白已过滤),验证只发起一次 mergeRelation
verify(neo4jClient, times(1)).query(anyString());
}
@Test
void mergeProduces_noJobs_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of());
SyncResult result = stepService.mergeProducesRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("PRODUCES");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeProduces_jobWithoutOutput_skips() {
GraphEntity job = GraphEntity.builder()
.id("job-entity-1").type("Job").graphId(GRAPH_ID)
.properties(new HashMap<>())
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of(job));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of());
SyncResult result = stepService.mergeProducesRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeAssignedTo_noTasksOrJobs_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "User")).thenReturn(List.of());
SyncResult result = stepService.mergeAssignedToRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("ASSIGNED_TO");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeAssignedTo_taskWithCreatedBy_verifiesUserLookup() {
setupNeo4jQueryChain(String.class, "new-rel-id");
GraphEntity task = GraphEntity.builder()
.id("lt-entity-1").type("LabelTask").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("created_by", "admin")))
.build();
GraphEntity user = GraphEntity.builder()
.id("user-entity-1").sourceId("user:admin").type("User").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "LabelTask")).thenReturn(List.of(task));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "User")).thenReturn(List.of(user));
SyncResult result = stepService.mergeAssignedToRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("ASSIGNED_TO");
// 验证通过预加载的 userMap 查找 User 实体(不再有 N+1 查询)
verify(neo4jClient).query(cypherCaptor.capture());
assertThat(cypherCaptor.getValue()).contains("RELATED_TO");
}
@Test
void mergeTriggers_noJobs_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of());
SyncResult result = stepService.mergeTriggersRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("TRIGGERS");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeTriggers_jobWithWorkflowId_createsRelationWithCorrectDirection() {
setupNeo4jQueryChain(String.class, "new-rel-id");
GraphEntity job = GraphEntity.builder()
.id("job-entity-1").type("Job").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("workflow_id", "wf-001")))
.build();
GraphEntity workflow = GraphEntity.builder()
.id("wf-entity-1").sourceId("wf-001").type("Workflow").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of(job));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Workflow")).thenReturn(List.of(workflow));
SyncResult result = stepService.mergeTriggersRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("TRIGGERS");
verify(neo4jClient).query(cypherCaptor.capture());
String cypher = cypherCaptor.getValue();
// 验证 MERGE 使用 RELATED_TO 关系类型
assertThat(cypher).contains("RELATED_TO");
// 验证 Cypher 参数绑定:source 应为 Workflow,target 应为 Job
assertThat(cypher).contains("$sourceEntityId");
assertThat(cypher).contains("$targetEntityId");
}
@Test
void mergeDependsOn_noJobs_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of());
SyncResult result = stepService.mergeDependsOnRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("DEPENDS_ON");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeDependsOn_jobWithDependency_verifiesCypherParams() {
setupNeo4jQueryChain(String.class, "new-rel-id");
GraphEntity job = GraphEntity.builder()
.id("job-entity-2").sourceId("job-002").type("Job").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("depends_on_job_id", "job-001")))
.build();
GraphEntity depJob = GraphEntity.builder()
.id("job-entity-1").sourceId("job-001").type("Job").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Job")).thenReturn(List.of(job, depJob));
SyncResult result = stepService.mergeDependsOnRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("DEPENDS_ON");
verify(neo4jClient).query(cypherCaptor.capture());
String cypher = cypherCaptor.getValue();
assertThat(cypher).contains("RELATED_TO");
assertThat(cypher).contains("$relationType");
}
@Test
void mergeImpacts_returnsPlaceholderResult() {
SyncResult result = stepService.mergeImpactsRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("IMPACTS");
assertThat(result.getCreated()).isEqualTo(0);
assertThat(result.isPlaceholder()).isTrue();
verifyNoInteractions(neo4jClient);
verifyNoInteractions(entityRepository);
}
@Test
void mergeSourcedFrom_noKnowledgeSets_returnsZero() {
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "KnowledgeSet")).thenReturn(List.of());
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of());
SyncResult result = stepService.mergeSourcedFromRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("SOURCED_FROM");
assertThat(result.getCreated()).isEqualTo(0);
}
@Test
void mergeSourcedFrom_ksWithSources_verifiesCypherAndLookup() {
setupNeo4jQueryChain(String.class, "new-rel-id");
GraphEntity ks = GraphEntity.builder()
.id("ks-entity-1").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("source_dataset_ids", List.of("ds-001"))))
.build();
GraphEntity dataset = GraphEntity.builder()
.id("ds-entity-1").sourceId("ds-001").type("Dataset").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "KnowledgeSet")).thenReturn(List.of(ks));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of(dataset));
SyncResult result = stepService.mergeSourcedFromRelations(GRAPH_ID, SYNC_ID);
assertThat(result.getSyncType()).isEqualTo("SOURCED_FROM");
verify(neo4jClient).query(cypherCaptor.capture());
assertThat(cypherCaptor.getValue()).contains("RELATED_TO");
}
@Test
void mergeSourcedFrom_listWithNullElements_filtersNulls() {
setupNeo4jQueryChain(String.class, "new-rel-id");
List<Object> listWithNulls = new ArrayList<>();
listWithNulls.add("ds-valid");
listWithNulls.add(null);
listWithNulls.add("");
GraphEntity ks = GraphEntity.builder()
.id("ks-entity-2").type("KnowledgeSet").graphId(GRAPH_ID)
.properties(new HashMap<>(Map.of("source_dataset_ids", listWithNulls)))
.build();
GraphEntity dataset = GraphEntity.builder()
.id("ds-entity-v").sourceId("ds-valid").type("Dataset").graphId(GRAPH_ID)
.build();
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "KnowledgeSet")).thenReturn(List.of(ks));
when(entityRepository.findByGraphIdAndType(GRAPH_ID, "Dataset")).thenReturn(List.of(dataset));
SyncResult result = stepService.mergeSourcedFromRelations(GRAPH_ID, SYNC_ID);
// 只有 "ds-valid" 应被处理,null 和空字符串已过滤;验证只发起一次 mergeRelation
verify(neo4jClient, times(1)).query(anyString());
}
}
}

View File

@@ -0,0 +1,157 @@
package com.datamate.knowledgegraph.infrastructure.neo4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.boot.DefaultApplicationArguments;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.data.neo4j.core.Neo4jClient.UnboundRunnableSpec;
import org.springframework.data.neo4j.core.Neo4jClient.RunnableSpec;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class GraphInitializerTest {
@Mock
private Neo4jClient neo4jClient;
private GraphInitializer createInitializer(String password, String profile, boolean autoInit) {
KnowledgeGraphProperties properties = new KnowledgeGraphProperties();
properties.getSync().setAutoInitSchema(autoInit);
GraphInitializer initializer = new GraphInitializer(neo4jClient, properties);
ReflectionTestUtils.setField(initializer, "neo4jPassword", password);
ReflectionTestUtils.setField(initializer, "activeProfile", profile);
return initializer;
}
// -----------------------------------------------------------------------
// P1-3: 默认凭据检测
// -----------------------------------------------------------------------
@Test
void run_defaultPassword_prodProfile_throwsException() {
GraphInitializer initializer = createInitializer("datamate123", "prod", false);
assertThatThrownBy(() -> initializer.run(new DefaultApplicationArguments()))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("SECURITY")
.hasMessageContaining("default");
}
@Test
void run_defaultPassword_stagingProfile_throwsException() {
GraphInitializer initializer = createInitializer("neo4j", "staging", false);
assertThatThrownBy(() -> initializer.run(new DefaultApplicationArguments()))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("SECURITY");
}
@Test
void run_defaultPassword_devProfile_warnsButContinues() {
GraphInitializer initializer = createInitializer("datamate123", "dev", false);
// Should not throw — just warn
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
@Test
void run_defaultPassword_testProfile_warnsButContinues() {
GraphInitializer initializer = createInitializer("datamate123", "test", false);
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
@Test
void run_defaultPassword_localProfile_warnsButContinues() {
GraphInitializer initializer = createInitializer("password", "local", false);
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
@Test
void run_securePassword_prodProfile_succeeds() {
GraphInitializer initializer = createInitializer("s3cure!P@ssw0rd", "prod", false);
// Schema init disabled, so no queries. Should succeed.
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
@Test
void run_blankPassword_skipsValidation() {
GraphInitializer initializer = createInitializer("", "prod", false);
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
// -----------------------------------------------------------------------
// Schema 初始化 — 成功
// -----------------------------------------------------------------------
@Test
void run_autoInitEnabled_executesAllStatements() {
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
when(neo4jClient.query(anyString())).thenReturn(spec);
initializer.run(new DefaultApplicationArguments());
// Should execute all schema statements (constraints + indexes + fulltext)
verify(neo4jClient, atLeast(10)).query(anyString());
}
@Test
void run_autoInitDisabled_skipsSchemaInit() {
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", false);
initializer.run(new DefaultApplicationArguments());
verifyNoInteractions(neo4jClient);
}
// -----------------------------------------------------------------------
// P2-7: Schema 初始化错误处理
// -----------------------------------------------------------------------
@Test
void run_alreadyExistsError_safelyIgnored() {
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
when(neo4jClient.query(anyString())).thenReturn(spec);
doThrow(new RuntimeException("Constraint already exists"))
.when(spec).run();
// Should not throw — "already exists" errors are safely ignored
assertThatCode(() -> initializer.run(new DefaultApplicationArguments()))
.doesNotThrowAnyException();
}
@Test
void run_nonExistenceError_throwsException() {
GraphInitializer initializer = createInitializer("s3cure!P@ss", "dev", true);
UnboundRunnableSpec spec = mock(UnboundRunnableSpec.class);
when(neo4jClient.query(anyString())).thenReturn(spec);
doThrow(new RuntimeException("Connection refused to Neo4j"))
.when(spec).run();
// Non-"already exists" errors should propagate
assertThatThrownBy(() -> initializer.run(new DefaultApplicationArguments()))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("schema initialization failed");
}
}

View File

@@ -109,6 +109,13 @@
<version>${project.version}</version>
</dependency>
<!-- 知识图谱服务依赖 -->
<dependency>
<groupId>com.datamate</groupId>
<artifactId>knowledge-graph-service</artifactId>
<version>${project.version}</version>
</dependency>
<!-- Database -->
<dependency>
<groupId>com.mysql</groupId>

View File

@@ -52,6 +52,7 @@ spring:
import:
- classpath:config/application-datacollection.yml
- classpath:config/application-datamanagement.yml
- optional:classpath:application-knowledgegraph.yml
# Redis配置
data:

View File

@@ -37,6 +37,9 @@
<module>rag-indexer-service</module>
<module>rag-query-service</module>
<!-- 知识图谱服务 -->
<module>knowledge-graph-service</module>
<!-- 主启动模块 -->
<module>main-application</module>
</modules>

View File

@@ -0,0 +1,39 @@
services:
datamate-neo4j:
container_name: datamate-neo4j
image: neo4j:5-community
restart: on-failure
ports:
- "7474:7474" # HTTP (Neo4j Browser)
- "7687:7687" # Bolt protocol
environment:
NEO4J_AUTH: neo4j/${NEO4J_PASSWORD:-datamate123}
# Memory configuration
NEO4J_server_memory_heap_initial__size: 512m
NEO4J_server_memory_heap_max__size: 1G
NEO4J_server_memory_pagecache_size: 512m
# Enable APOC plugin
NEO4J_PLUGINS: '["apoc"]'
# Transaction timeout
NEO4J_db_transaction_timeout: 60s
volumes:
- neo4j_data:/data
- neo4j_logs:/logs
networks: [ datamate ]
healthcheck:
test: ["CMD-SHELL", "wget --no-verbose --tries=1 --spider http://localhost:7474 || exit 1"]
interval: 15s
timeout: 10s
retries: 5
start_period: 30s
volumes:
neo4j_data:
name: datamate-neo4j-data-volume
neo4j_logs:
name: datamate-neo4j-logs-volume
networks:
datamate:
driver: bridge
name: datamate-network

View File

@@ -1,6 +1,15 @@
from pydantic_settings import BaseSettings
from pydantic import model_validator
from pydantic import SecretStr, model_validator
from typing import Optional
import logging
import os
_logger = logging.getLogger(__name__)
# 已知的弱默认凭据,生产环境禁止使用
_BLOCKED_DEFAULT_PASSWORDS = {"password", "123456", "admin", "root", "datamate123"}
_BLOCKED_DEFAULT_TOKENS = {"abc123abc123", "EMPTY"}
class Settings(BaseSettings):
"""应用程序配置"""
@@ -62,11 +71,51 @@ class Settings(BaseSettings):
# DataMate
dm_file_path_prefix: str = "/dataset" # DM存储文件夹前缀
# DataMate Backend (Java) - 用于通过下载/预览接口读取文件内容
# DataMate Backend (Java) - 用于通过"下载/预览接口"读取文件内容
datamate_backend_base_url: str = "http://datamate-backend:8080/api"
# Knowledge Graph - LLM 三元组抽取配置
kg_llm_api_key: SecretStr = SecretStr("EMPTY")
kg_llm_base_url: Optional[str] = None
kg_llm_model: str = "gpt-4o-mini"
kg_llm_temperature: float = 0.0
kg_llm_timeout_seconds: int = 60
kg_llm_max_retries: int = 2
# 标注编辑器(Label Studio Editor)相关
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
@model_validator(mode='after')
def check_default_credentials(self):
"""生产环境下检测弱默认凭据,拒绝启动。
通过环境变量 DATAMATE_ENV 判断环境:
- dev/test/local: 仅发出警告
- 其他(prod/staging 等): 抛出异常阻止启动
"""
env = os.environ.get("DATAMATE_ENV", "dev").lower()
is_dev = env in ("dev", "test", "local", "development")
issues: list[str] = []
if self.mysql_password in _BLOCKED_DEFAULT_PASSWORDS:
issues.append(f"mysql_password is set to a weak default ('{self.mysql_password}')")
if self.label_studio_password and self.label_studio_password in _BLOCKED_DEFAULT_PASSWORDS:
issues.append("label_studio_password is set to a weak default")
if self.label_studio_user_token and self.label_studio_user_token in _BLOCKED_DEFAULT_TOKENS:
issues.append("label_studio_user_token is set to a weak default")
if issues:
msg = "SECURITY: Weak default credentials detected: " + "; ".join(issues)
if is_dev:
_logger.warning(msg + " (acceptable in dev/test, MUST change for production)")
else:
raise ValueError(
msg + ". Set proper credentials via environment variables "
"before deploying to production."
)
return self
# 全局设置实例
settings = Settings()

View File

@@ -7,6 +7,7 @@ from .generation.interface import router as generation_router
from .evaluation.interface import router as evaluation_router
from .collection.interface import router as collection_route
from .dataset.interface import router as dataset_router
from .kg_extraction.interface import router as kg_extraction_router
router = APIRouter(
prefix="/api"
@@ -19,5 +20,6 @@ router.include_router(generation_router)
router.include_router(evaluation_router)
router.include_router(collection_route)
router.include_router(dataset_router)
router.include_router(kg_extraction_router)
__all__ = ["router"]

View File

@@ -0,0 +1,19 @@
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
Triple,
GraphNode,
GraphEdge,
)
from app.module.kg_extraction.interface import router
__all__ = [
"KnowledgeGraphExtractor",
"ExtractionRequest",
"ExtractionResult",
"Triple",
"GraphNode",
"GraphEdge",
"router",
]

View File

@@ -0,0 +1,228 @@
"""基于 LLM 的知识图谱三元组抽取器。
利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系,
支持 schema-guided 抽取以提升准确率。
"""
from __future__ import annotations
import hashlib
from typing import Sequence
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from pydantic import SecretStr
from app.core.logging import get_logger
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
GraphEdge,
GraphNode,
Triple,
)
logger = get_logger(__name__)
def _text_fingerprint(text: str) -> str:
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
class KnowledgeGraphExtractor:
"""基于 LLMGraphTransformer 的三元组抽取器。
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
也可直接构造以覆盖默认参数。
"""
def __init__(
self,
model_name: str = "gpt-4o-mini",
base_url: str | None = None,
api_key: SecretStr = SecretStr("EMPTY"),
temperature: float = 0.0,
timeout: int = 60,
max_retries: int = 2,
) -> None:
logger.info(
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
model_name,
base_url or "default",
timeout,
max_retries,
)
self._llm = ChatOpenAI(
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature,
timeout=timeout,
max_retries=max_retries,
)
@classmethod
def from_settings(cls) -> KnowledgeGraphExtractor:
"""从全局 Settings 创建抽取器实例。"""
from app.core.config import settings
return cls(
model_name=settings.kg_llm_model,
base_url=settings.kg_llm_base_url,
api_key=settings.kg_llm_api_key,
temperature=settings.kg_llm_temperature,
timeout=settings.kg_llm_timeout_seconds,
max_retries=settings.kg_llm_max_retries,
)
def _build_transformer(
self,
schema: ExtractionSchema | None = None,
) -> LLMGraphTransformer:
"""根据可选的 schema 约束构造 LLMGraphTransformer。"""
kwargs: dict = {"llm": self._llm}
if schema:
if schema.entity_types:
kwargs["allowed_nodes"] = [et.name for et in schema.entity_types]
if schema.relation_types:
kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types]
return LLMGraphTransformer(**kwargs)
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
"""从文本中异步抽取三元组。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = await transformer.aconvert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
logger.info(
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
len(result.triples),
)
return result
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
"""同步版本的三元组抽取。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = transformer.convert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
logger.info(
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
)
return result
async def extract_batch(
self,
requests: Sequence[ExtractionRequest],
) -> list[ExtractionResult]:
"""批量抽取,逐条处理。
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
"""
logger.info("Starting batch extraction: count=%d", len(requests))
results: list[ExtractionResult] = []
for i, req in enumerate(requests):
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
result = await self.extract(req)
results.append(result)
logger.info("Batch extraction complete: count=%d", len(results))
return results
@staticmethod
def _convert_result(
graph_documents: list,
request: ExtractionRequest,
) -> ExtractionResult:
"""将 LangChain GraphDocument 转换为内部数据模型。"""
nodes: list[GraphNode] = []
edges: list[GraphEdge] = []
triples: list[Triple] = []
seen_nodes: set[str] = set()
for doc in graph_documents:
for node in doc.nodes:
node_key = f"{node.id}:{node.type}"
if node_key not in seen_nodes:
seen_nodes.add(node_key)
nodes.append(
GraphNode(
name=node.id,
type=node.type,
properties=node.properties if hasattr(node, "properties") else {},
)
)
for rel in doc.relationships:
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
edges.append(
GraphEdge(
source=rel.source.id,
target=rel.target.id,
relation_type=rel.type,
properties=rel.properties if hasattr(rel, "properties") else {},
)
)
triples.append(
Triple(subject=source_node, predicate=rel.type, object=target_node)
)
return ExtractionResult(
nodes=nodes,
edges=edges,
triples=triples,
raw_text=request.text,
source_id=request.source_id,
)

View File

@@ -0,0 +1,193 @@
"""知识图谱三元组抽取 API。
注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用,
外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。
当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from app.core.logging import get_logger
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
EntityTypeConstraint,
RelationTypeConstraint,
)
from app.module.shared.schema import StandardResponse
router = APIRouter(prefix="/kg", tags=["knowledge-graph"])
logger = get_logger(__name__)
# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM
_extractor: KnowledgeGraphExtractor | None = None
_UUID_PATTERN = (
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符
_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$"
def _get_extractor() -> KnowledgeGraphExtractor:
global _extractor
if _extractor is None:
_extractor = KnowledgeGraphExtractor.from_settings()
return _extractor
def _require_caller_id(
x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")],
) -> str:
"""从请求头提取调用方用户 ID,用于审计日志。
该接口为内部服务调用,调用方身份由上游 Java 后端通过
X-User-Id 请求头传递。缺失或为空时返回 401。
"""
caller = x_user_id.strip()
if not caller:
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
return caller
# ---------------------------------------------------------------------------
# Request / Response DTO(API 层,与内部 models 解耦)
# ---------------------------------------------------------------------------
class SourceType(str, Enum):
ANNOTATION = "ANNOTATION"
KNOWLEDGE_BASE = "KNOWLEDGE_BASE"
IMPORT = "IMPORT"
MANUAL = "MANUAL"
class ExtractRequest(BaseModel):
"""三元组抽取请求。"""
text: str = Field(
...,
min_length=1,
max_length=50000,
description="待抽取的文本内容",
examples=["张三是北京大学的教授,研究方向为人工智能。"],
)
graph_id: str = Field(
...,
pattern=_UUID_PATTERN,
description="目标图谱 ID(UUID 格式)",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符",
examples=[["Person", "Organization", "Location"]],
)
allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的关系类型列表(schema-guided 抽取)",
examples=[["works_at", "located_in"]],
)
source_id: Optional[str] = Field(
default=None,
pattern=_UUID_PATTERN,
description="来源 ID(数据集/知识库条目,UUID 格式)",
)
source_type: SourceType = Field(
default=SourceType.KNOWLEDGE_BASE,
description="来源类型",
)
class BatchExtractRequest(BaseModel):
"""批量三元组抽取请求。"""
items: list[ExtractRequest] = Field(
...,
min_length=1,
max_length=50,
description="抽取请求列表,单次最多 50 条",
)
def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest:
"""将 API DTO 转换为内部抽取请求。"""
schema: ExtractionSchema | None = None
if req.allowed_nodes or req.allowed_relationships:
schema = ExtractionSchema(
entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])],
relation_types=[
RelationTypeConstraint(name=r) for r in (req.allowed_relationships or [])
],
)
return ExtractionRequest(
text=req.text,
graph_id=req.graph_id,
schema=schema,
source_id=req.source_id,
source_type=req.source_type.value,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post(
"/extract",
response_model=StandardResponse[ExtractionResult],
summary="三元组抽取",
description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。",
)
async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""单条文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
extractor = _get_extractor()
extraction_req = _to_extraction_request(req)
try:
result = await extractor.extract(extraction_req)
except Exception:
logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=result)
@router.post(
"/extract/batch",
response_model=StandardResponse[list[ExtractionResult]],
summary="批量三元组抽取",
description="对多段文本逐条抽取三元组,单次最多 50 条。",
)
async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""批量文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller)
extractor = _get_extractor()
extraction_reqs = [_to_extraction_request(item) for item in req.items]
try:
results = await extractor.extract_batch(extraction_reqs)
except Exception:
logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=results)

View File

@@ -0,0 +1,75 @@
"""知识图谱三元组抽取数据模型。"""
from __future__ import annotations
from pydantic import BaseModel, Field
class GraphNode(BaseModel):
"""图谱节点(实体)。"""
name: str = Field(..., description="实体名称")
type: str = Field(..., description="实体类型, 如 Person, Organization, Location")
properties: dict[str, object] = Field(default_factory=dict, description="扩展属性")
class GraphEdge(BaseModel):
"""图谱边(关系)。"""
source: str = Field(..., description="源实体名称")
target: str = Field(..., description="目标实体名称")
relation_type: str = Field(..., description="关系类型, 如 works_at, located_in")
properties: dict[str, object] = Field(default_factory=dict, description="关系属性")
class Triple(BaseModel):
"""知识三元组: (主体, 关系, 客体)。"""
subject: GraphNode
predicate: str = Field(..., description="关系类型")
object: GraphNode
class EntityTypeConstraint(BaseModel):
"""实体类型约束,用于 Schema-guided 抽取。"""
name: str = Field(..., description="类型名称")
description: str = Field(default="", description="类型说明")
class RelationTypeConstraint(BaseModel):
"""关系类型约束。"""
name: str = Field(..., description="关系类型名称")
source_types: list[str] = Field(default_factory=list, description="允许的源实体类型")
target_types: list[str] = Field(default_factory=list, description="允许的目标实体类型")
description: str = Field(default="", description="关系说明")
class ExtractionSchema(BaseModel):
"""抽取 schema 约束,约束 LLM 输出的实体和关系类型范围。"""
entity_types: list[EntityTypeConstraint] = Field(default_factory=list)
relation_types: list[RelationTypeConstraint] = Field(default_factory=list)
class ExtractionRequest(BaseModel):
"""三元组抽取请求。"""
text: str = Field(..., description="待抽取的文本")
graph_id: str = Field(..., description="目标图谱 ID")
schema: ExtractionSchema | None = Field(
default=None, description="可选的 schema 约束, 提供后做 schema-guided 抽取"
)
source_id: str | None = Field(default=None, description="来源 ID(数据集/知识库条目)")
source_type: str = Field(default="KNOWLEDGE_BASE", description="来源类型")
class ExtractionResult(BaseModel):
"""三元组抽取结果。"""
nodes: list[GraphNode] = Field(default_factory=list)
edges: list[GraphEdge] = Field(default_factory=list)
triples: list[Triple] = Field(default_factory=list)
raw_text: str = Field(default="", description="原始文本")
source_id: str | None = None

View File

@@ -31,6 +31,7 @@ dependencies = [
"openai (>=2.9.0,<3.0.0)",
"langchain-openai (>=1.1.1,<2.0.0)",
"langchain (>=1.1.3,<2.0.0)",
"langchain-experimental (>=0.3.0,<1.0.0)",
"pydantic (>=2.12.5,<3.0.0)",
"sqlalchemy (>=2.0.45,<3.0.0)",
"fastapi (>=0.124.0,<0.125.0)",