"""Milvus 向量检索客户端的单元测试。""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever @pytest.fixture def retriever() -> MilvusVectorRetriever: return MilvusVectorRetriever( uri="http://test-milvus:19530", embedding_model="text-embedding-test", ) def _run(coro): return asyncio.run(coro) # --------------------------------------------------------------------------- # has_collection 测试 # --------------------------------------------------------------------------- class TestHasCollection: def test_collection_exists(self, retriever: MilvusVectorRetriever): mock_client = MagicMock() mock_client.has_collection = MagicMock(return_value=True) retriever._milvus_client = mock_client result = _run(retriever.has_collection("my_collection")) assert result is True def test_collection_not_exists(self, retriever: MilvusVectorRetriever): mock_client = MagicMock() mock_client.has_collection = MagicMock(return_value=False) retriever._milvus_client = mock_client result = _run(retriever.has_collection("nonexistent")) assert result is False def test_fail_open_on_error(self, retriever: MilvusVectorRetriever): mock_client = MagicMock() mock_client.has_collection = MagicMock(side_effect=Exception("connection error")) retriever._milvus_client = mock_client result = _run(retriever.has_collection("test")) assert result is False # --------------------------------------------------------------------------- # search 测试 # --------------------------------------------------------------------------- class TestSearch: def test_successful_search(self, retriever: MilvusVectorRetriever): """正常搜索返回 VectorChunk 列表。""" mock_embeddings = AsyncMock() mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3]) retriever._embeddings = mock_embeddings mock_milvus = MagicMock() mock_milvus.search = MagicMock(return_value=[ [ {"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}}, {"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}}, ] ]) retriever._milvus_client = mock_milvus chunks = _run(retriever.search("my_collection", "用户数据", top_k=5)) assert len(chunks) == 2 assert chunks[0].id == "doc1" assert chunks[0].text == "文档片段一" assert chunks[0].score == 0.95 assert chunks[0].metadata == {"source": "kb1"} assert chunks[1].id == "doc2" assert chunks[1].score == 0.82 def test_empty_results(self, retriever: MilvusVectorRetriever): mock_embeddings = AsyncMock() mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) retriever._embeddings = mock_embeddings mock_milvus = MagicMock() mock_milvus.search = MagicMock(return_value=[[]]) retriever._milvus_client = mock_milvus chunks = _run(retriever.search("col", "query")) assert chunks == [] def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever): """Embedding 失败时 fail-open 返回空列表。""" mock_embeddings = AsyncMock() mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error")) retriever._embeddings = mock_embeddings chunks = _run(retriever.search("col", "query")) assert chunks == [] def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever): """Milvus 搜索失败时 fail-open 返回空列表。""" mock_embeddings = AsyncMock() mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) retriever._embeddings = mock_embeddings mock_milvus = MagicMock() mock_milvus.search = MagicMock(side_effect=Exception("Milvus down")) retriever._milvus_client = mock_milvus chunks = _run(retriever.search("col", "query")) assert chunks == [] def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever): """验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。""" mock_embeddings = AsyncMock() mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) retriever._embeddings = mock_embeddings mock_milvus = MagicMock() mock_milvus.search = MagicMock(return_value=[[]]) retriever._milvus_client = mock_milvus with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread: mock_to_thread.return_value = [[]] chunks = _run(retriever.search("col", "query")) # asyncio.to_thread 应该被调用来包装同步 Milvus 调用 mock_to_thread.assert_called_once() call_args = mock_to_thread.call_args assert call_args.args[0] == mock_milvus.search