"""知识库访问权限校验的单元测试。""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, patch import httpx import pytest from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator @pytest.fixture def validator() -> KnowledgeBaseAccessValidator: return KnowledgeBaseAccessValidator( base_url="http://test-backend:8080/api", timeout=5.0, ) def _run(coro): return asyncio.run(coro) _FAKE_REQUEST = httpx.Request("GET", "http://test") def _resp(status_code: int, *, json=None, text=None) -> httpx.Response: """创建带 request 的 httpx.Response。""" if json is not None: return httpx.Response(status_code, json=json, request=_FAKE_REQUEST) return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST) # --------------------------------------------------------------------------- # check_access 测试 # --------------------------------------------------------------------------- class TestCheckAccess: """check_access 方法的测试。""" def test_access_granted(self, validator: KnowledgeBaseAccessValidator): """Java 返回 200 + code=200: 用户有权访问。""" mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}}) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "user-1")) assert result is True def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator): """权限通过且 collection_name 与 KB name 一致:允许访问。""" mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}}) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-1", "user-1", collection_name="my-collection", )) assert result is True def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator): """Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。""" mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"}) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "other-user")) assert result is False def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator): """Java 返回 HTTP 403。""" mock_resp = _resp(403, text="Forbidden") with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "user-1")) assert result is False def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator): """知识库不存在,Java 返回 404。""" mock_resp = _resp(404, text="Not Found") with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("nonexistent-kb", "user-1")) assert result is False def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator): """Java 后端返回 500。""" mock_resp = _resp(500, text="Internal Server Error") with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "user-1")) assert result is False def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator): """网络异常时 fail-close(拒绝访问),防止绕过权限校验。""" with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "user-1")) assert result is False def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator): """超时时 fail-close(拒绝访问)。""" with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout")) mock_get.return_value = mock_http result = _run(validator.check_access("kb-1", "user-1")) assert result is False def test_request_headers(self, validator: KnowledgeBaseAccessValidator): """验证请求中携带正确的 X-User-Id header。""" mock_resp = _resp(200, json={"code": 200, "data": {}}) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http _run(validator.check_access("kb-123", "user-456")) call_kwargs = mock_http.get.call_args assert "/api/knowledge-base/kb-123" in call_kwargs.args[0] assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456" def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator): """跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。 模拟 Java 后端返回权限不足的业务错误。 """ # 用户 A 创建的 KB,用户 B 请求访问 mock_resp = _resp(200, json={ "code": "sys.0005", "message": "权限不足", "data": None, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-user-a", "user-b")) assert result is False # 确认请求携带的是用户 B 的 ID call_kwargs = mock_http.get.call_args assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b" def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator): """管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。""" mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access("kb-user-a", "admin-user")) # Java 侧管理员校验通过,返回 200 + code=200 assert result is True # --------------------------------------------------------------------------- # collection_name 绑定校验测试 # --------------------------------------------------------------------------- class TestCollectionNameBinding: """collection_name 与 knowledge_base_id 的绑定校验测试。 防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他 知识库的 Milvus 数据。 """ def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator): """KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。""" mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-a", "name": "collection-a"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-a", "user-1", collection_name="collection-b", )) assert result is False def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator): """collection_name=None 时不做绑定校验(向后兼容)。""" mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-1", "name": "some-name"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http # 不传 collection_name → 仅校验权限,不校验绑定 result = _run(validator.check_access("kb-1", "user-1")) assert result is True def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator): """Java 响应 data 中没有 name 字段:fail-close 拒绝。""" mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-1"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-1", "user-1", collection_name="any-collection", )) # data.name is None, doesn't match "any-collection" → denied assert result is False def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator): """Java 响应 data 为 null:fail-close 拒绝。""" mock_resp = _resp(200, json={ "code": 200, "data": None, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-1", "user-1", collection_name="any-collection", )) assert result is False def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator): """Java 响应 data 为空 dict {}:fail-close 拒绝。""" mock_resp = _resp(200, json={ "code": 200, "data": {}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-1", "user-1", collection_name="any-collection", )) assert result is False def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator): """用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的 collection_name='kb-b-data':应被拒绝。 这是核心越权场景的完整模拟。 """ # 用户有权访问 KB-A mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http # 但 collection_name 指向 KB-B 的数据 result = _run(validator.check_access( "kb-a", "user-1", collection_name="kb-b-data", )) assert result is False def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator): """中文 collection_name 精确匹配。""" mock_resp = _resp(200, json={ "code": 200, "data": {"id": "kb-1", "name": "用户行为数据"}, }) with patch.object(validator, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http result = _run(validator.check_access( "kb-1", "user-1", collection_name="用户行为数据", )) assert result is True