You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""GraphRAG 检索编排器的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
RelationSummary,
|
||||
RetrievalStrategy,
|
||||
VectorChunk,
|
||||
)
|
||||
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _make_retriever(
|
||||
*,
|
||||
milvus_search_result: list[VectorChunk] | None = None,
|
||||
milvus_has_collection: bool = True,
|
||||
kg_fulltext_result: list[EntitySummary] | None = None,
|
||||
kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None,
|
||||
) -> GraphRAGRetriever:
|
||||
"""创建带 mock 依赖的 retriever。"""
|
||||
mock_milvus = AsyncMock()
|
||||
mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection)
|
||||
mock_milvus.search = AsyncMock(return_value=milvus_search_result or [])
|
||||
|
||||
mock_kg = AsyncMock()
|
||||
mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or [])
|
||||
mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], []))
|
||||
|
||||
return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retrieve 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetrieve:
|
||||
"""retrieve 方法的测试。"""
|
||||
|
||||
def test_both_vector_and_graph(self):
|
||||
"""同时启用向量和图谱检索。"""
|
||||
chunks = [
|
||||
VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9),
|
||||
VectorChunk(id="c2", text="其他内容", score=0.7),
|
||||
]
|
||||
seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||
entities = [
|
||||
EntitySummary(id="e1", name="用户数据", type="Dataset"),
|
||||
EntitySummary(id="e2", name="user_id", type="Field"),
|
||||
]
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="用户数据", source_type="Dataset",
|
||||
target_name="user_id", target_type="Field",
|
||||
relation_type="HAS_FIELD",
|
||||
),
|
||||
]
|
||||
retriever = _make_retriever(
|
||||
milvus_search_result=chunks,
|
||||
kg_fulltext_result=seed,
|
||||
kg_subgraph_result=(entities, relations),
|
||||
)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="用户数据有哪些字段",
|
||||
collection_name="kb1",
|
||||
graph_id="graph-1",
|
||||
strategy=RetrievalStrategy(),
|
||||
user_id="u1",
|
||||
))
|
||||
|
||||
assert len(ctx.vector_chunks) == 2
|
||||
assert len(ctx.graph_context.entities) == 2
|
||||
assert len(ctx.graph_context.relations) == 1
|
||||
assert "用户数据" in ctx.graph_context.textualized
|
||||
assert "## 相关文档" in ctx.merged_text
|
||||
assert "## 知识图谱上下文" in ctx.merged_text
|
||||
|
||||
def test_vector_only(self):
|
||||
"""仅启用向量检索。"""
|
||||
chunks = [VectorChunk(id="c1", text="doc", score=0.9)]
|
||||
retriever = _make_retriever(milvus_search_result=chunks)
|
||||
strategy = RetrievalStrategy(enable_graph=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert len(ctx.vector_chunks) == 1
|
||||
assert ctx.graph_context.entities == []
|
||||
# KG client should not be called
|
||||
retriever._kg.fulltext_search.assert_not_called()
|
||||
|
||||
def test_graph_only(self):
|
||||
"""仅启用图谱检索。"""
|
||||
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||
entities = [EntitySummary(id="e1", name="A", type="T")]
|
||||
retriever = _make_retriever(
|
||||
kg_fulltext_result=seed,
|
||||
kg_subgraph_result=(entities, []),
|
||||
)
|
||||
strategy = RetrievalStrategy(enable_vector=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
assert len(ctx.graph_context.entities) == 1
|
||||
retriever._milvus.search.assert_not_called()
|
||||
|
||||
def test_no_seed_entities(self):
|
||||
"""图谱全文检索无结果时,不调用子图查询。"""
|
||||
retriever = _make_retriever(kg_fulltext_result=[])
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(enable_vector=False), user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.graph_context.entities == []
|
||||
retriever._kg.get_subgraph.assert_not_called()
|
||||
|
||||
def test_collection_not_found_skips_vector(self):
|
||||
"""collection 不存在时跳过向量检索。"""
|
||||
retriever = _make_retriever(milvus_has_collection=False)
|
||||
strategy = RetrievalStrategy(enable_graph=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="nonexistent", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
retriever._milvus.search.assert_not_called()
|
||||
|
||||
def test_both_empty(self):
|
||||
"""两条检索路径都无结果。"""
|
||||
retriever = _make_retriever()
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="nothing", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(), user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
assert ctx.graph_context.entities == []
|
||||
assert "未检索到相关上下文信息" in ctx.merged_text
|
||||
|
||||
def test_vector_error_fail_open(self):
|
||||
"""向量检索异常时 fail-open,图谱检索仍可正常返回。"""
|
||||
retriever = _make_retriever()
|
||||
retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down"))
|
||||
|
||||
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||
retriever._kg.fulltext_search = AsyncMock(return_value=seed)
|
||||
retriever._kg.get_subgraph = AsyncMock(
|
||||
return_value=([EntitySummary(id="e1", name="A", type="T")], [])
|
||||
)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(), user_id="u",
|
||||
))
|
||||
|
||||
# 向量检索失败,但图谱检索仍有结果
|
||||
assert ctx.vector_chunks == []
|
||||
assert len(ctx.graph_context.entities) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rank_results 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRankResults:
|
||||
"""_rank_results 方法的测试。"""
|
||||
|
||||
def _make_retriever_instance(self) -> GraphRAGRetriever:
|
||||
return GraphRAGRetriever(
|
||||
milvus_client=MagicMock(),
|
||||
kg_client=MagicMock(),
|
||||
)
|
||||
|
||||
def test_empty_chunks(self):
|
||||
r = self._make_retriever_instance()
|
||||
result = r._rank_results([], [], [], RetrievalStrategy())
|
||||
assert result == []
|
||||
|
||||
def test_single_chunk(self):
|
||||
r = self._make_retriever_instance()
|
||||
chunks = [VectorChunk(id="1", text="text", score=0.9)]
|
||||
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "1"
|
||||
|
||||
def test_graph_boost_reorders(self):
|
||||
"""图谱实体命中应提升文档片段排名。"""
|
||||
r = self._make_retriever_instance()
|
||||
# chunk1 向量分高但无图谱命中
|
||||
# chunk2 向量分低但命中图谱实体
|
||||
chunks = [
|
||||
VectorChunk(id="1", text="无关内容", score=0.9),
|
||||
VectorChunk(id="2", text="包含用户数据的内容", score=0.5),
|
||||
]
|
||||
entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||
strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7)
|
||||
|
||||
result = r._rank_results(chunks, entities, [], strategy)
|
||||
|
||||
# chunk2 应该排在前面(graph_boost 更高)
|
||||
assert result[0].id == "2"
|
||||
|
||||
def test_all_same_score(self):
|
||||
"""所有 chunk 分数相同时不崩溃。"""
|
||||
r = self._make_retriever_instance()
|
||||
chunks = [
|
||||
VectorChunk(id="1", text="a", score=0.5),
|
||||
VectorChunk(id="2", text="b", score=0.5),
|
||||
]
|
||||
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||
assert len(result) == 2
|
||||
Reference in New Issue
Block a user