Files
DataMate/runtime/datamate-python/app/module/kg_graphrag/test_interface.py
Jerry Yan 39338df808 feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证

新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口

修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖

测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)

关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*

安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)

测试结果:79 tests pass 
2026-02-20 09:41:55 +08:00

301 lines
11 KiB
Python

"""GraphRAG API 端点回归测试。
验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点
的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.exception import (
fastapi_http_exception_handler,
starlette_http_exception_handler,
validation_exception_handler,
)
from app.module.kg_graphrag.interface import router
from app.module.kg_graphrag.models import (
GraphContext,
RetrievalContext,
)
# ---------------------------------------------------------------------------
# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器)
# ---------------------------------------------------------------------------
_app = FastAPI()
_app.include_router(router, prefix="/api")
_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler)
_app.add_exception_handler(HTTPException, fastapi_http_exception_handler)
_app.add_exception_handler(RequestValidationError, validation_exception_handler)
_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc"
_VALID_BODY = {
"query": "测试查询",
"knowledge_base_id": "kb-1",
"collection_name": "test-collection",
"graph_id": _VALID_GRAPH_ID,
}
_HEADERS = {"X-User-Id": "user-1"}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_retrieval_context() -> RetrievalContext:
return RetrievalContext(
vector_chunks=[],
graph_context=GraphContext(),
merged_text="test context",
)
def _make_retriever_mock() -> AsyncMock:
m = AsyncMock()
m.retrieve = AsyncMock(return_value=_fake_retrieval_context())
return m
def _make_generator_mock() -> AsyncMock:
m = AsyncMock()
m.generate = AsyncMock(return_value="test answer")
m.model_name = "test-model"
async def _stream(*, query: str, context: str): # noqa: ARG001
for token in ["hello", " ", "world"]:
yield token
m.generate_stream = _stream
return m
def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock:
m = AsyncMock()
m.check_access = AsyncMock(return_value=access_granted)
return m
def _patch_all(
*,
access_granted: bool = True,
retriever: AsyncMock | None = None,
generator: AsyncMock | None = None,
validator: AsyncMock | None = None,
):
"""返回 context manager,统一 patch 三个懒加载工厂函数。"""
retriever = retriever or _make_retriever_mock()
generator = generator or _make_generator_mock()
validator = validator or _make_kb_validator_mock(access_granted=access_granted)
class _Ctx:
def __init__(self):
self.retriever = retriever
self.generator = generator
self.validator = validator
self._patches = [
patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever),
patch("app.module.kg_graphrag.interface._get_generator", return_value=generator),
patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator),
]
def __enter__(self):
for p in self._patches:
p.__enter__()
return self
def __exit__(self, *args):
for p in reversed(self._patches):
p.__exit__(*args)
return _Ctx()
@pytest.fixture
def client():
return TestClient(_app)
# ---------------------------------------------------------------------------
# POST /api/graphrag/query
# ---------------------------------------------------------------------------
class TestQueryEndpoint:
"""POST /api/graphrag/query 端点测试。"""
def test_success(self, client: TestClient):
"""权限校验通过 + 检索 + 生成 → 200。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
body = resp.json()
assert body["code"] == 200
assert body["data"]["answer"] == "test answer"
assert body["data"]["model"] == "test-model"
ctx.retriever.retrieve.assert_awaited_once()
ctx.generator.generate.assert_awaited_once()
def test_access_denied_returns_403(self, client: TestClient):
"""check_access 返回 False → 403 + 标准错误格式。"""
with _patch_all(access_granted=False):
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
assert "kb-1" in body["data"]["detail"]
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
"""权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
ctx.generator.generate.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 被调用时携带正确的 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id 请求头 → 422 验证错误。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/query", json=_VALID_BODY)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/graphrag/retrieve
# ---------------------------------------------------------------------------
class TestRetrieveEndpoint:
"""POST /api/graphrag/retrieve 端点测试。"""
def test_success(self, client: TestClient):
"""权限通过 → 检索 → 返回 RetrievalContext。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
body = resp.json()
assert body["code"] == 200
assert body["data"]["merged_text"] == "test context"
ctx.retriever.retrieve.assert_awaited_once()
def test_access_denied_returns_403(self, client: TestClient):
"""权限拒绝 → 403。"""
with _patch_all(access_granted=False):
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
def test_access_denied_skips_retrieval(self, client: TestClient):
"""权限拒绝时不调用 retriever.retrieve。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 收到 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id → 422。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/graphrag/query/stream
# ---------------------------------------------------------------------------
class TestQueryStreamEndpoint:
"""POST /api/graphrag/query/stream 端点测试。"""
def test_success_returns_sse(self, client: TestClient):
"""权限通过 → SSE 流式响应,包含 token 和 done 事件。"""
with _patch_all(access_granted=True):
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
text = resp.text
assert '"token"' in text
assert '"done": true' in text or '"done":true' in text
def test_access_denied_returns_403(self, client: TestClient):
"""权限拒绝 → 403。"""
with _patch_all(access_granted=False):
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
"""权限拒绝时不调用检索和生成。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 收到 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id → 422。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY)
assert resp.status_code == 422