You've already forked DataMate
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
331 lines
13 KiB
Python
331 lines
13 KiB
Python
"""知识库访问权限校验的单元测试。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
|
|
|
|
|
@pytest.fixture
|
|
def validator() -> KnowledgeBaseAccessValidator:
|
|
return KnowledgeBaseAccessValidator(
|
|
base_url="http://test-backend:8080/api",
|
|
timeout=5.0,
|
|
)
|
|
|
|
|
|
def _run(coro):
|
|
return asyncio.run(coro)
|
|
|
|
|
|
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
|
|
|
|
|
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
|
"""创建带 request 的 httpx.Response。"""
|
|
if json is not None:
|
|
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
|
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# check_access 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckAccess:
|
|
"""check_access 方法的测试。"""
|
|
|
|
def test_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 返回 200 + code=200: 用户有权访问。"""
|
|
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is True
|
|
|
|
def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator):
|
|
"""权限通过且 collection_name 与 KB name 一致:允许访问。"""
|
|
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-1", "user-1", collection_name="my-collection",
|
|
))
|
|
|
|
assert result is True
|
|
|
|
def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。"""
|
|
mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "other-user"))
|
|
|
|
assert result is False
|
|
|
|
def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 返回 HTTP 403。"""
|
|
mock_resp = _resp(403, text="Forbidden")
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator):
|
|
"""知识库不存在,Java 返回 404。"""
|
|
mock_resp = _resp(404, text="Not Found")
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("nonexistent-kb", "user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 后端返回 500。"""
|
|
mock_resp = _resp(500, text="Internal Server Error")
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator):
|
|
"""网络异常时 fail-close(拒绝访问),防止绕过权限校验。"""
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator):
|
|
"""超时时 fail-close(拒绝访问)。"""
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout"))
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_request_headers(self, validator: KnowledgeBaseAccessValidator):
|
|
"""验证请求中携带正确的 X-User-Id header。"""
|
|
mock_resp = _resp(200, json={"code": 200, "data": {}})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
_run(validator.check_access("kb-123", "user-456"))
|
|
|
|
call_kwargs = mock_http.get.call_args
|
|
assert "/knowledge-base/kb-123" in call_kwargs.args[0]
|
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456"
|
|
|
|
def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。
|
|
|
|
模拟 Java 后端返回权限不足的业务错误。
|
|
"""
|
|
# 用户 A 创建的 KB,用户 B 请求访问
|
|
mock_resp = _resp(200, json={
|
|
"code": "sys.0005",
|
|
"message": "权限不足",
|
|
"data": None,
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-user-a", "user-b"))
|
|
|
|
assert result is False
|
|
|
|
# 确认请求携带的是用户 B 的 ID
|
|
call_kwargs = mock_http.get.call_args
|
|
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b"
|
|
|
|
def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
|
"""管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access("kb-user-a", "admin-user"))
|
|
|
|
# Java 侧管理员校验通过,返回 200 + code=200
|
|
assert result is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# collection_name 绑定校验测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCollectionNameBinding:
|
|
"""collection_name 与 knowledge_base_id 的绑定校验测试。
|
|
|
|
防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他
|
|
知识库的 Milvus 数据。
|
|
"""
|
|
|
|
def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-a", "name": "collection-a"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-a", "user-1", collection_name="collection-b",
|
|
))
|
|
|
|
assert result is False
|
|
|
|
def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator):
|
|
"""collection_name=None 时不做绑定校验(向后兼容)。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-1", "name": "some-name"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
# 不传 collection_name → 仅校验权限,不校验绑定
|
|
result = _run(validator.check_access("kb-1", "user-1"))
|
|
|
|
assert result is True
|
|
|
|
def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 响应 data 中没有 name 字段:fail-close 拒绝。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-1"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-1", "user-1", collection_name="any-collection",
|
|
))
|
|
|
|
# data.name is None, doesn't match "any-collection" → denied
|
|
assert result is False
|
|
|
|
def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 响应 data 为 null:fail-close 拒绝。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": None,
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-1", "user-1", collection_name="any-collection",
|
|
))
|
|
|
|
assert result is False
|
|
|
|
def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""Java 响应 data 为空 dict {}:fail-close 拒绝。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-1", "user-1", collection_name="any-collection",
|
|
))
|
|
|
|
assert result is False
|
|
|
|
def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator):
|
|
"""用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的
|
|
collection_name='kb-b-data':应被拒绝。
|
|
|
|
这是核心越权场景的完整模拟。
|
|
"""
|
|
# 用户有权访问 KB-A
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
# 但 collection_name 指向 KB-B 的数据
|
|
result = _run(validator.check_access(
|
|
"kb-a", "user-1", collection_name="kb-b-data",
|
|
))
|
|
|
|
assert result is False
|
|
|
|
def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator):
|
|
"""中文 collection_name 精确匹配。"""
|
|
mock_resp = _resp(200, json={
|
|
"code": 200,
|
|
"data": {"id": "kb-1", "name": "用户行为数据"},
|
|
})
|
|
with patch.object(validator, "_get_client") as mock_get:
|
|
mock_http = AsyncMock()
|
|
mock_http.get = AsyncMock(return_value=mock_resp)
|
|
mock_get.return_value = mock_http
|
|
|
|
result = _run(validator.check_access(
|
|
"kb-1", "user-1", collection_name="用户行为数据",
|
|
))
|
|
|
|
assert result is True
|