You've already forked DataMate
核心功能: - Neo4j 索引优化(entityType, graphId, properties.name) - Redis 缓存(Java 侧,3 个缓存区,TTL 可配置) - LRU 缓存(Python 侧,KG + Embedding,线程安全) - 细粒度缓存清除(graphId 前缀匹配) - 失败路径缓存清除(finally 块) 新增文件(Java 侧,7 个): - V2__PerformanceIndexes.java - Flyway 迁移,创建 3 个索引 - IndexHealthService.java - 索引健康监控 - RedisCacheConfig.java - Spring Cache + Redis 配置 - GraphCacheService.java - 缓存清除管理器 - CacheableIntegrationTest.java - 集成测试(10 tests) - GraphCacheServiceTest.java - 单元测试(19 tests) - V2__PerformanceIndexesTest.java, IndexHealthServiceTest.java 新增文件(Python 侧,2 个): - cache.py - 内存 TTL+LRU 缓存(cachetools) - test_cache.py - 单元测试(20 tests) 修改文件(Java 侧,9 个): - GraphEntityService.java - 添加 @Cacheable,缓存清除 - GraphQueryService.java - 添加 @Cacheable(包含用户权限上下文) - GraphRelationService.java - 添加缓存清除 - GraphSyncService.java - 添加缓存清除(finally 块,失败路径) - KnowledgeGraphProperties.java - 添加 Cache 配置类 - application-knowledgegraph.yml - 添加 Redis 和缓存 TTL 配置 - GraphEntityServiceTest.java - 添加 verify(cacheService) 断言 - GraphRelationServiceTest.java - 添加 verify(cacheService) 断言 - GraphSyncServiceTest.java - 添加失败路径缓存清除测试 修改文件(Python 侧,5 个): - kg_client.py - 集成缓存(fulltext_search, get_subgraph) - interface.py - 添加 /cache/stats 和 /cache/clear 端点 - config.py - 添加缓存配置字段 - pyproject.toml - 添加 cachetools 依赖 - test_kg_client.py - 添加 _disable_cache fixture 安全修复(3 轮迭代): - P0: 缓存 key 用户隔离(防止跨用户数据泄露) - P1-1: 同步子步骤后的缓存清除(18 个方法) - P1-2: 实体创建后的搜索缓存清除 - P1-3: 失败路径缓存清除(finally 块) - P2-1: 细粒度缓存清除(graphId 前缀匹配,避免跨图谱冲刷) - P2-2: 服务层测试添加 verify(cacheService) 断言 测试结果: - Java: 280 tests pass ✅ (270 → 280, +10 new) - Python: 154 tests pass ✅ (140 → 154, +14 new) 缓存配置: - kg:entities - 实体缓存,TTL 1h - kg:queries - 查询结果缓存,TTL 5min - kg:search - 全文搜索缓存,TTL 3min - KG cache (Python) - 256 entries, 5min TTL - Embedding cache (Python) - 512 entries, 10min TTL
307 lines
12 KiB
Python
307 lines
12 KiB
Python
"""KG 服务 REST 客户端的单元测试。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.module.kg_graphrag.cache import GraphRAGCache
|
|
from app.module.kg_graphrag.kg_client import KGServiceClient
|
|
|
|
|
|
@pytest.fixture
|
|
def client() -> KGServiceClient:
|
|
return KGServiceClient(
|
|
base_url="http://test-kg:8080",
|
|
internal_token="test-token",
|
|
timeout=5.0,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _disable_cache():
|
|
"""为每个测试禁用缓存,防止跨测试缓存命中干扰 mock 验证。"""
|
|
disabled = GraphRAGCache(kg_maxsize=0, kg_ttl=1, embedding_maxsize=0, embedding_ttl=1)
|
|
with patch("app.module.kg_graphrag.kg_client.get_cache", return_value=disabled):
|
|
yield
|
|
|
|
|
|
def _run(coro):
|
|
return asyncio.run(coro)
|
|
|
|
|
|
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
|
|
|
|
|
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
|
"""创建带 request 的 httpx.Response(raise_for_status 需要)。"""
|
|
if json is not None:
|
|
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
|
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# fulltext_search 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFulltextSearch:
|
|
"""fulltext_search 方法的测试。"""
|
|
|
|
def test_wrapped_paged_response(self, client: KGServiceClient):
|
|
"""Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}"""
|
|
mock_body = {
|
|
"code": 200,
|
|
"data": {
|
|
"page": 0,
|
|
"size": 20,
|
|
"totalElements": 2,
|
|
"totalPages": 1,
|
|
"content": [
|
|
{"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5},
|
|
{"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8},
|
|
],
|
|
},
|
|
}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1"))
|
|
|
|
assert len(entities) == 2
|
|
assert entities[0].id == "e1"
|
|
assert entities[0].name == "用户数据"
|
|
assert entities[0].type == "Dataset"
|
|
assert entities[1].name == "清洗管道"
|
|
|
|
def test_unwrapped_paged_response(self, client: KGServiceClient):
|
|
"""Java 直接返回 PagedResponse(无全局包装)。"""
|
|
mock_body = {
|
|
"page": 0,
|
|
"size": 10,
|
|
"totalElements": 1,
|
|
"totalPages": 1,
|
|
"content": [
|
|
{"id": "e1", "name": "A", "type": "Dataset", "description": "desc"},
|
|
],
|
|
}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities = _run(client.fulltext_search("graph-1", "A"))
|
|
|
|
# body has no "data" key → fallback to body itself → read "content"
|
|
assert len(entities) == 1
|
|
assert entities[0].name == "A"
|
|
|
|
def test_empty_content(self, client: KGServiceClient):
|
|
mock_body = {"code": 200, "data": {"page": 0, "content": []}}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities = _run(client.fulltext_search("graph-1", "nothing"))
|
|
|
|
assert entities == []
|
|
|
|
def test_fail_open_on_http_error(self, client: KGServiceClient):
|
|
"""HTTP 错误时 fail-open 返回空列表。"""
|
|
mock_resp = _resp(500, text="Internal Server Error")
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities = _run(client.fulltext_search("graph-1", "test"))
|
|
|
|
assert entities == []
|
|
|
|
def test_fail_open_on_connection_error(self, client: KGServiceClient):
|
|
"""连接错误时 fail-open 返回空列表。"""
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
|
mock_get.return_value = mock_http
|
|
|
|
entities = _run(client.fulltext_search("graph-1", "test"))
|
|
|
|
assert entities == []
|
|
|
|
def test_request_headers(self, client: KGServiceClient):
|
|
"""验证请求中携带正确的 headers。"""
|
|
mock_resp = _resp(200, json={"data": {"content": []}})
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
_run(client.fulltext_search("gid", "q", size=5, user_id="user-123"))
|
|
|
|
call_kwargs = mock_http.get.call_args
|
|
assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token"
|
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123"
|
|
assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_subgraph 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetSubgraph:
|
|
"""get_subgraph 方法的测试。"""
|
|
|
|
def test_wrapped_subgraph_response(self, client: KGServiceClient):
|
|
"""Java 返回被全局包装的 SubgraphExportVO。"""
|
|
mock_body = {
|
|
"code": 200,
|
|
"data": {
|
|
"nodes": [
|
|
{"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}},
|
|
{"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}},
|
|
],
|
|
"edges": [
|
|
{
|
|
"id": "edge1",
|
|
"sourceEntityId": "n1",
|
|
"targetEntityId": "n2",
|
|
"relationType": "HAS_FIELD",
|
|
"weight": 1.0,
|
|
"confidence": 0.9,
|
|
"sourceId": "kb-1",
|
|
},
|
|
],
|
|
"nodeCount": 2,
|
|
"edgeCount": 1,
|
|
},
|
|
}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1"))
|
|
|
|
assert len(entities) == 2
|
|
assert entities[0].name == "用户数据"
|
|
assert entities[1].name == "user_id"
|
|
|
|
assert len(relations) == 1
|
|
assert relations[0].source_name == "用户数据"
|
|
assert relations[0].target_name == "user_id"
|
|
assert relations[0].relation_type == "HAS_FIELD"
|
|
assert relations[0].source_type == "Dataset"
|
|
assert relations[0].target_type == "Field"
|
|
|
|
def test_unwrapped_subgraph_response(self, client: KGServiceClient):
|
|
"""Java 直接返回 SubgraphExportVO(无全局包装)。"""
|
|
mock_body = {
|
|
"nodes": [
|
|
{"id": "n1", "name": "A", "type": "T1", "description": ""},
|
|
],
|
|
"edges": [],
|
|
"nodeCount": 1,
|
|
"edgeCount": 0,
|
|
}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
|
|
|
assert len(entities) == 1
|
|
assert entities[0].name == "A"
|
|
assert relations == []
|
|
|
|
def test_edge_with_unknown_entity(self, client: KGServiceClient):
|
|
"""边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。"""
|
|
mock_body = {
|
|
"code": 200,
|
|
"data": {
|
|
"nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}],
|
|
"edges": [
|
|
{
|
|
"sourceEntityId": "n1",
|
|
"targetEntityId": "n999",
|
|
"relationType": "DEPENDS_ON",
|
|
},
|
|
],
|
|
},
|
|
}
|
|
mock_resp = _resp(200, json=mock_body)
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
|
|
|
assert len(relations) == 1
|
|
assert relations[0].source_name == "A"
|
|
assert relations[0].target_name == "n999" # fallback to ID
|
|
assert relations[0].target_type == ""
|
|
|
|
def test_fail_open_on_error(self, client: KGServiceClient):
|
|
mock_resp = _resp(500, text="error")
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
|
|
|
assert entities == []
|
|
assert relations == []
|
|
|
|
def test_request_params(self, client: KGServiceClient):
|
|
"""验证子图请求参数正确传递。"""
|
|
mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}})
|
|
with patch.object(client, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.post = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
_run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1"))
|
|
|
|
call_kwargs = mock_http.post.call_args
|
|
assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0]
|
|
assert call_kwargs.kwargs["params"] == {"depth": 3}
|
|
assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# headers 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHeaders:
|
|
def test_headers_with_token_and_user(self, client: KGServiceClient):
|
|
headers = client._headers(user_id="user-1")
|
|
assert headers["X-Internal-Token"] == "test-token"
|
|
assert headers["X-User-Id"] == "user-1"
|
|
|
|
def test_headers_without_user(self, client: KGServiceClient):
|
|
headers = client._headers()
|
|
assert "X-Internal-Token" in headers
|
|
assert "X-User-Id" not in headers
|
|
|
|
def test_headers_without_token(self):
|
|
c = KGServiceClient(base_url="http://test:8080", internal_token="")
|
|
headers = c._headers(user_id="u1")
|
|
assert "X-Internal-Token" not in headers
|
|
assert headers["X-User-Id"] == "u1"
|