You've already forked DataMate
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
"""Milvus 向量检索客户端的单元测试。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
|
|
|
|
|
@pytest.fixture
|
|
def retriever() -> MilvusVectorRetriever:
|
|
return MilvusVectorRetriever(
|
|
uri="http://test-milvus:19530",
|
|
embedding_model="text-embedding-test",
|
|
)
|
|
|
|
|
|
def _run(coro):
|
|
return asyncio.run(coro)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# has_collection 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHasCollection:
|
|
def test_collection_exists(self, retriever: MilvusVectorRetriever):
|
|
mock_client = MagicMock()
|
|
mock_client.has_collection = MagicMock(return_value=True)
|
|
retriever._milvus_client = mock_client
|
|
|
|
result = _run(retriever.has_collection("my_collection"))
|
|
|
|
assert result is True
|
|
|
|
def test_collection_not_exists(self, retriever: MilvusVectorRetriever):
|
|
mock_client = MagicMock()
|
|
mock_client.has_collection = MagicMock(return_value=False)
|
|
retriever._milvus_client = mock_client
|
|
|
|
result = _run(retriever.has_collection("nonexistent"))
|
|
|
|
assert result is False
|
|
|
|
def test_fail_open_on_error(self, retriever: MilvusVectorRetriever):
|
|
mock_client = MagicMock()
|
|
mock_client.has_collection = MagicMock(side_effect=Exception("connection error"))
|
|
retriever._milvus_client = mock_client
|
|
|
|
result = _run(retriever.has_collection("test"))
|
|
|
|
assert result is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# search 测试
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSearch:
|
|
def test_successful_search(self, retriever: MilvusVectorRetriever):
|
|
"""正常搜索返回 VectorChunk 列表。"""
|
|
mock_embeddings = AsyncMock()
|
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
|
retriever._embeddings = mock_embeddings
|
|
|
|
mock_milvus = MagicMock()
|
|
mock_milvus.search = MagicMock(return_value=[
|
|
[
|
|
{"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}},
|
|
{"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}},
|
|
]
|
|
])
|
|
retriever._milvus_client = mock_milvus
|
|
|
|
chunks = _run(retriever.search("my_collection", "用户数据", top_k=5))
|
|
|
|
assert len(chunks) == 2
|
|
assert chunks[0].id == "doc1"
|
|
assert chunks[0].text == "文档片段一"
|
|
assert chunks[0].score == 0.95
|
|
assert chunks[0].metadata == {"source": "kb1"}
|
|
assert chunks[1].id == "doc2"
|
|
assert chunks[1].score == 0.82
|
|
|
|
def test_empty_results(self, retriever: MilvusVectorRetriever):
|
|
mock_embeddings = AsyncMock()
|
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
|
retriever._embeddings = mock_embeddings
|
|
|
|
mock_milvus = MagicMock()
|
|
mock_milvus.search = MagicMock(return_value=[[]])
|
|
retriever._milvus_client = mock_milvus
|
|
|
|
chunks = _run(retriever.search("col", "query"))
|
|
|
|
assert chunks == []
|
|
|
|
def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever):
|
|
"""Embedding 失败时 fail-open 返回空列表。"""
|
|
mock_embeddings = AsyncMock()
|
|
mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error"))
|
|
retriever._embeddings = mock_embeddings
|
|
|
|
chunks = _run(retriever.search("col", "query"))
|
|
|
|
assert chunks == []
|
|
|
|
def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever):
|
|
"""Milvus 搜索失败时 fail-open 返回空列表。"""
|
|
mock_embeddings = AsyncMock()
|
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
|
retriever._embeddings = mock_embeddings
|
|
|
|
mock_milvus = MagicMock()
|
|
mock_milvus.search = MagicMock(side_effect=Exception("Milvus down"))
|
|
retriever._milvus_client = mock_milvus
|
|
|
|
chunks = _run(retriever.search("col", "query"))
|
|
|
|
assert chunks == []
|
|
|
|
def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever):
|
|
"""验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。"""
|
|
mock_embeddings = AsyncMock()
|
|
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
|
retriever._embeddings = mock_embeddings
|
|
|
|
mock_milvus = MagicMock()
|
|
mock_milvus.search = MagicMock(return_value=[[]])
|
|
retriever._milvus_client = mock_milvus
|
|
|
|
with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
|
|
mock_to_thread.return_value = [[]]
|
|
|
|
chunks = _run(retriever.search("col", "query"))
|
|
|
|
# asyncio.to_thread 应该被调用来包装同步 Milvus 调用
|
|
mock_to_thread.assert_called_once()
|
|
call_args = mock_to_thread.call_args
|
|
assert call_args.args[0] == mock_milvus.search
|