You've already forked DataMate
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""GraphRAG API 端点回归测试。
|
|
|
|
验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点
|
|
的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.testclient import TestClient
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
from app.exception import (
|
|
fastapi_http_exception_handler,
|
|
starlette_http_exception_handler,
|
|
validation_exception_handler,
|
|
)
|
|
from app.module.kg_graphrag.interface import router
|
|
from app.module.kg_graphrag.models import (
|
|
GraphContext,
|
|
RetrievalContext,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_app = FastAPI()
|
|
_app.include_router(router, prefix="/api")
|
|
_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler)
|
|
_app.add_exception_handler(HTTPException, fastapi_http_exception_handler)
|
|
_app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
|
|
|
|
_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc"
|
|
|
|
_VALID_BODY = {
|
|
"query": "测试查询",
|
|
"knowledge_base_id": "kb-1",
|
|
"collection_name": "test-collection",
|
|
"graph_id": _VALID_GRAPH_ID,
|
|
}
|
|
|
|
_HEADERS = {"X-User-Id": "user-1"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _fake_retrieval_context() -> RetrievalContext:
|
|
return RetrievalContext(
|
|
vector_chunks=[],
|
|
graph_context=GraphContext(),
|
|
merged_text="test context",
|
|
)
|
|
|
|
|
|
def _make_retriever_mock() -> AsyncMock:
|
|
m = AsyncMock()
|
|
m.retrieve = AsyncMock(return_value=_fake_retrieval_context())
|
|
return m
|
|
|
|
|
|
def _make_generator_mock() -> AsyncMock:
|
|
m = AsyncMock()
|
|
m.generate = AsyncMock(return_value="test answer")
|
|
m.model_name = "test-model"
|
|
|
|
async def _stream(*, query: str, context: str): # noqa: ARG001
|
|
for token in ["hello", " ", "world"]:
|
|
yield token
|
|
|
|
m.generate_stream = _stream
|
|
return m
|
|
|
|
|
|
def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock:
|
|
m = AsyncMock()
|
|
m.check_access = AsyncMock(return_value=access_granted)
|
|
return m
|
|
|
|
|
|
def _patch_all(
|
|
*,
|
|
access_granted: bool = True,
|
|
retriever: AsyncMock | None = None,
|
|
generator: AsyncMock | None = None,
|
|
validator: AsyncMock | None = None,
|
|
):
|
|
"""返回 context manager,统一 patch 三个懒加载工厂函数。"""
|
|
retriever = retriever or _make_retriever_mock()
|
|
generator = generator or _make_generator_mock()
|
|
validator = validator or _make_kb_validator_mock(access_granted=access_granted)
|
|
|
|
class _Ctx:
|
|
def __init__(self):
|
|
self.retriever = retriever
|
|
self.generator = generator
|
|
self.validator = validator
|
|
self._patches = [
|
|
patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever),
|
|
patch("app.module.kg_graphrag.interface._get_generator", return_value=generator),
|
|
patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator),
|
|
]
|
|
|
|
def __enter__(self):
|
|
for p in self._patches:
|
|
p.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
for p in reversed(self._patches):
|
|
p.__exit__(*args)
|
|
|
|
return _Ctx()
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
return TestClient(_app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /api/graphrag/query
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQueryEndpoint:
|
|
"""POST /api/graphrag/query 端点测试。"""
|
|
|
|
def test_success(self, client: TestClient):
|
|
"""权限校验通过 + 检索 + 生成 → 200。"""
|
|
with _patch_all(access_granted=True) as ctx:
|
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["code"] == 200
|
|
assert body["data"]["answer"] == "test answer"
|
|
assert body["data"]["model"] == "test-model"
|
|
ctx.retriever.retrieve.assert_awaited_once()
|
|
ctx.generator.generate.assert_awaited_once()
|
|
|
|
def test_access_denied_returns_403(self, client: TestClient):
|
|
"""check_access 返回 False → 403 + 标准错误格式。"""
|
|
with _patch_all(access_granted=False):
|
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 403
|
|
body = resp.json()
|
|
assert body["code"] == 403
|
|
assert "kb-1" in body["data"]["detail"]
|
|
|
|
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
|
"""权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。"""
|
|
with _patch_all(access_granted=False) as ctx:
|
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 403
|
|
ctx.retriever.retrieve.assert_not_called()
|
|
ctx.generator.generate.assert_not_called()
|
|
|
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
|
"""验证 check_access 被调用时携带正确的 collection_name 参数。"""
|
|
with _patch_all(access_granted=True) as ctx:
|
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 200
|
|
ctx.validator.check_access.assert_awaited_once_with(
|
|
"kb-1", "user-1", collection_name="test-collection",
|
|
)
|
|
|
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
|
"""缺少 X-User-Id 请求头 → 422 验证错误。"""
|
|
with _patch_all(access_granted=True):
|
|
resp = client.post("/api/graphrag/query", json=_VALID_BODY)
|
|
|
|
assert resp.status_code == 422
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /api/graphrag/retrieve
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRetrieveEndpoint:
|
|
"""POST /api/graphrag/retrieve 端点测试。"""
|
|
|
|
def test_success(self, client: TestClient):
|
|
"""权限通过 → 检索 → 返回 RetrievalContext。"""
|
|
with _patch_all(access_granted=True) as ctx:
|
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["code"] == 200
|
|
assert body["data"]["merged_text"] == "test context"
|
|
ctx.retriever.retrieve.assert_awaited_once()
|
|
|
|
def test_access_denied_returns_403(self, client: TestClient):
|
|
"""权限拒绝 → 403。"""
|
|
with _patch_all(access_granted=False):
|
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 403
|
|
body = resp.json()
|
|
assert body["code"] == 403
|
|
|
|
def test_access_denied_skips_retrieval(self, client: TestClient):
|
|
"""权限拒绝时不调用 retriever.retrieve。"""
|
|
with _patch_all(access_granted=False) as ctx:
|
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 403
|
|
ctx.retriever.retrieve.assert_not_called()
|
|
|
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
|
"""验证 check_access 收到 collection_name 参数。"""
|
|
with _patch_all(access_granted=True) as ctx:
|
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
|
|
|
assert resp.status_code == 200
|
|
ctx.validator.check_access.assert_awaited_once_with(
|
|
"kb-1", "user-1", collection_name="test-collection",
|
|
)
|
|
|
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
|
"""缺少 X-User-Id → 422。"""
|
|
with _patch_all(access_granted=True):
|
|
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY)
|
|
|
|
assert resp.status_code == 422
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /api/graphrag/query/stream
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQueryStreamEndpoint:
|
|
"""POST /api/graphrag/query/stream 端点测试。"""
|
|
|
|
def test_success_returns_sse(self, client: TestClient):
|
|
"""权限通过 → SSE 流式响应,包含 token 和 done 事件。"""
|
|
with _patch_all(access_granted=True):
|
|
resp = client.post(
|
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.headers["content-type"].startswith("text/event-stream")
|
|
text = resp.text
|
|
assert '"token"' in text
|
|
assert '"done": true' in text or '"done":true' in text
|
|
|
|
def test_access_denied_returns_403(self, client: TestClient):
|
|
"""权限拒绝 → 403。"""
|
|
with _patch_all(access_granted=False):
|
|
resp = client.post(
|
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
|
)
|
|
|
|
assert resp.status_code == 403
|
|
body = resp.json()
|
|
assert body["code"] == 403
|
|
|
|
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
|
"""权限拒绝时不调用检索和生成。"""
|
|
with _patch_all(access_granted=False) as ctx:
|
|
resp = client.post(
|
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
|
)
|
|
|
|
assert resp.status_code == 403
|
|
ctx.retriever.retrieve.assert_not_called()
|
|
|
|
def test_check_access_receives_collection_name(self, client: TestClient):
|
|
"""验证 check_access 收到 collection_name 参数。"""
|
|
with _patch_all(access_granted=True) as ctx:
|
|
resp = client.post(
|
|
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
ctx.validator.check_access.assert_awaited_once_with(
|
|
"kb-1", "user-1", collection_name="test-collection",
|
|
)
|
|
|
|
def test_missing_user_id_returns_422(self, client: TestClient):
|
|
"""缺少 X-User-Id → 422。"""
|
|
with _patch_all(access_granted=True):
|
|
resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY)
|
|
|
|
assert resp.status_code == 422
|