"""GraphRAG 检索编排器的单元测试。""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, MagicMock import pytest from app.module.kg_graphrag.models import ( EntitySummary, RelationSummary, RetrievalStrategy, VectorChunk, ) from app.module.kg_graphrag.retriever import GraphRAGRetriever def _run(coro): return asyncio.run(coro) def _make_retriever( *, milvus_search_result: list[VectorChunk] | None = None, milvus_has_collection: bool = True, kg_fulltext_result: list[EntitySummary] | None = None, kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None, ) -> GraphRAGRetriever: """创建带 mock 依赖的 retriever。""" mock_milvus = AsyncMock() mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection) mock_milvus.search = AsyncMock(return_value=milvus_search_result or []) mock_kg = AsyncMock() mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or []) mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], [])) return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg) # --------------------------------------------------------------------------- # retrieve 测试 # --------------------------------------------------------------------------- class TestRetrieve: """retrieve 方法的测试。""" def test_both_vector_and_graph(self): """同时启用向量和图谱检索。""" chunks = [ VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9), VectorChunk(id="c2", text="其他内容", score=0.7), ] seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")] entities = [ EntitySummary(id="e1", name="用户数据", type="Dataset"), EntitySummary(id="e2", name="user_id", type="Field"), ] relations = [ RelationSummary( source_name="用户数据", source_type="Dataset", target_name="user_id", target_type="Field", relation_type="HAS_FIELD", ), ] retriever = _make_retriever( milvus_search_result=chunks, kg_fulltext_result=seed, kg_subgraph_result=(entities, relations), ) ctx = _run(retriever.retrieve( query="用户数据有哪些字段", collection_name="kb1", graph_id="graph-1", strategy=RetrievalStrategy(), user_id="u1", )) assert len(ctx.vector_chunks) == 2 assert len(ctx.graph_context.entities) == 2 assert len(ctx.graph_context.relations) == 1 assert "用户数据" in ctx.graph_context.textualized assert "## 相关文档" in ctx.merged_text assert "## 知识图谱上下文" in ctx.merged_text def test_vector_only(self): """仅启用向量检索。""" chunks = [VectorChunk(id="c1", text="doc", score=0.9)] retriever = _make_retriever(milvus_search_result=chunks) strategy = RetrievalStrategy(enable_graph=False) ctx = _run(retriever.retrieve( query="test", collection_name="kb", graph_id="g", strategy=strategy, user_id="u", )) assert len(ctx.vector_chunks) == 1 assert ctx.graph_context.entities == [] # KG client should not be called retriever._kg.fulltext_search.assert_not_called() def test_graph_only(self): """仅启用图谱检索。""" seed = [EntitySummary(id="e1", name="A", type="T")] entities = [EntitySummary(id="e1", name="A", type="T")] retriever = _make_retriever( kg_fulltext_result=seed, kg_subgraph_result=(entities, []), ) strategy = RetrievalStrategy(enable_vector=False) ctx = _run(retriever.retrieve( query="test", collection_name="kb", graph_id="g", strategy=strategy, user_id="u", )) assert ctx.vector_chunks == [] assert len(ctx.graph_context.entities) == 1 retriever._milvus.search.assert_not_called() def test_no_seed_entities(self): """图谱全文检索无结果时,不调用子图查询。""" retriever = _make_retriever(kg_fulltext_result=[]) ctx = _run(retriever.retrieve( query="test", collection_name="kb", graph_id="g", strategy=RetrievalStrategy(enable_vector=False), user_id="u", )) assert ctx.graph_context.entities == [] retriever._kg.get_subgraph.assert_not_called() def test_collection_not_found_skips_vector(self): """collection 不存在时跳过向量检索。""" retriever = _make_retriever(milvus_has_collection=False) strategy = RetrievalStrategy(enable_graph=False) ctx = _run(retriever.retrieve( query="test", collection_name="nonexistent", graph_id="g", strategy=strategy, user_id="u", )) assert ctx.vector_chunks == [] retriever._milvus.search.assert_not_called() def test_both_empty(self): """两条检索路径都无结果。""" retriever = _make_retriever() ctx = _run(retriever.retrieve( query="nothing", collection_name="kb", graph_id="g", strategy=RetrievalStrategy(), user_id="u", )) assert ctx.vector_chunks == [] assert ctx.graph_context.entities == [] assert "未检索到相关上下文信息" in ctx.merged_text def test_vector_error_fail_open(self): """向量检索异常时 fail-open,图谱检索仍可正常返回。""" retriever = _make_retriever() retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down")) seed = [EntitySummary(id="e1", name="A", type="T")] retriever._kg.fulltext_search = AsyncMock(return_value=seed) retriever._kg.get_subgraph = AsyncMock( return_value=([EntitySummary(id="e1", name="A", type="T")], []) ) ctx = _run(retriever.retrieve( query="test", collection_name="kb", graph_id="g", strategy=RetrievalStrategy(), user_id="u", )) # 向量检索失败,但图谱检索仍有结果 assert ctx.vector_chunks == [] assert len(ctx.graph_context.entities) == 1 # --------------------------------------------------------------------------- # _rank_results 测试 # --------------------------------------------------------------------------- class TestRankResults: """_rank_results 方法的测试。""" def _make_retriever_instance(self) -> GraphRAGRetriever: return GraphRAGRetriever( milvus_client=MagicMock(), kg_client=MagicMock(), ) def test_empty_chunks(self): r = self._make_retriever_instance() result = r._rank_results([], [], [], RetrievalStrategy()) assert result == [] def test_single_chunk(self): r = self._make_retriever_instance() chunks = [VectorChunk(id="1", text="text", score=0.9)] result = r._rank_results(chunks, [], [], RetrievalStrategy()) assert len(result) == 1 assert result[0].id == "1" def test_graph_boost_reorders(self): """图谱实体命中应提升文档片段排名。""" r = self._make_retriever_instance() # chunk1 向量分高但无图谱命中 # chunk2 向量分低但命中图谱实体 chunks = [ VectorChunk(id="1", text="无关内容", score=0.9), VectorChunk(id="2", text="包含用户数据的内容", score=0.5), ] entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")] strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7) result = r._rank_results(chunks, entities, [], strategy) # chunk2 应该排在前面(graph_boost 更高) assert result[0].id == "2" def test_all_same_score(self): """所有 chunk 分数相同时不崩溃。""" r = self._make_retriever_instance() chunks = [ VectorChunk(id="1", text="a", score=0.5), VectorChunk(id="2", text="b", score=0.5), ] result = r._rank_results(chunks, [], [], RetrievalStrategy()) assert len(result) == 2