"""GraphRAG API 端点回归测试。 验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点 的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。 """ from __future__ import annotations from unittest.mock import AsyncMock, patch import pytest from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient from starlette.exceptions import HTTPException as StarletteHTTPException from app.exception import ( fastapi_http_exception_handler, starlette_http_exception_handler, validation_exception_handler, ) from app.module.kg_graphrag.interface import router from app.module.kg_graphrag.models import ( GraphContext, RetrievalContext, ) # --------------------------------------------------------------------------- # 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器) # --------------------------------------------------------------------------- _app = FastAPI() _app.include_router(router, prefix="/api") _app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) _app.add_exception_handler(HTTPException, fastapi_http_exception_handler) _app.add_exception_handler(RequestValidationError, validation_exception_handler) _VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc" _VALID_BODY = { "query": "测试查询", "knowledge_base_id": "kb-1", "collection_name": "test-collection", "graph_id": _VALID_GRAPH_ID, } _HEADERS = {"X-User-Id": "user-1"} # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _fake_retrieval_context() -> RetrievalContext: return RetrievalContext( vector_chunks=[], graph_context=GraphContext(), merged_text="test context", ) def _make_retriever_mock() -> AsyncMock: m = AsyncMock() m.retrieve = AsyncMock(return_value=_fake_retrieval_context()) return m def _make_generator_mock() -> AsyncMock: m = AsyncMock() m.generate = AsyncMock(return_value="test answer") m.model_name = "test-model" async def _stream(*, query: str, context: str): # noqa: ARG001 for token in ["hello", " ", "world"]: yield token m.generate_stream = _stream return m def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock: m = AsyncMock() m.check_access = AsyncMock(return_value=access_granted) return m def _patch_all( *, access_granted: bool = True, retriever: AsyncMock | None = None, generator: AsyncMock | None = None, validator: AsyncMock | None = None, ): """返回 context manager,统一 patch 三个懒加载工厂函数。""" retriever = retriever or _make_retriever_mock() generator = generator or _make_generator_mock() validator = validator or _make_kb_validator_mock(access_granted=access_granted) class _Ctx: def __init__(self): self.retriever = retriever self.generator = generator self.validator = validator self._patches = [ patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever), patch("app.module.kg_graphrag.interface._get_generator", return_value=generator), patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator), ] def __enter__(self): for p in self._patches: p.__enter__() return self def __exit__(self, *args): for p in reversed(self._patches): p.__exit__(*args) return _Ctx() @pytest.fixture def client(): return TestClient(_app) # --------------------------------------------------------------------------- # POST /api/graphrag/query # --------------------------------------------------------------------------- class TestQueryEndpoint: """POST /api/graphrag/query 端点测试。""" def test_success(self, client: TestClient): """权限校验通过 + 检索 + 生成 → 200。""" with _patch_all(access_granted=True) as ctx: resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 200 body = resp.json() assert body["code"] == 200 assert body["data"]["answer"] == "test answer" assert body["data"]["model"] == "test-model" ctx.retriever.retrieve.assert_awaited_once() ctx.generator.generate.assert_awaited_once() def test_access_denied_returns_403(self, client: TestClient): """check_access 返回 False → 403 + 标准错误格式。""" with _patch_all(access_granted=False): resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 403 body = resp.json() assert body["code"] == 403 assert "kb-1" in body["data"]["detail"] def test_access_denied_skips_retrieval_and_generation(self, client: TestClient): """权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。""" with _patch_all(access_granted=False) as ctx: resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 403 ctx.retriever.retrieve.assert_not_called() ctx.generator.generate.assert_not_called() def test_check_access_receives_collection_name(self, client: TestClient): """验证 check_access 被调用时携带正确的 collection_name 参数。""" with _patch_all(access_granted=True) as ctx: resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 200 ctx.validator.check_access.assert_awaited_once_with( "kb-1", "user-1", collection_name="test-collection", ) def test_missing_user_id_returns_422(self, client: TestClient): """缺少 X-User-Id 请求头 → 422 验证错误。""" with _patch_all(access_granted=True): resp = client.post("/api/graphrag/query", json=_VALID_BODY) assert resp.status_code == 422 # --------------------------------------------------------------------------- # POST /api/graphrag/retrieve # --------------------------------------------------------------------------- class TestRetrieveEndpoint: """POST /api/graphrag/retrieve 端点测试。""" def test_success(self, client: TestClient): """权限通过 → 检索 → 返回 RetrievalContext。""" with _patch_all(access_granted=True) as ctx: resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 200 body = resp.json() assert body["code"] == 200 assert body["data"]["merged_text"] == "test context" ctx.retriever.retrieve.assert_awaited_once() def test_access_denied_returns_403(self, client: TestClient): """权限拒绝 → 403。""" with _patch_all(access_granted=False): resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 403 body = resp.json() assert body["code"] == 403 def test_access_denied_skips_retrieval(self, client: TestClient): """权限拒绝时不调用 retriever.retrieve。""" with _patch_all(access_granted=False) as ctx: resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 403 ctx.retriever.retrieve.assert_not_called() def test_check_access_receives_collection_name(self, client: TestClient): """验证 check_access 收到 collection_name 参数。""" with _patch_all(access_granted=True) as ctx: resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) assert resp.status_code == 200 ctx.validator.check_access.assert_awaited_once_with( "kb-1", "user-1", collection_name="test-collection", ) def test_missing_user_id_returns_422(self, client: TestClient): """缺少 X-User-Id → 422。""" with _patch_all(access_granted=True): resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY) assert resp.status_code == 422 # --------------------------------------------------------------------------- # POST /api/graphrag/query/stream # --------------------------------------------------------------------------- class TestQueryStreamEndpoint: """POST /api/graphrag/query/stream 端点测试。""" def test_success_returns_sse(self, client: TestClient): """权限通过 → SSE 流式响应,包含 token 和 done 事件。""" with _patch_all(access_granted=True): resp = client.post( "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, ) assert resp.status_code == 200 assert resp.headers["content-type"].startswith("text/event-stream") text = resp.text assert '"token"' in text assert '"done": true' in text or '"done":true' in text def test_access_denied_returns_403(self, client: TestClient): """权限拒绝 → 403。""" with _patch_all(access_granted=False): resp = client.post( "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, ) assert resp.status_code == 403 body = resp.json() assert body["code"] == 403 def test_access_denied_skips_retrieval_and_generation(self, client: TestClient): """权限拒绝时不调用检索和生成。""" with _patch_all(access_granted=False) as ctx: resp = client.post( "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, ) assert resp.status_code == 403 ctx.retriever.retrieve.assert_not_called() def test_check_access_receives_collection_name(self, client: TestClient): """验证 check_access 收到 collection_name 参数。""" with _patch_all(access_granted=True) as ctx: resp = client.post( "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, ) assert resp.status_code == 200 ctx.validator.check_access.assert_awaited_once_with( "kb-1", "user-1", collection_name="test-collection", ) def test_missing_user_id_returns_422(self, client: TestClient): """缺少 X-User-Id → 422。""" with _patch_all(access_granted=True): resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY) assert resp.status_code == 422