Files
DataMate/runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py
Jerry Yan 39338df808 feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证

新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口

修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖

测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)

关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*

安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)

测试结果:79 tests pass 
2026-02-20 09:41:55 +08:00

298 lines
11 KiB
Python

"""KG 服务 REST 客户端的单元测试。"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.module.kg_graphrag.kg_client import KGServiceClient
@pytest.fixture
def client() -> KGServiceClient:
return KGServiceClient(
base_url="http://test-kg:8080",
internal_token="test-token",
timeout=5.0,
)
def _run(coro):
return asyncio.run(coro)
_FAKE_REQUEST = httpx.Request("GET", "http://test")
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
"""创建带 request 的 httpx.Response(raise_for_status 需要)。"""
if json is not None:
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
# ---------------------------------------------------------------------------
# fulltext_search 测试
# ---------------------------------------------------------------------------
class TestFulltextSearch:
"""fulltext_search 方法的测试。"""
def test_wrapped_paged_response(self, client: KGServiceClient):
"""Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}"""
mock_body = {
"code": 200,
"data": {
"page": 0,
"size": 20,
"totalElements": 2,
"totalPages": 1,
"content": [
{"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5},
{"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8},
],
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1"))
assert len(entities) == 2
assert entities[0].id == "e1"
assert entities[0].name == "用户数据"
assert entities[0].type == "Dataset"
assert entities[1].name == "清洗管道"
def test_unwrapped_paged_response(self, client: KGServiceClient):
"""Java 直接返回 PagedResponse(无全局包装)。"""
mock_body = {
"page": 0,
"size": 10,
"totalElements": 1,
"totalPages": 1,
"content": [
{"id": "e1", "name": "A", "type": "Dataset", "description": "desc"},
],
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "A"))
# body has no "data" key → fallback to body itself → read "content"
assert len(entities) == 1
assert entities[0].name == "A"
def test_empty_content(self, client: KGServiceClient):
mock_body = {"code": 200, "data": {"page": 0, "content": []}}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "nothing"))
assert entities == []
def test_fail_open_on_http_error(self, client: KGServiceClient):
"""HTTP 错误时 fail-open 返回空列表。"""
mock_resp = _resp(500, text="Internal Server Error")
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "test"))
assert entities == []
def test_fail_open_on_connection_error(self, client: KGServiceClient):
"""连接错误时 fail-open 返回空列表。"""
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "test"))
assert entities == []
def test_request_headers(self, client: KGServiceClient):
"""验证请求中携带正确的 headers。"""
mock_resp = _resp(200, json={"data": {"content": []}})
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
_run(client.fulltext_search("gid", "q", size=5, user_id="user-123"))
call_kwargs = mock_http.get.call_args
assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token"
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123"
assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5}
# ---------------------------------------------------------------------------
# get_subgraph 测试
# ---------------------------------------------------------------------------
class TestGetSubgraph:
"""get_subgraph 方法的测试。"""
def test_wrapped_subgraph_response(self, client: KGServiceClient):
"""Java 返回被全局包装的 SubgraphExportVO。"""
mock_body = {
"code": 200,
"data": {
"nodes": [
{"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}},
{"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}},
],
"edges": [
{
"id": "edge1",
"sourceEntityId": "n1",
"targetEntityId": "n2",
"relationType": "HAS_FIELD",
"weight": 1.0,
"confidence": 0.9,
"sourceId": "kb-1",
},
],
"nodeCount": 2,
"edgeCount": 1,
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1"))
assert len(entities) == 2
assert entities[0].name == "用户数据"
assert entities[1].name == "user_id"
assert len(relations) == 1
assert relations[0].source_name == "用户数据"
assert relations[0].target_name == "user_id"
assert relations[0].relation_type == "HAS_FIELD"
assert relations[0].source_type == "Dataset"
assert relations[0].target_type == "Field"
def test_unwrapped_subgraph_response(self, client: KGServiceClient):
"""Java 直接返回 SubgraphExportVO(无全局包装)。"""
mock_body = {
"nodes": [
{"id": "n1", "name": "A", "type": "T1", "description": ""},
],
"edges": [],
"nodeCount": 1,
"edgeCount": 0,
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert len(entities) == 1
assert entities[0].name == "A"
assert relations == []
def test_edge_with_unknown_entity(self, client: KGServiceClient):
"""边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。"""
mock_body = {
"code": 200,
"data": {
"nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}],
"edges": [
{
"sourceEntityId": "n1",
"targetEntityId": "n999",
"relationType": "DEPENDS_ON",
},
],
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert len(relations) == 1
assert relations[0].source_name == "A"
assert relations[0].target_name == "n999" # fallback to ID
assert relations[0].target_type == ""
def test_fail_open_on_error(self, client: KGServiceClient):
mock_resp = _resp(500, text="error")
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert entities == []
assert relations == []
def test_request_params(self, client: KGServiceClient):
"""验证子图请求参数正确传递。"""
mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}})
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
_run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1"))
call_kwargs = mock_http.post.call_args
assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0]
assert call_kwargs.kwargs["params"] == {"depth": 3}
assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]}
# ---------------------------------------------------------------------------
# headers 测试
# ---------------------------------------------------------------------------
class TestHeaders:
def test_headers_with_token_and_user(self, client: KGServiceClient):
headers = client._headers(user_id="user-1")
assert headers["X-Internal-Token"] == "test-token"
assert headers["X-User-Id"] == "user-1"
def test_headers_without_user(self, client: KGServiceClient):
headers = client._headers()
assert "X-Internal-Token" in headers
assert "X-User-Id" not in headers
def test_headers_without_token(self):
c = KGServiceClient(base_url="http://test:8080", internal_token="")
headers = c._headers(user_id="u1")
assert "X-Internal-Token" not in headers
assert headers["X-User-Id"] == "u1"