"""KG 服务 REST 客户端的单元测试。""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, patch import httpx import pytest from app.module.kg_graphrag.kg_client import KGServiceClient @pytest.fixture def client() -> KGServiceClient: return KGServiceClient( base_url="http://test-kg:8080", internal_token="test-token", timeout=5.0, ) def _run(coro): return asyncio.run(coro) _FAKE_REQUEST = httpx.Request("GET", "http://test") def _resp(status_code: int, *, json=None, text=None) -> httpx.Response: """创建带 request 的 httpx.Response(raise_for_status 需要)。""" if json is not None: return httpx.Response(status_code, json=json, request=_FAKE_REQUEST) return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST) # --------------------------------------------------------------------------- # fulltext_search 测试 # --------------------------------------------------------------------------- class TestFulltextSearch: """fulltext_search 方法的测试。""" def test_wrapped_paged_response(self, client: KGServiceClient): """Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}""" mock_body = { "code": 200, "data": { "page": 0, "size": 20, "totalElements": 2, "totalPages": 1, "content": [ {"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5}, {"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8}, ], }, } mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1")) assert len(entities) == 2 assert entities[0].id == "e1" assert entities[0].name == "用户数据" assert entities[0].type == "Dataset" assert entities[1].name == "清洗管道" def test_unwrapped_paged_response(self, client: KGServiceClient): """Java 直接返回 PagedResponse(无全局包装)。""" mock_body = { "page": 0, "size": 10, "totalElements": 1, "totalPages": 1, "content": [ {"id": "e1", "name": "A", "type": "Dataset", "description": "desc"}, ], } mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities = _run(client.fulltext_search("graph-1", "A")) # body has no "data" key → fallback to body itself → read "content" assert len(entities) == 1 assert entities[0].name == "A" def test_empty_content(self, client: KGServiceClient): mock_body = {"code": 200, "data": {"page": 0, "content": []}} mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities = _run(client.fulltext_search("graph-1", "nothing")) assert entities == [] def test_fail_open_on_http_error(self, client: KGServiceClient): """HTTP 错误时 fail-open 返回空列表。""" mock_resp = _resp(500, text="Internal Server Error") with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities = _run(client.fulltext_search("graph-1", "test")) assert entities == [] def test_fail_open_on_connection_error(self, client: KGServiceClient): """连接错误时 fail-open 返回空列表。""" with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) mock_get.return_value = mock_http entities = _run(client.fulltext_search("graph-1", "test")) assert entities == [] def test_request_headers(self, client: KGServiceClient): """验证请求中携带正确的 headers。""" mock_resp = _resp(200, json={"data": {"content": []}}) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.get = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http _run(client.fulltext_search("gid", "q", size=5, user_id="user-123")) call_kwargs = mock_http.get.call_args assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token" assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123" assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5} # --------------------------------------------------------------------------- # get_subgraph 测试 # --------------------------------------------------------------------------- class TestGetSubgraph: """get_subgraph 方法的测试。""" def test_wrapped_subgraph_response(self, client: KGServiceClient): """Java 返回被全局包装的 SubgraphExportVO。""" mock_body = { "code": 200, "data": { "nodes": [ {"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}}, {"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}}, ], "edges": [ { "id": "edge1", "sourceEntityId": "n1", "targetEntityId": "n2", "relationType": "HAS_FIELD", "weight": 1.0, "confidence": 0.9, "sourceId": "kb-1", }, ], "nodeCount": 2, "edgeCount": 1, }, } mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.post = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1")) assert len(entities) == 2 assert entities[0].name == "用户数据" assert entities[1].name == "user_id" assert len(relations) == 1 assert relations[0].source_name == "用户数据" assert relations[0].target_name == "user_id" assert relations[0].relation_type == "HAS_FIELD" assert relations[0].source_type == "Dataset" assert relations[0].target_type == "Field" def test_unwrapped_subgraph_response(self, client: KGServiceClient): """Java 直接返回 SubgraphExportVO(无全局包装)。""" mock_body = { "nodes": [ {"id": "n1", "name": "A", "type": "T1", "description": ""}, ], "edges": [], "nodeCount": 1, "edgeCount": 0, } mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.post = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities, relations = _run(client.get_subgraph("gid", ["n1"])) assert len(entities) == 1 assert entities[0].name == "A" assert relations == [] def test_edge_with_unknown_entity(self, client: KGServiceClient): """边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。""" mock_body = { "code": 200, "data": { "nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}], "edges": [ { "sourceEntityId": "n1", "targetEntityId": "n999", "relationType": "DEPENDS_ON", }, ], }, } mock_resp = _resp(200, json=mock_body) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.post = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities, relations = _run(client.get_subgraph("gid", ["n1"])) assert len(relations) == 1 assert relations[0].source_name == "A" assert relations[0].target_name == "n999" # fallback to ID assert relations[0].target_type == "" def test_fail_open_on_error(self, client: KGServiceClient): mock_resp = _resp(500, text="error") with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.post = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http entities, relations = _run(client.get_subgraph("gid", ["n1"])) assert entities == [] assert relations == [] def test_request_params(self, client: KGServiceClient): """验证子图请求参数正确传递。""" mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}}) with patch.object(client, "_get_client") as mock_get: mock_http = AsyncMock() mock_http.post = AsyncMock(return_value=mock_resp) mock_get.return_value = mock_http _run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1")) call_kwargs = mock_http.post.call_args assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0] assert call_kwargs.kwargs["params"] == {"depth": 3} assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]} # --------------------------------------------------------------------------- # headers 测试 # --------------------------------------------------------------------------- class TestHeaders: def test_headers_with_token_and_user(self, client: KGServiceClient): headers = client._headers(user_id="user-1") assert headers["X-Internal-Token"] == "test-token" assert headers["X-User-Id"] == "user-1" def test_headers_without_user(self, client: KGServiceClient): headers = client._headers() assert "X-Internal-Token" in headers assert "X-User-Id" not in headers def test_headers_without_token(self): c = KGServiceClient(base_url="http://test:8080", internal_token="") headers = c._headers(user_id="u1") assert "X-Internal-Token" not in headers assert headers["X-User-Id"] == "u1"