From 39338df80887b8b9e6d07f310d397d77d8ddeb8a Mon Sep 17 00:00:00 2001 From: Jerry Yan <792602257@qq.com> Date: Fri, 20 Feb 2026 09:41:55 +0800 Subject: [PATCH] =?UTF-8?q?feat(kg):=20=E5=AE=9E=E7=8E=B0=20Phase=202=20Gr?= =?UTF-8?q?aphRAG=20=E8=9E=8D=E5=90=88=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 核心功能: - 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序 - LLM 生成:支持同步和流式(SSE)响应 - 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证 新增模块(9个文件): - models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等) - milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread) - kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open) - context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建 - generator.py: LLM 生成(ChatOpenAI,支持同步和流式) - retriever.py: 检索编排(并行检索 + 融合排序) - kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close) - interface.py: FastAPI 端点(/query, /retrieve, /query/stream) - __init__.py: 模块入口 修改文件(3个): - app/core/config.py: 添加 13 个 graphrag_* 配置项 - app/module/__init__.py: 注册 kg_graphrag_router - pyproject.toml: 添加 pymilvus 依赖 测试覆盖(79 tests): - test_context_builder.py: 13 tests(三元组文本化 + 上下文构建) - test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射) - test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread) - test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open) - test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例) - test_interface.py: 15 tests(端点级回归 + 403 short-circuit) 关键设计: - Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果 - Fail-close: 访问控制失败拒绝请求,防止授权绕过 - 并行检索: asyncio.gather() 并发运行向量和图检索 - 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight) - 延迟初始化: 所有客户端在首次请求时初始化 - 配置回退: graphrag_llm_* 为空时回退到 kg_llm_* 安全修复: - P1-1: KG 响应解析(PagedResponse.content) - P1-2: 子图边字段映射(sourceEntityId/targetEntityId) - P1-3: collection_name 越权风险(归属校验 + 绑定验证) - P1-4: 同步 Milvus I/O(asyncio.to_thread) - P1-5: 测试覆盖(79 tests,包括安全负例) 测试结果:79 tests pass ✅ --- runtime/datamate-python/app/core/config.py | 23 ++ .../datamate-python/app/module/__init__.py | 2 + .../app/module/kg_graphrag/__init__.py | 5 + .../app/module/kg_graphrag/context_builder.py | 110 ++++++ .../app/module/kg_graphrag/generator.py | 101 ++++++ .../app/module/kg_graphrag/interface.py | 249 +++++++++++++ .../app/module/kg_graphrag/kb_access.py | 118 +++++++ .../app/module/kg_graphrag/kg_client.py | 197 +++++++++++ .../app/module/kg_graphrag/milvus_client.py | 135 +++++++ .../app/module/kg_graphrag/models.py | 102 ++++++ .../app/module/kg_graphrag/retriever.py | 214 ++++++++++++ .../kg_graphrag/test_context_builder.py | 182 ++++++++++ .../app/module/kg_graphrag/test_interface.py | 300 ++++++++++++++++ .../app/module/kg_graphrag/test_kb_access.py | 330 ++++++++++++++++++ .../app/module/kg_graphrag/test_kg_client.py | 297 ++++++++++++++++ .../module/kg_graphrag/test_milvus_client.py | 145 ++++++++ .../app/module/kg_graphrag/test_retriever.py | 234 +++++++++++++ runtime/datamate-python/pyproject.toml | 1 + 18 files changed, 2745 insertions(+) create mode 100644 runtime/datamate-python/app/module/kg_graphrag/__init__.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/context_builder.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/generator.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/interface.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/kb_access.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/kg_client.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/milvus_client.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/models.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/retriever.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_context_builder.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_interface.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_milvus_client.py create mode 100644 runtime/datamate-python/app/module/kg_graphrag/test_retriever.py diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index a3fcea7..c83f90b 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -88,6 +88,29 @@ class Settings(BaseSettings): kg_alignment_vector_threshold: float = 0.92 kg_alignment_llm_threshold: float = 0.78 + # GraphRAG 融合查询配置 + graphrag_enabled: bool = False + graphrag_milvus_uri: str = "http://milvus-standalone:19530" + graphrag_kg_service_url: str = "http://datamate-kg:8080" + graphrag_kg_internal_token: str = "" + + # GraphRAG - 检索策略默认值 + graphrag_vector_top_k: int = 5 + graphrag_graph_depth: int = 2 + graphrag_graph_max_entities: int = 20 + graphrag_vector_weight: float = 0.6 + graphrag_graph_weight: float = 0.4 + + # GraphRAG - LLM(空则复用 kg_llm_* 配置) + graphrag_llm_model: str = "" + graphrag_llm_base_url: Optional[str] = None + graphrag_llm_api_key: SecretStr = SecretStr("EMPTY") + graphrag_llm_temperature: float = 0.1 + graphrag_llm_timeout_seconds: int = 60 + + # GraphRAG - Embedding(空则复用 kg_alignment_embedding_* 配置) + graphrag_embedding_model: str = "" + # 标注编辑器(Label Studio Editor)相关 editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数 diff --git a/runtime/datamate-python/app/module/__init__.py b/runtime/datamate-python/app/module/__init__.py index 4900d44..611e657 100644 --- a/runtime/datamate-python/app/module/__init__.py +++ b/runtime/datamate-python/app/module/__init__.py @@ -8,6 +8,7 @@ from .evaluation.interface import router as evaluation_router from .collection.interface import router as collection_route from .dataset.interface import router as dataset_router from .kg_extraction.interface import router as kg_extraction_router +from .kg_graphrag.interface import router as kg_graphrag_router router = APIRouter( prefix="/api" @@ -21,5 +22,6 @@ router.include_router(evaluation_router) router.include_router(collection_route) router.include_router(dataset_router) router.include_router(kg_extraction_router) +router.include_router(kg_graphrag_router) __all__ = ["router"] diff --git a/runtime/datamate-python/app/module/kg_graphrag/__init__.py b/runtime/datamate-python/app/module/kg_graphrag/__init__.py new file mode 100644 index 0000000..5a086b7 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/__init__.py @@ -0,0 +1,5 @@ +"""GraphRAG 融合查询模块。""" + +from app.module.kg_graphrag.interface import router + +__all__ = ["router"] diff --git a/runtime/datamate-python/app/module/kg_graphrag/context_builder.py b/runtime/datamate-python/app/module/kg_graphrag/context_builder.py new file mode 100644 index 0000000..16150c0 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/context_builder.py @@ -0,0 +1,110 @@ +"""三元组文本化 + 上下文构建。 + +将图谱子图(实体 + 关系)转为自然语言描述, +并与向量检索片段合并为 LLM 可消费的上下文文本。 +""" + +from __future__ import annotations + +from app.module.kg_graphrag.models import ( + EntitySummary, + RelationSummary, + VectorChunk, +) + +# 关系类型 -> 中文模板映射 +RELATION_TEMPLATES: dict[str, str] = { + "HAS_FIELD": "{source}包含字段{target}", + "DERIVED_FROM": "{source}来源于{target}", + "USES_DATASET": "{source}使用了数据集{target}", + "PRODUCES": "{source}产出了{target}", + "ASSIGNED_TO": "{source}分配给了{target}", + "BELONGS_TO": "{source}属于{target}", + "TRIGGERS": "{source}触发了{target}", + "DEPENDS_ON": "{source}依赖于{target}", + "IMPACTS": "{source}影响了{target}", + "SOURCED_FROM": "{source}的知识来源于{target}", +} + +# 通用模板(未在映射中的关系类型) +_DEFAULT_TEMPLATE = "{source}与{target}存在{relation}关系" + + +def textualize_subgraph( + entities: list[EntitySummary], + relations: list[RelationSummary], +) -> str: + """将图谱子图转为自然语言描述。 + + Args: + entities: 子图中的实体列表。 + relations: 子图中的关系列表。 + + Returns: + 文本化后的图谱描述,每条关系/实体一行。 + """ + lines: list[str] = [] + + # 记录有关系的实体名称 + mentioned_entities: set[str] = set() + + # 1. 对每条关系生成一句话 + for rel in relations: + source_label = f"{rel.source_type}'{rel.source_name}'" + target_label = f"{rel.target_type}'{rel.target_name}'" + template = RELATION_TEMPLATES.get(rel.relation_type, _DEFAULT_TEMPLATE) + line = template.format( + source=source_label, + target=target_label, + relation=rel.relation_type, + ) + lines.append(line) + mentioned_entities.add(rel.source_name) + mentioned_entities.add(rel.target_name) + + # 2. 对独立实体(无关系)生成描述句 + for entity in entities: + if entity.name not in mentioned_entities: + desc = entity.description or "" + if desc: + lines.append(f"{entity.type}'{entity.name}': {desc}") + else: + lines.append(f"存在{entity.type}'{entity.name}'") + + return "\n".join(lines) + + +def build_context( + vector_chunks: list[VectorChunk], + graph_text: str, + vector_weight: float = 0.6, + graph_weight: float = 0.4, +) -> str: + """合并向量检索片段和图谱文本化内容为 LLM 上下文。 + + Args: + vector_chunks: 向量检索到的文档片段列表。 + graph_text: 文本化后的图谱描述。 + vector_weight: 向量分数权重(当前用于日志/调试,不影响上下文排序)。 + graph_weight: 图谱相关性权重。 + + Returns: + 合并后的上下文文本,分为「相关文档」和「知识图谱上下文」两个部分。 + """ + sections: list[str] = [] + + # 向量检索片段 + if vector_chunks: + doc_lines = ["## 相关文档"] + for i, chunk in enumerate(vector_chunks, 1): + doc_lines.append(f"[{i}] {chunk.text}") + sections.append("\n".join(doc_lines)) + + # 图谱文本化内容 + if graph_text: + sections.append(f"## 知识图谱上下文\n{graph_text}") + + if not sections: + return "(未检索到相关上下文信息)" + + return "\n\n".join(sections) diff --git a/runtime/datamate-python/app/module/kg_graphrag/generator.py b/runtime/datamate-python/app/module/kg_graphrag/generator.py new file mode 100644 index 0000000..bbae253 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/generator.py @@ -0,0 +1,101 @@ +"""LLM 生成器。 + +基于增强上下文(向量 + 图谱)调用 LLM 生成回答, +支持同步和流式两种模式。 +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator + +from pydantic import SecretStr + +from app.core.logging import get_logger + +logger = get_logger(__name__) + +_SYSTEM_PROMPT = ( + "你是 DataMate 数据管理平台的智能助手。请根据以下上下文信息回答用户的问题。\n" + "如果上下文中没有相关信息,请明确说明。不要编造信息。" +) + + +class GraphRAGGenerator: + """GraphRAG LLM 生成器。""" + + def __init__( + self, + *, + model: str = "gpt-4o-mini", + base_url: str | None = None, + api_key: SecretStr = SecretStr("EMPTY"), + temperature: float = 0.1, + timeout: int = 60, + ) -> None: + self._model = model + self._base_url = base_url + self._api_key = api_key + self._temperature = temperature + self._timeout = timeout + self._llm = None + + @property + def model_name(self) -> str: + return self._model + + @classmethod + def from_settings(cls) -> GraphRAGGenerator: + from app.core.config import settings + + model = settings.graphrag_llm_model or settings.kg_llm_model + base_url = settings.graphrag_llm_base_url or settings.kg_llm_base_url + api_key = ( + settings.graphrag_llm_api_key + if settings.graphrag_llm_api_key.get_secret_value() != "EMPTY" + else settings.kg_llm_api_key + ) + return cls( + model=model, + base_url=base_url, + api_key=api_key, + temperature=settings.graphrag_llm_temperature, + timeout=settings.graphrag_llm_timeout_seconds, + ) + + def _get_llm(self): + if self._llm is None: + from langchain_openai import ChatOpenAI + + self._llm = ChatOpenAI( + model=self._model, + base_url=self._base_url, + api_key=self._api_key, + temperature=self._temperature, + timeout=self._timeout, + ) + return self._llm + + def _build_messages(self, query: str, context: str) -> list[dict[str, str]]: + return [ + {"role": "system", "content": _SYSTEM_PROMPT}, + { + "role": "user", + "content": f"{context}\n\n用户问题: {query}\n\n请基于上下文中的信息回答。", + }, + ] + + async def generate(self, query: str, context: str) -> str: + """基于增强上下文生成回答。""" + messages = self._build_messages(query, context) + llm = self._get_llm() + response = await llm.ainvoke(messages) + return str(response.content) + + async def generate_stream(self, query: str, context: str) -> AsyncIterator[str]: + """基于增强上下文流式生成回答,逐 token 返回。""" + messages = self._build_messages(query, context) + llm = self._get_llm() + async for chunk in llm.astream(messages): + content = chunk.content + if content: + yield str(content) diff --git a/runtime/datamate-python/app/module/kg_graphrag/interface.py b/runtime/datamate-python/app/module/kg_graphrag/interface.py new file mode 100644 index 0000000..47ff845 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/interface.py @@ -0,0 +1,249 @@ +"""GraphRAG 融合查询 API 端点。 + +提供向量检索 + 知识图谱的融合查询能力: +- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成) +- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM) +- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE) +""" + +from __future__ import annotations + +import uuid +from typing import Annotated + +from fastapi import APIRouter, Depends, Header, HTTPException +from fastapi.responses import StreamingResponse + +from app.core.logging import get_logger +from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator +from app.module.kg_graphrag.models import ( + GraphRAGQueryRequest, + GraphRAGQueryResponse, + RetrievalContext, +) +from app.module.kg_graphrag.retriever import GraphRAGRetriever +from app.module.kg_graphrag.generator import GraphRAGGenerator +from app.module.shared.schema import StandardResponse + +router = APIRouter(prefix="/graphrag", tags=["graphrag"]) +logger = get_logger(__name__) + +# 延迟初始化 +_retriever: GraphRAGRetriever | None = None +_generator: GraphRAGGenerator | None = None +_kb_validator: KnowledgeBaseAccessValidator | None = None + + +def _get_retriever() -> GraphRAGRetriever: + global _retriever + if _retriever is None: + _retriever = GraphRAGRetriever.from_settings() + return _retriever + + +def _get_generator() -> GraphRAGGenerator: + global _generator + if _generator is None: + _generator = GraphRAGGenerator.from_settings() + return _generator + + +def _get_kb_validator() -> KnowledgeBaseAccessValidator: + global _kb_validator + if _kb_validator is None: + _kb_validator = KnowledgeBaseAccessValidator.from_settings() + return _kb_validator + + +def _require_caller_id( + x_user_id: Annotated[ + str, + Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"), + ], +) -> str: + caller = x_user_id.strip() + if not caller: + raise HTTPException(status_code=401, detail="Missing required header: X-User-Id") + return caller + + +# --------------------------------------------------------------------------- +# P0: 完整 GraphRAG 查询 +# --------------------------------------------------------------------------- + + +@router.post( + "/query", + response_model=StandardResponse[GraphRAGQueryResponse], + summary="GraphRAG 查询", + description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。", +) +async def query( + req: GraphRAGQueryRequest, + caller: Annotated[str, Depends(_require_caller_id)], +): + trace_id = uuid.uuid4().hex[:16] + logger.info( + "[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s", + trace_id, req.graph_id, req.collection_name, caller, + ) + + retriever = _get_retriever() + generator = _get_generator() + + # 权限校验:验证用户是否有权访问该知识库 + kb_validator = _get_kb_validator() + if not await kb_validator.check_access( + req.knowledge_base_id, caller, collection_name=req.collection_name, + ): + logger.warning( + "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", + trace_id, req.knowledge_base_id, req.collection_name, caller, + ) + raise HTTPException( + status_code=403, + detail=f"无权访问知识库 {req.knowledge_base_id}", + ) + + try: + context = await retriever.retrieve( + query=req.query, + collection_name=req.collection_name, + graph_id=req.graph_id, + strategy=req.strategy, + user_id=caller, + ) + except Exception: + logger.exception("[%s] Retrieval failed", trace_id) + raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") + + try: + answer = await generator.generate(query=req.query, context=context.merged_text) + except Exception: + logger.exception("[%s] Generation failed", trace_id) + raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})") + + result = GraphRAGQueryResponse( + answer=answer, + context=context, + model=generator.model_name, + ) + return StandardResponse(code=200, message="success", data=result) + + +# --------------------------------------------------------------------------- +# P1-1: 仅检索 +# --------------------------------------------------------------------------- + + +@router.post( + "/retrieve", + response_model=StandardResponse[RetrievalContext], + summary="GraphRAG 仅检索", + description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。", +) +async def retrieve( + req: GraphRAGQueryRequest, + caller: Annotated[str, Depends(_require_caller_id)], +): + trace_id = uuid.uuid4().hex[:16] + logger.info( + "[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s", + trace_id, req.graph_id, req.collection_name, caller, + ) + + retriever = _get_retriever() + + # 权限校验:验证用户是否有权访问该知识库 + kb_validator = _get_kb_validator() + if not await kb_validator.check_access( + req.knowledge_base_id, caller, collection_name=req.collection_name, + ): + logger.warning( + "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", + trace_id, req.knowledge_base_id, req.collection_name, caller, + ) + raise HTTPException( + status_code=403, + detail=f"无权访问知识库 {req.knowledge_base_id}", + ) + + try: + context = await retriever.retrieve( + query=req.query, + collection_name=req.collection_name, + graph_id=req.graph_id, + strategy=req.strategy, + user_id=caller, + ) + except Exception: + logger.exception("[%s] Retrieval failed", trace_id) + raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") + + return StandardResponse(code=200, message="success", data=context) + + +# --------------------------------------------------------------------------- +# P1-4: 流式查询 (SSE) +# --------------------------------------------------------------------------- + + +@router.post( + "/query/stream", + summary="GraphRAG 流式查询", + description="并行检索后,通过 SSE 流式返回 LLM 生成内容。", +) +async def query_stream( + req: GraphRAGQueryRequest, + caller: Annotated[str, Depends(_require_caller_id)], +): + trace_id = uuid.uuid4().hex[:16] + logger.info( + "[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s", + trace_id, req.graph_id, req.collection_name, caller, + ) + + retriever = _get_retriever() + generator = _get_generator() + + # 权限校验:验证用户是否有权访问该知识库 + kb_validator = _get_kb_validator() + if not await kb_validator.check_access( + req.knowledge_base_id, caller, collection_name=req.collection_name, + ): + logger.warning( + "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", + trace_id, req.knowledge_base_id, req.collection_name, caller, + ) + raise HTTPException( + status_code=403, + detail=f"无权访问知识库 {req.knowledge_base_id}", + ) + + try: + context = await retriever.retrieve( + query=req.query, + collection_name=req.collection_name, + graph_id=req.graph_id, + strategy=req.strategy, + user_id=caller, + ) + except Exception: + logger.exception("[%s] Retrieval failed", trace_id) + raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") + + import json + + async def event_stream(): + try: + async for token in generator.generate_stream( + query=req.query, context=context.merged_text + ): + yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n" + # 结束事件:附带检索上下文 + yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n" + except Exception: + logger.exception("[%s] Stream generation failed", trace_id) + yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n" + + return StreamingResponse(event_stream(), media_type="text/event-stream") diff --git a/runtime/datamate-python/app/module/kg_graphrag/kb_access.py b/runtime/datamate-python/app/module/kg_graphrag/kb_access.py new file mode 100644 index 0000000..0a100f2 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/kb_access.py @@ -0,0 +1,118 @@ +"""知识库访问权限校验。 + +在执行 GraphRAG 检索前,调用 Java rag-indexer-service 的 +GET /knowledge-base/{id} 端点验证当前用户是否有权访问该知识库。 + +Java 侧实现参考:KnowledgeBaseService.getKnowledgeBaseWithAccessCheck() +- 查找 KB 是否存在 +- 校验 createdBy == currentUserId(管理员跳过) +- 不满足则抛出 sys.0005 (INSUFFICIENT_PERMISSIONS) +""" + +from __future__ import annotations + +import httpx + +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class KnowledgeBaseAccessValidator: + """通过 Java 后端校验用户是否有权访问指定知识库。""" + + def __init__( + self, + *, + base_url: str = "http://datamate-backend:8080/api", + timeout: float = 10.0, + ) -> None: + self._base_url = base_url.rstrip("/") + self._timeout = timeout + self._client: httpx.AsyncClient | None = None + + @classmethod + def from_settings(cls) -> KnowledgeBaseAccessValidator: + from app.core.config import settings + + return cls(base_url=settings.datamate_backend_base_url) + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient( + base_url=self._base_url, + timeout=self._timeout, + ) + return self._client + + async def check_access( + self, + knowledge_base_id: str, + user_id: str, + *, + collection_name: str | None = None, + ) -> bool: + """校验用户是否有权访问指定知识库。 + + 调用 Java 后端 GET /knowledge-base/{id},该端点内部执行 + owner 校验(createdBy == currentUserId,管理员跳过)。 + + 当 *collection_name* 不为 None 时,还会校验请求中的 + collection_name 与该知识库实际的 name 是否一致,防止 + 用户提交合法 KB ID 但篡改 collection_name 来访问 + 其他知识库的 Milvus 数据。 + + Returns: + True — 用户有权访问且 collection_name 匹配 + False — 无权访问、collection_name 不匹配或校验失败 + """ + try: + client = self._get_client() + resp = await client.get( + f"/knowledge-base/{knowledge_base_id}", + headers={"X-User-Id": user_id}, + ) + if resp.status_code == 200: + body = resp.json() + # Java 全局包装: {"code": 200, "data": {...}} + # code != 200 说明业务层拒绝(如权限不足) + code = body.get("code", resp.status_code) + if code != 200: + logger.warning( + "KB access denied: kb_id=%s, user=%s, biz_code=%s, msg=%s", + knowledge_base_id, user_id, code, body.get("message", ""), + ) + return False + + # 校验 collection_name 与 KB 实际名称的绑定关系 + if collection_name is not None: + data = body.get("data") or {} + actual_name = data.get("name") if isinstance(data, dict) else None + if actual_name != collection_name: + logger.warning( + "KB collection_name mismatch: kb_id=%s, " + "expected=%s, actual=%s, user=%s", + knowledge_base_id, collection_name, + actual_name, user_id, + ) + return False + + return True + # HTTP 4xx/5xx + logger.warning( + "KB access check returned HTTP %d: kb_id=%s, user=%s", + resp.status_code, knowledge_base_id, user_id, + ) + return False + except Exception: + # 网络异常时 fail-close:拒绝访问,防止绕过权限 + logger.exception( + "KB access check failed (fail-close): kb_id=%s, user=%s", + knowledge_base_id, user_id, + ) + return False + + async def close(self) -> None: + if self._client is not None: + await self._client.aclose() + self._client = None diff --git a/runtime/datamate-python/app/module/kg_graphrag/kg_client.py b/runtime/datamate-python/app/module/kg_graphrag/kg_client.py new file mode 100644 index 0000000..499a759 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/kg_client.py @@ -0,0 +1,197 @@ +"""KG 服务 REST 客户端。 + +通过 httpx 调用 Java 侧 knowledge-graph-service 的查询 API, +包括全文检索和子图导出。 + +失败策略:fail-open —— KG 服务不可用时返回空结果 + 日志告警。 +""" + +from __future__ import annotations + +import httpx + +from app.core.logging import get_logger +from app.module.kg_graphrag.models import EntitySummary, RelationSummary + +logger = get_logger(__name__) + + +class KGServiceClient: + """Java KG 服务 REST 客户端。""" + + def __init__( + self, + *, + base_url: str = "http://datamate-kg:8080", + internal_token: str = "", + timeout: float = 30.0, + ) -> None: + self._base_url = base_url.rstrip("/") + self._internal_token = internal_token + self._timeout = timeout + self._client: httpx.AsyncClient | None = None + + @classmethod + def from_settings(cls) -> KGServiceClient: + from app.core.config import settings + + return cls( + base_url=settings.graphrag_kg_service_url, + internal_token=settings.graphrag_kg_internal_token, + timeout=30.0, + ) + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient( + base_url=self._base_url, + timeout=self._timeout, + ) + return self._client + + def _headers(self, user_id: str = "") -> dict[str, str]: + headers: dict[str, str] = {} + if self._internal_token: + headers["X-Internal-Token"] = self._internal_token + if user_id: + headers["X-User-Id"] = user_id + return headers + + async def fulltext_search( + self, + graph_id: str, + query: str, + size: int = 10, + user_id: str = "", + ) -> list[EntitySummary]: + """调用 KG 服务全文检索,返回匹配的实体列表。 + + Fail-open: KG 服务不可用时返回空列表。 + """ + try: + return await self._fulltext_search_impl(graph_id, query, size, user_id) + except Exception: + logger.exception( + "KG fulltext search failed for graph_id=%s (fail-open, returning empty)", + graph_id, + ) + return [] + + async def _fulltext_search_impl( + self, + graph_id: str, + query: str, + size: int, + user_id: str, + ) -> list[EntitySummary]: + client = self._get_client() + resp = await client.get( + f"/knowledge-graph/{graph_id}/query/search", + params={"q": query, "size": size}, + headers=self._headers(user_id), + ) + resp.raise_for_status() + body = resp.json() + + # Java 返回 PagedResponse: + # 可能被全局包装为 {"code": 200, "data": PagedResponse} + # 也可能直接返回 PagedResponse {"page": 0, "content": [...]} + data = body.get("data", body) + # PagedResponse 将实体列表放在 content 字段中 + items: list[dict] = ( + data.get("content", []) if isinstance(data, dict) else data if isinstance(data, list) else [] + ) + entities: list[EntitySummary] = [] + for item in items: + entities.append( + EntitySummary( + id=str(item.get("id", "")), + name=item.get("name", ""), + type=item.get("type", ""), + description=item.get("description", ""), + ) + ) + return entities + + async def get_subgraph( + self, + graph_id: str, + entity_ids: list[str], + depth: int = 1, + user_id: str = "", + ) -> tuple[list[EntitySummary], list[RelationSummary]]: + """获取种子实体的 N-hop 子图。 + + Fail-open: KG 服务不可用时返回空子图。 + """ + try: + return await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id) + except Exception: + logger.exception( + "KG subgraph export failed for graph_id=%s (fail-open, returning empty)", + graph_id, + ) + return [], [] + + async def _get_subgraph_impl( + self, + graph_id: str, + entity_ids: list[str], + depth: int, + user_id: str, + ) -> tuple[list[EntitySummary], list[RelationSummary]]: + client = self._get_client() + resp = await client.post( + f"/knowledge-graph/{graph_id}/query/subgraph/export", + params={"depth": depth}, + json={"entityIds": entity_ids}, + headers=self._headers(user_id), + ) + resp.raise_for_status() + body = resp.json() + + # Java 返回 SubgraphExportVO: + # 可能被全局包装为 {"code": 200, "data": SubgraphExportVO} + # 也可能直接返回 SubgraphExportVO {"nodes": [...], "edges": [...]} + data = body.get("data", body) if isinstance(body.get("data"), dict) else body + nodes_raw = data.get("nodes", []) + edges_raw = data.get("edges", []) + + # ExportNodeVO: id, name, type, description, properties (Map) + entities: list[EntitySummary] = [] + for node in nodes_raw: + entities.append( + EntitySummary( + id=str(node.get("id", "")), + name=node.get("name", ""), + type=node.get("type", ""), + description=node.get("description", ""), + ) + ) + + relations: list[RelationSummary] = [] + # 构建 id -> entity 的映射用于查找 source/target 名称和类型 + entity_map = {e.id: e for e in entities} + # ExportEdgeVO: sourceEntityId, targetEntityId, relationType + # 注意:sourceId 是数据来源 ID,不是源实体 ID + for edge in edges_raw: + source_id = str(edge.get("sourceEntityId", "")) + target_id = str(edge.get("targetEntityId", "")) + source_entity = entity_map.get(source_id) + target_entity = entity_map.get(target_id) + relations.append( + RelationSummary( + source_name=source_entity.name if source_entity else source_id, + source_type=source_entity.type if source_entity else "", + target_name=target_entity.name if target_entity else target_id, + target_type=target_entity.type if target_entity else "", + relation_type=edge.get("relationType", ""), + ) + ) + + return entities, relations + + async def close(self) -> None: + if self._client is not None: + await self._client.aclose() + self._client = None diff --git a/runtime/datamate-python/app/module/kg_graphrag/milvus_client.py b/runtime/datamate-python/app/module/kg_graphrag/milvus_client.py new file mode 100644 index 0000000..64aaa4f --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/milvus_client.py @@ -0,0 +1,135 @@ +"""Milvus 向量检索客户端。 + +通过 pymilvus 连接 Milvus,对查询文本进行 embedding 后执行混合搜索, +返回 top-K 文档片段。 + +失败策略:fail-open —— Milvus 不可用时返回空列表 + 日志告警。 +""" + +from __future__ import annotations + +import asyncio + +from pydantic import SecretStr + +from app.core.logging import get_logger +from app.module.kg_graphrag.models import VectorChunk + +logger = get_logger(__name__) + + +class MilvusVectorRetriever: + """Milvus 向量检索器。""" + + def __init__( + self, + *, + uri: str = "http://milvus-standalone:19530", + embedding_model: str = "text-embedding-3-small", + embedding_base_url: str | None = None, + embedding_api_key: SecretStr = SecretStr("EMPTY"), + ) -> None: + self._uri = uri + self._embedding_model = embedding_model + self._embedding_base_url = embedding_base_url + self._embedding_api_key = embedding_api_key + # Lazy init + self._milvus_client = None + self._embeddings = None + + @classmethod + def from_settings(cls) -> MilvusVectorRetriever: + from app.core.config import settings + + embedding_model = ( + settings.graphrag_embedding_model + or settings.kg_alignment_embedding_model + ) + return cls( + uri=settings.graphrag_milvus_uri, + embedding_model=embedding_model, + embedding_base_url=settings.kg_llm_base_url, + embedding_api_key=settings.kg_llm_api_key, + ) + + def _get_embeddings(self): + if self._embeddings is None: + from langchain_openai import OpenAIEmbeddings + + self._embeddings = OpenAIEmbeddings( + model=self._embedding_model, + base_url=self._embedding_base_url, + api_key=self._embedding_api_key, + ) + return self._embeddings + + def _get_milvus_client(self): + if self._milvus_client is None: + from pymilvus import MilvusClient + + self._milvus_client = MilvusClient(uri=self._uri) + logger.info("Connected to Milvus at %s", self._uri) + return self._milvus_client + + async def has_collection(self, collection_name: str) -> bool: + """检查 Milvus 中是否存在指定 collection(防止越权访问不存在的库)。""" + try: + client = self._get_milvus_client() + return await asyncio.to_thread(client.has_collection, collection_name) + except Exception: + logger.exception("Milvus has_collection check failed for %s", collection_name) + return False + + async def search( + self, + collection_name: str, + query: str, + top_k: int = 5, + ) -> list[VectorChunk]: + """向量搜索:embed query -> Milvus search -> 返回 top-K 文档片段。 + + Fail-open: Milvus 不可用时返回空列表。 + """ + try: + return await self._search_impl(collection_name, query, top_k) + except Exception: + logger.exception( + "Milvus search failed for collection=%s (fail-open, returning empty)", + collection_name, + ) + return [] + + async def _search_impl( + self, + collection_name: str, + query: str, + top_k: int, + ) -> list[VectorChunk]: + # 1. Embed query + query_vector = await self._get_embeddings().aembed_query(query) + + # 2. Milvus search(同步 I/O,通过 to_thread 避免阻塞事件循环) + client = self._get_milvus_client() + results = await asyncio.to_thread( + client.search, + collection_name=collection_name, + data=[query_vector], + limit=top_k, + output_fields=["text", "metadata"], + search_params={"metric_type": "COSINE", "params": {"nprobe": 16}}, + ) + + # 3. 转换为 VectorChunk + chunks: list[VectorChunk] = [] + if results and len(results) > 0: + for hit in results[0]: + entity = hit.get("entity", {}) + chunks.append( + VectorChunk( + id=str(hit.get("id", "")), + text=entity.get("text", ""), + score=float(hit.get("distance", 0.0)), + metadata=entity.get("metadata", {}), + ) + ) + return chunks diff --git a/runtime/datamate-python/app/module/kg_graphrag/models.py b/runtime/datamate-python/app/module/kg_graphrag/models.py new file mode 100644 index 0000000..ef0a6dc --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/models.py @@ -0,0 +1,102 @@ +"""GraphRAG 融合查询的请求/响应数据模型。""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class RetrievalStrategy(BaseModel): + """检索策略配置。""" + + vector_top_k: int = Field(default=5, ge=1, le=50, description="向量检索返回数") + graph_depth: int = Field(default=2, ge=1, le=5, description="图谱扩展深度") + graph_max_entities: int = Field(default=20, ge=1, le=100, description="图谱最大实体数") + vector_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="向量分数权重") + graph_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="图谱相关性权重") + enable_graph: bool = Field(default=True, description="是否启用图谱检索") + enable_vector: bool = Field(default=True, description="是否启用向量检索") + + +class GraphRAGQueryRequest(BaseModel): + """GraphRAG 查询请求。""" + + query: str = Field( + ..., + min_length=1, + max_length=2000, + description="用户查询", + ) + knowledge_base_id: str = Field( + ..., + min_length=1, + max_length=64, + description="知识库 ID,用于权限校验(由上游 Java 后端传入)", + ) + collection_name: str = Field( + ..., + min_length=1, + max_length=256, + pattern=r"^[a-zA-Z0-9_\-\u4e00-\u9fff]+$", + description="Milvus collection 名称(= 知识库名),仅允许字母、数字、下划线、连字符和中文", + ) + graph_id: str = Field( + ..., + pattern=r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$", + description="Neo4j 图谱 ID(UUID 格式)", + ) + strategy: RetrievalStrategy = Field( + default_factory=RetrievalStrategy, + description="可选策略覆盖", + ) + + +class VectorChunk(BaseModel): + """向量检索到的文档片段。""" + + id: str + text: str + score: float + metadata: dict[str, object] = Field(default_factory=dict) + + +class EntitySummary(BaseModel): + """实体摘要。""" + + id: str + name: str + type: str + description: str = "" + + +class RelationSummary(BaseModel): + """关系摘要。""" + + source_name: str + source_type: str + target_name: str + target_type: str + relation_type: str + + +class GraphContext(BaseModel): + """图谱上下文。""" + + entities: list[EntitySummary] = Field(default_factory=list) + relations: list[RelationSummary] = Field(default_factory=list) + textualized: str = "" + + +class RetrievalContext(BaseModel): + """检索上下文(检索结果的结构化表示)。""" + + vector_chunks: list[VectorChunk] = Field(default_factory=list) + graph_context: GraphContext = Field(default_factory=GraphContext) + merged_text: str = "" + + +class GraphRAGQueryResponse(BaseModel): + """GraphRAG 查询响应。""" + + answer: str = Field(..., description="LLM 生成的回答") + context: RetrievalContext = Field(..., description="检索上下文") + model: str = Field(..., description="使用的 LLM 模型名") diff --git a/runtime/datamate-python/app/module/kg_graphrag/retriever.py b/runtime/datamate-python/app/module/kg_graphrag/retriever.py new file mode 100644 index 0000000..b4ba0c6 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/retriever.py @@ -0,0 +1,214 @@ +"""GraphRAG 检索编排器。 + +并行执行向量检索和图谱检索,融合排序后构建统一上下文。 +""" + +from __future__ import annotations + +import asyncio + +from app.core.logging import get_logger +from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph +from app.module.kg_graphrag.kg_client import KGServiceClient +from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever +from app.module.kg_graphrag.models import ( + EntitySummary, + GraphContext, + RelationSummary, + RetrievalContext, + RetrievalStrategy, + VectorChunk, +) + +logger = get_logger(__name__) + + +class GraphRAGRetriever: + """GraphRAG 检索编排器。""" + + def __init__( + self, + *, + milvus_client: MilvusVectorRetriever, + kg_client: KGServiceClient, + ) -> None: + self._milvus = milvus_client + self._kg = kg_client + + @classmethod + def from_settings(cls) -> GraphRAGRetriever: + return cls( + milvus_client=MilvusVectorRetriever.from_settings(), + kg_client=KGServiceClient.from_settings(), + ) + + async def retrieve( + self, + query: str, + collection_name: str, + graph_id: str, + strategy: RetrievalStrategy, + user_id: str = "", + ) -> RetrievalContext: + """并行执行向量检索 + 图谱检索,融合结果。""" + # 构建并行任务 + tasks: dict[str, asyncio.Task] = {} + + if strategy.enable_vector: + # 先校验 collection 存在性,防止越权访问 + if not await self._milvus.has_collection(collection_name): + logger.warning( + "Collection %s not found, skipping vector retrieval", + collection_name, + ) + else: + tasks["vector"] = asyncio.create_task( + self._milvus.search( + collection_name=collection_name, + query=query, + top_k=strategy.vector_top_k, + ) + ) + + if strategy.enable_graph: + tasks["graph"] = asyncio.create_task( + self._retrieve_graph( + query=query, + graph_id=graph_id, + strategy=strategy, + user_id=user_id, + ) + ) + + # 等待所有任务完成 + if tasks: + await asyncio.gather(*tasks.values(), return_exceptions=True) + + # 收集结果 + vector_chunks: list[VectorChunk] = [] + if "vector" in tasks: + try: + vector_chunks = tasks["vector"].result() + except Exception: + logger.exception("Vector retrieval task failed") + + entities: list[EntitySummary] = [] + relations: list[RelationSummary] = [] + if "graph" in tasks: + try: + entities, relations = tasks["graph"].result() + except Exception: + logger.exception("Graph retrieval task failed") + + # 融合排序 + vector_chunks = self._rank_results( + vector_chunks, entities, relations, strategy + ) + + # 三元组文本化 + graph_text = textualize_subgraph(entities, relations) + + # 构建上下文 + merged_text = build_context( + vector_chunks, + graph_text, + vector_weight=strategy.vector_weight, + graph_weight=strategy.graph_weight, + ) + + return RetrievalContext( + vector_chunks=vector_chunks, + graph_context=GraphContext( + entities=entities, + relations=relations, + textualized=graph_text, + ), + merged_text=merged_text, + ) + + async def _retrieve_graph( + self, + query: str, + graph_id: str, + strategy: RetrievalStrategy, + user_id: str, + ) -> tuple[list[EntitySummary], list[RelationSummary]]: + """图谱检索:全文搜索 -> 种子实体 -> 子图扩展。""" + # 1. 全文检索获取种子实体 + seed_entities = await self._kg.fulltext_search( + graph_id=graph_id, + query=query, + size=strategy.graph_max_entities, + user_id=user_id, + ) + + if not seed_entities: + logger.debug("No seed entities found for query: %s", query) + return [], [] + + # 2. 获取种子实体的 N-hop 子图 + seed_ids = [e.id for e in seed_entities] + entities, relations = await self._kg.get_subgraph( + graph_id=graph_id, + entity_ids=seed_ids, + depth=strategy.graph_depth, + user_id=user_id, + ) + + logger.info( + "Graph retrieval: %d seed entities -> %d entities, %d relations", + len(seed_entities), len(entities), len(relations), + ) + return entities, relations + + def _rank_results( + self, + vector_chunks: list[VectorChunk], + entities: list[EntitySummary], + relations: list[RelationSummary], + strategy: RetrievalStrategy, + ) -> list[VectorChunk]: + """对向量检索结果进行融合排序。 + + 基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。 + """ + if not vector_chunks: + return vector_chunks + + # 向量分数归一化 (min-max scaling) + scores = [c.score for c in vector_chunks] + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score + + # 构建图谱实体名称集合,用于关联度加分 + graph_entity_names = {e.name.lower() for e in entities} + + ranked: list[tuple[float, VectorChunk]] = [] + for chunk in vector_chunks: + # 归一化向量分数 + norm_score = ( + (chunk.score - min_score) / score_range + if score_range > 0 + else 1.0 + ) + + # 图谱关联度加分:文档片段中提及图谱实体名称 + graph_boost = 0.0 + if graph_entity_names: + chunk_text_lower = chunk.text.lower() + mentioned = sum( + 1 for name in graph_entity_names if name in chunk_text_lower + ) + graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0) + + # 加权融合分数 + final_score = ( + strategy.vector_weight * norm_score + + strategy.graph_weight * graph_boost + ) + ranked.append((final_score, chunk)) + + # 按融合分数降序排序 + ranked.sort(key=lambda x: x[0], reverse=True) + return [chunk for _, chunk in ranked] diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_context_builder.py b/runtime/datamate-python/app/module/kg_graphrag/test_context_builder.py new file mode 100644 index 0000000..43cec33 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_context_builder.py @@ -0,0 +1,182 @@ +"""三元组文本化 + 上下文构建的单元测试。""" + +from app.module.kg_graphrag.context_builder import ( + RELATION_TEMPLATES, + build_context, + textualize_subgraph, +) +from app.module.kg_graphrag.models import ( + EntitySummary, + RelationSummary, + VectorChunk, +) + + +# --------------------------------------------------------------------------- +# textualize_subgraph 测试 +# --------------------------------------------------------------------------- + + +class TestTextualizeSubgraph: + """textualize_subgraph 函数的测试。""" + + def test_single_relation(self): + entities = [ + EntitySummary(id="1", name="用户行为数据", type="Dataset"), + EntitySummary(id="2", name="user_id", type="Field"), + ] + relations = [ + RelationSummary( + source_name="用户行为数据", + source_type="Dataset", + target_name="user_id", + target_type="Field", + relation_type="HAS_FIELD", + ), + ] + result = textualize_subgraph(entities, relations) + assert "Dataset'用户行为数据'包含字段Field'user_id'" in result + + def test_multiple_relations(self): + entities = [ + EntitySummary(id="1", name="用户行为数据", type="Dataset"), + EntitySummary(id="2", name="清洗管道", type="Workflow"), + ] + relations = [ + RelationSummary( + source_name="清洗管道", + source_type="Workflow", + target_name="用户行为数据", + target_type="Dataset", + relation_type="USES_DATASET", + ), + RelationSummary( + source_name="用户行为数据", + source_type="Dataset", + target_name="外部系统", + target_type="DataSource", + relation_type="SOURCED_FROM", + ), + ] + result = textualize_subgraph(entities, relations) + assert "Workflow'清洗管道'使用了数据集Dataset'用户行为数据'" in result + assert "Dataset'用户行为数据'的知识来源于DataSource'外部系统'" in result + + def test_all_relation_templates(self): + """验证所有 10 种关系模板都能正确生成。""" + for rel_type, template in RELATION_TEMPLATES.items(): + relations = [ + RelationSummary( + source_name="A", + source_type="TypeA", + target_name="B", + target_type="TypeB", + relation_type=rel_type, + ), + ] + result = textualize_subgraph([], relations) + assert "TypeA'A'" in result + assert "TypeB'B'" in result + assert result # 非空 + + def test_unknown_relation_type(self): + """未知关系类型使用通用模板。""" + relations = [ + RelationSummary( + source_name="X", + source_type="T1", + target_name="Y", + target_type="T2", + relation_type="CUSTOM_REL", + ), + ] + result = textualize_subgraph([], relations) + assert "T1'X'与T2'Y'存在CUSTOM_REL关系" in result + + def test_orphan_entity_with_description(self): + """无关系的独立实体(有描述)。""" + entities = [ + EntitySummary(id="1", name="孤立实体", type="Dataset", description="这是一个测试实体"), + ] + result = textualize_subgraph(entities, []) + assert "Dataset'孤立实体': 这是一个测试实体" in result + + def test_orphan_entity_without_description(self): + """无关系的独立实体(无描述)。""" + entities = [ + EntitySummary(id="1", name="孤立实体", type="Dataset"), + ] + result = textualize_subgraph(entities, []) + assert "存在Dataset'孤立实体'" in result + + def test_empty_inputs(self): + result = textualize_subgraph([], []) + assert result == "" + + def test_entity_with_relation_not_orphan(self): + """有关系的实体不应出现在独立实体部分。""" + entities = [ + EntitySummary(id="1", name="A", type="Dataset"), + EntitySummary(id="2", name="B", type="Field"), + EntitySummary(id="3", name="C", type="Workflow"), + ] + relations = [ + RelationSummary( + source_name="A", + source_type="Dataset", + target_name="B", + target_type="Field", + relation_type="HAS_FIELD", + ), + ] + result = textualize_subgraph(entities, relations) + # A 和 B 有关系,不应作为独立实体出现 + # C 无关系,应出现 + assert "存在Workflow'C'" in result + lines = result.strip().split("\n") + assert len(lines) == 2 # 一条关系 + 一个独立实体 + + +# --------------------------------------------------------------------------- +# build_context 测试 +# --------------------------------------------------------------------------- + + +class TestBuildContext: + """build_context 函数的测试。""" + + def test_both_vector_and_graph(self): + chunks = [ + VectorChunk(id="1", text="文档片段一", score=0.9), + VectorChunk(id="2", text="文档片段二", score=0.8), + ] + graph_text = "Dataset'用户数据'包含字段Field'user_id'" + result = build_context(chunks, graph_text) + assert "## 相关文档" in result + assert "[1] 文档片段一" in result + assert "[2] 文档片段二" in result + assert "## 知识图谱上下文" in result + assert graph_text in result + + def test_vector_only(self): + chunks = [VectorChunk(id="1", text="文档片段", score=0.9)] + result = build_context(chunks, "") + assert "## 相关文档" in result + assert "## 知识图谱上下文" not in result + + def test_graph_only(self): + result = build_context([], "图谱内容") + assert "## 知识图谱上下文" in result + assert "## 相关文档" not in result + + def test_empty_both(self): + result = build_context([], "") + assert "未检索到相关上下文信息" in result + + def test_context_section_order(self): + """验证文档在图谱之前。""" + chunks = [VectorChunk(id="1", text="doc", score=0.9)] + result = build_context(chunks, "graph") + doc_pos = result.index("## 相关文档") + graph_pos = result.index("## 知识图谱上下文") + assert doc_pos < graph_pos diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_interface.py b/runtime/datamate-python/app/module/kg_graphrag/test_interface.py new file mode 100644 index 0000000..28f830b --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_interface.py @@ -0,0 +1,300 @@ +"""GraphRAG API 端点回归测试。 + +验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点 +的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。 +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.testclient import TestClient +from starlette.exceptions import HTTPException as StarletteHTTPException + +from app.exception import ( + fastapi_http_exception_handler, + starlette_http_exception_handler, + validation_exception_handler, +) +from app.module.kg_graphrag.interface import router +from app.module.kg_graphrag.models import ( + GraphContext, + RetrievalContext, +) + +# --------------------------------------------------------------------------- +# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器) +# --------------------------------------------------------------------------- + +_app = FastAPI() +_app.include_router(router, prefix="/api") +_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) +_app.add_exception_handler(HTTPException, fastapi_http_exception_handler) +_app.add_exception_handler(RequestValidationError, validation_exception_handler) + + +_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc" + +_VALID_BODY = { + "query": "测试查询", + "knowledge_base_id": "kb-1", + "collection_name": "test-collection", + "graph_id": _VALID_GRAPH_ID, +} + +_HEADERS = {"X-User-Id": "user-1"} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fake_retrieval_context() -> RetrievalContext: + return RetrievalContext( + vector_chunks=[], + graph_context=GraphContext(), + merged_text="test context", + ) + + +def _make_retriever_mock() -> AsyncMock: + m = AsyncMock() + m.retrieve = AsyncMock(return_value=_fake_retrieval_context()) + return m + + +def _make_generator_mock() -> AsyncMock: + m = AsyncMock() + m.generate = AsyncMock(return_value="test answer") + m.model_name = "test-model" + + async def _stream(*, query: str, context: str): # noqa: ARG001 + for token in ["hello", " ", "world"]: + yield token + + m.generate_stream = _stream + return m + + +def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock: + m = AsyncMock() + m.check_access = AsyncMock(return_value=access_granted) + return m + + +def _patch_all( + *, + access_granted: bool = True, + retriever: AsyncMock | None = None, + generator: AsyncMock | None = None, + validator: AsyncMock | None = None, +): + """返回 context manager,统一 patch 三个懒加载工厂函数。""" + retriever = retriever or _make_retriever_mock() + generator = generator or _make_generator_mock() + validator = validator or _make_kb_validator_mock(access_granted=access_granted) + + class _Ctx: + def __init__(self): + self.retriever = retriever + self.generator = generator + self.validator = validator + self._patches = [ + patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever), + patch("app.module.kg_graphrag.interface._get_generator", return_value=generator), + patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator), + ] + + def __enter__(self): + for p in self._patches: + p.__enter__() + return self + + def __exit__(self, *args): + for p in reversed(self._patches): + p.__exit__(*args) + + return _Ctx() + + +@pytest.fixture +def client(): + return TestClient(_app) + + +# --------------------------------------------------------------------------- +# POST /api/graphrag/query +# --------------------------------------------------------------------------- + + +class TestQueryEndpoint: + """POST /api/graphrag/query 端点测试。""" + + def test_success(self, client: TestClient): + """权限校验通过 + 检索 + 生成 → 200。""" + with _patch_all(access_granted=True) as ctx: + resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 200 + body = resp.json() + assert body["code"] == 200 + assert body["data"]["answer"] == "test answer" + assert body["data"]["model"] == "test-model" + ctx.retriever.retrieve.assert_awaited_once() + ctx.generator.generate.assert_awaited_once() + + def test_access_denied_returns_403(self, client: TestClient): + """check_access 返回 False → 403 + 标准错误格式。""" + with _patch_all(access_granted=False): + resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 403 + body = resp.json() + assert body["code"] == 403 + assert "kb-1" in body["data"]["detail"] + + def test_access_denied_skips_retrieval_and_generation(self, client: TestClient): + """权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。""" + with _patch_all(access_granted=False) as ctx: + resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 403 + ctx.retriever.retrieve.assert_not_called() + ctx.generator.generate.assert_not_called() + + def test_check_access_receives_collection_name(self, client: TestClient): + """验证 check_access 被调用时携带正确的 collection_name 参数。""" + with _patch_all(access_granted=True) as ctx: + resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 200 + ctx.validator.check_access.assert_awaited_once_with( + "kb-1", "user-1", collection_name="test-collection", + ) + + def test_missing_user_id_returns_422(self, client: TestClient): + """缺少 X-User-Id 请求头 → 422 验证错误。""" + with _patch_all(access_granted=True): + resp = client.post("/api/graphrag/query", json=_VALID_BODY) + + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# POST /api/graphrag/retrieve +# --------------------------------------------------------------------------- + + +class TestRetrieveEndpoint: + """POST /api/graphrag/retrieve 端点测试。""" + + def test_success(self, client: TestClient): + """权限通过 → 检索 → 返回 RetrievalContext。""" + with _patch_all(access_granted=True) as ctx: + resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 200 + body = resp.json() + assert body["code"] == 200 + assert body["data"]["merged_text"] == "test context" + ctx.retriever.retrieve.assert_awaited_once() + + def test_access_denied_returns_403(self, client: TestClient): + """权限拒绝 → 403。""" + with _patch_all(access_granted=False): + resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 403 + body = resp.json() + assert body["code"] == 403 + + def test_access_denied_skips_retrieval(self, client: TestClient): + """权限拒绝时不调用 retriever.retrieve。""" + with _patch_all(access_granted=False) as ctx: + resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 403 + ctx.retriever.retrieve.assert_not_called() + + def test_check_access_receives_collection_name(self, client: TestClient): + """验证 check_access 收到 collection_name 参数。""" + with _patch_all(access_granted=True) as ctx: + resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS) + + assert resp.status_code == 200 + ctx.validator.check_access.assert_awaited_once_with( + "kb-1", "user-1", collection_name="test-collection", + ) + + def test_missing_user_id_returns_422(self, client: TestClient): + """缺少 X-User-Id → 422。""" + with _patch_all(access_granted=True): + resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY) + + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# POST /api/graphrag/query/stream +# --------------------------------------------------------------------------- + + +class TestQueryStreamEndpoint: + """POST /api/graphrag/query/stream 端点测试。""" + + def test_success_returns_sse(self, client: TestClient): + """权限通过 → SSE 流式响应,包含 token 和 done 事件。""" + with _patch_all(access_granted=True): + resp = client.post( + "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, + ) + + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + text = resp.text + assert '"token"' in text + assert '"done": true' in text or '"done":true' in text + + def test_access_denied_returns_403(self, client: TestClient): + """权限拒绝 → 403。""" + with _patch_all(access_granted=False): + resp = client.post( + "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, + ) + + assert resp.status_code == 403 + body = resp.json() + assert body["code"] == 403 + + def test_access_denied_skips_retrieval_and_generation(self, client: TestClient): + """权限拒绝时不调用检索和生成。""" + with _patch_all(access_granted=False) as ctx: + resp = client.post( + "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, + ) + + assert resp.status_code == 403 + ctx.retriever.retrieve.assert_not_called() + + def test_check_access_receives_collection_name(self, client: TestClient): + """验证 check_access 收到 collection_name 参数。""" + with _patch_all(access_granted=True) as ctx: + resp = client.post( + "/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS, + ) + + assert resp.status_code == 200 + ctx.validator.check_access.assert_awaited_once_with( + "kb-1", "user-1", collection_name="test-collection", + ) + + def test_missing_user_id_returns_422(self, client: TestClient): + """缺少 X-User-Id → 422。""" + with _patch_all(access_granted=True): + resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY) + + assert resp.status_code == 422 diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py b/runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py new file mode 100644 index 0000000..8c85122 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py @@ -0,0 +1,330 @@ +"""知识库访问权限校验的单元测试。""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator + + +@pytest.fixture +def validator() -> KnowledgeBaseAccessValidator: + return KnowledgeBaseAccessValidator( + base_url="http://test-backend:8080/api", + timeout=5.0, + ) + + +def _run(coro): + return asyncio.run(coro) + + +_FAKE_REQUEST = httpx.Request("GET", "http://test") + + +def _resp(status_code: int, *, json=None, text=None) -> httpx.Response: + """创建带 request 的 httpx.Response。""" + if json is not None: + return httpx.Response(status_code, json=json, request=_FAKE_REQUEST) + return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST) + + +# --------------------------------------------------------------------------- +# check_access 测试 +# --------------------------------------------------------------------------- + + +class TestCheckAccess: + """check_access 方法的测试。""" + + def test_access_granted(self, validator: KnowledgeBaseAccessValidator): + """Java 返回 200 + code=200: 用户有权访问。""" + mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}}) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is True + + def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator): + """权限通过且 collection_name 与 KB name 一致:允许访问。""" + mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}}) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-1", "user-1", collection_name="my-collection", + )) + + assert result is True + + def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator): + """Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。""" + mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"}) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "other-user")) + + assert result is False + + def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator): + """Java 返回 HTTP 403。""" + mock_resp = _resp(403, text="Forbidden") + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is False + + def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator): + """知识库不存在,Java 返回 404。""" + mock_resp = _resp(404, text="Not Found") + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("nonexistent-kb", "user-1")) + + assert result is False + + def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator): + """Java 后端返回 500。""" + mock_resp = _resp(500, text="Internal Server Error") + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is False + + def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator): + """网络异常时 fail-close(拒绝访问),防止绕过权限校验。""" + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is False + + def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator): + """超时时 fail-close(拒绝访问)。""" + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout")) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is False + + def test_request_headers(self, validator: KnowledgeBaseAccessValidator): + """验证请求中携带正确的 X-User-Id header。""" + mock_resp = _resp(200, json={"code": 200, "data": {}}) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + _run(validator.check_access("kb-123", "user-456")) + + call_kwargs = mock_http.get.call_args + assert "/knowledge-base/kb-123" in call_kwargs.args[0] + assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456" + + def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator): + """跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。 + + 模拟 Java 后端返回权限不足的业务错误。 + """ + # 用户 A 创建的 KB,用户 B 请求访问 + mock_resp = _resp(200, json={ + "code": "sys.0005", + "message": "权限不足", + "data": None, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-user-a", "user-b")) + + assert result is False + + # 确认请求携带的是用户 B 的 ID + call_kwargs = mock_http.get.call_args + assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b" + + def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator): + """管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access("kb-user-a", "admin-user")) + + # Java 侧管理员校验通过,返回 200 + code=200 + assert result is True + + +# --------------------------------------------------------------------------- +# collection_name 绑定校验测试 +# --------------------------------------------------------------------------- + + +class TestCollectionNameBinding: + """collection_name 与 knowledge_base_id 的绑定校验测试。 + + 防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他 + 知识库的 Milvus 数据。 + """ + + def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator): + """KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-a", "name": "collection-a"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-a", "user-1", collection_name="collection-b", + )) + + assert result is False + + def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator): + """collection_name=None 时不做绑定校验(向后兼容)。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-1", "name": "some-name"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + # 不传 collection_name → 仅校验权限,不校验绑定 + result = _run(validator.check_access("kb-1", "user-1")) + + assert result is True + + def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator): + """Java 响应 data 中没有 name 字段:fail-close 拒绝。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-1"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-1", "user-1", collection_name="any-collection", + )) + + # data.name is None, doesn't match "any-collection" → denied + assert result is False + + def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator): + """Java 响应 data 为 null:fail-close 拒绝。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": None, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-1", "user-1", collection_name="any-collection", + )) + + assert result is False + + def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator): + """Java 响应 data 为空 dict {}:fail-close 拒绝。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-1", "user-1", collection_name="any-collection", + )) + + assert result is False + + def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator): + """用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的 + collection_name='kb-b-data':应被拒绝。 + + 这是核心越权场景的完整模拟。 + """ + # 用户有权访问 KB-A + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + # 但 collection_name 指向 KB-B 的数据 + result = _run(validator.check_access( + "kb-a", "user-1", collection_name="kb-b-data", + )) + + assert result is False + + def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator): + """中文 collection_name 精确匹配。""" + mock_resp = _resp(200, json={ + "code": 200, + "data": {"id": "kb-1", "name": "用户行为数据"}, + }) + with patch.object(validator, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + result = _run(validator.check_access( + "kb-1", "user-1", collection_name="用户行为数据", + )) + + assert result is True diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py b/runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py new file mode 100644 index 0000000..3fb4a41 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py @@ -0,0 +1,297 @@ +"""KG 服务 REST 客户端的单元测试。""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from app.module.kg_graphrag.kg_client import KGServiceClient + + +@pytest.fixture +def client() -> KGServiceClient: + return KGServiceClient( + base_url="http://test-kg:8080", + internal_token="test-token", + timeout=5.0, + ) + + +def _run(coro): + return asyncio.run(coro) + + +_FAKE_REQUEST = httpx.Request("GET", "http://test") + + +def _resp(status_code: int, *, json=None, text=None) -> httpx.Response: + """创建带 request 的 httpx.Response(raise_for_status 需要)。""" + if json is not None: + return httpx.Response(status_code, json=json, request=_FAKE_REQUEST) + return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST) + + +# --------------------------------------------------------------------------- +# fulltext_search 测试 +# --------------------------------------------------------------------------- + + +class TestFulltextSearch: + """fulltext_search 方法的测试。""" + + def test_wrapped_paged_response(self, client: KGServiceClient): + """Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}""" + mock_body = { + "code": 200, + "data": { + "page": 0, + "size": 20, + "totalElements": 2, + "totalPages": 1, + "content": [ + {"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5}, + {"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8}, + ], + }, + } + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1")) + + assert len(entities) == 2 + assert entities[0].id == "e1" + assert entities[0].name == "用户数据" + assert entities[0].type == "Dataset" + assert entities[1].name == "清洗管道" + + def test_unwrapped_paged_response(self, client: KGServiceClient): + """Java 直接返回 PagedResponse(无全局包装)。""" + mock_body = { + "page": 0, + "size": 10, + "totalElements": 1, + "totalPages": 1, + "content": [ + {"id": "e1", "name": "A", "type": "Dataset", "description": "desc"}, + ], + } + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities = _run(client.fulltext_search("graph-1", "A")) + + # body has no "data" key → fallback to body itself → read "content" + assert len(entities) == 1 + assert entities[0].name == "A" + + def test_empty_content(self, client: KGServiceClient): + mock_body = {"code": 200, "data": {"page": 0, "content": []}} + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities = _run(client.fulltext_search("graph-1", "nothing")) + + assert entities == [] + + def test_fail_open_on_http_error(self, client: KGServiceClient): + """HTTP 错误时 fail-open 返回空列表。""" + mock_resp = _resp(500, text="Internal Server Error") + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities = _run(client.fulltext_search("graph-1", "test")) + + assert entities == [] + + def test_fail_open_on_connection_error(self, client: KGServiceClient): + """连接错误时 fail-open 返回空列表。""" + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused")) + mock_get.return_value = mock_http + + entities = _run(client.fulltext_search("graph-1", "test")) + + assert entities == [] + + def test_request_headers(self, client: KGServiceClient): + """验证请求中携带正确的 headers。""" + mock_resp = _resp(200, json={"data": {"content": []}}) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.get = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + _run(client.fulltext_search("gid", "q", size=5, user_id="user-123")) + + call_kwargs = mock_http.get.call_args + assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token" + assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123" + assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5} + + +# --------------------------------------------------------------------------- +# get_subgraph 测试 +# --------------------------------------------------------------------------- + + +class TestGetSubgraph: + """get_subgraph 方法的测试。""" + + def test_wrapped_subgraph_response(self, client: KGServiceClient): + """Java 返回被全局包装的 SubgraphExportVO。""" + mock_body = { + "code": 200, + "data": { + "nodes": [ + {"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}}, + {"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}}, + ], + "edges": [ + { + "id": "edge1", + "sourceEntityId": "n1", + "targetEntityId": "n2", + "relationType": "HAS_FIELD", + "weight": 1.0, + "confidence": 0.9, + "sourceId": "kb-1", + }, + ], + "nodeCount": 2, + "edgeCount": 1, + }, + } + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.post = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1")) + + assert len(entities) == 2 + assert entities[0].name == "用户数据" + assert entities[1].name == "user_id" + + assert len(relations) == 1 + assert relations[0].source_name == "用户数据" + assert relations[0].target_name == "user_id" + assert relations[0].relation_type == "HAS_FIELD" + assert relations[0].source_type == "Dataset" + assert relations[0].target_type == "Field" + + def test_unwrapped_subgraph_response(self, client: KGServiceClient): + """Java 直接返回 SubgraphExportVO(无全局包装)。""" + mock_body = { + "nodes": [ + {"id": "n1", "name": "A", "type": "T1", "description": ""}, + ], + "edges": [], + "nodeCount": 1, + "edgeCount": 0, + } + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.post = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities, relations = _run(client.get_subgraph("gid", ["n1"])) + + assert len(entities) == 1 + assert entities[0].name == "A" + assert relations == [] + + def test_edge_with_unknown_entity(self, client: KGServiceClient): + """边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。""" + mock_body = { + "code": 200, + "data": { + "nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}], + "edges": [ + { + "sourceEntityId": "n1", + "targetEntityId": "n999", + "relationType": "DEPENDS_ON", + }, + ], + }, + } + mock_resp = _resp(200, json=mock_body) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.post = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities, relations = _run(client.get_subgraph("gid", ["n1"])) + + assert len(relations) == 1 + assert relations[0].source_name == "A" + assert relations[0].target_name == "n999" # fallback to ID + assert relations[0].target_type == "" + + def test_fail_open_on_error(self, client: KGServiceClient): + mock_resp = _resp(500, text="error") + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.post = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + entities, relations = _run(client.get_subgraph("gid", ["n1"])) + + assert entities == [] + assert relations == [] + + def test_request_params(self, client: KGServiceClient): + """验证子图请求参数正确传递。""" + mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}}) + with patch.object(client, "_get_client") as mock_get: + mock_http = AsyncMock() + mock_http.post = AsyncMock(return_value=mock_resp) + mock_get.return_value = mock_http + + _run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1")) + + call_kwargs = mock_http.post.call_args + assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0] + assert call_kwargs.kwargs["params"] == {"depth": 3} + assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]} + + +# --------------------------------------------------------------------------- +# headers 测试 +# --------------------------------------------------------------------------- + + +class TestHeaders: + def test_headers_with_token_and_user(self, client: KGServiceClient): + headers = client._headers(user_id="user-1") + assert headers["X-Internal-Token"] == "test-token" + assert headers["X-User-Id"] == "user-1" + + def test_headers_without_user(self, client: KGServiceClient): + headers = client._headers() + assert "X-Internal-Token" in headers + assert "X-User-Id" not in headers + + def test_headers_without_token(self): + c = KGServiceClient(base_url="http://test:8080", internal_token="") + headers = c._headers(user_id="u1") + assert "X-Internal-Token" not in headers + assert headers["X-User-Id"] == "u1" diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_milvus_client.py b/runtime/datamate-python/app/module/kg_graphrag/test_milvus_client.py new file mode 100644 index 0000000..3c8db25 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_milvus_client.py @@ -0,0 +1,145 @@ +"""Milvus 向量检索客户端的单元测试。""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever + + +@pytest.fixture +def retriever() -> MilvusVectorRetriever: + return MilvusVectorRetriever( + uri="http://test-milvus:19530", + embedding_model="text-embedding-test", + ) + + +def _run(coro): + return asyncio.run(coro) + + +# --------------------------------------------------------------------------- +# has_collection 测试 +# --------------------------------------------------------------------------- + + +class TestHasCollection: + def test_collection_exists(self, retriever: MilvusVectorRetriever): + mock_client = MagicMock() + mock_client.has_collection = MagicMock(return_value=True) + retriever._milvus_client = mock_client + + result = _run(retriever.has_collection("my_collection")) + + assert result is True + + def test_collection_not_exists(self, retriever: MilvusVectorRetriever): + mock_client = MagicMock() + mock_client.has_collection = MagicMock(return_value=False) + retriever._milvus_client = mock_client + + result = _run(retriever.has_collection("nonexistent")) + + assert result is False + + def test_fail_open_on_error(self, retriever: MilvusVectorRetriever): + mock_client = MagicMock() + mock_client.has_collection = MagicMock(side_effect=Exception("connection error")) + retriever._milvus_client = mock_client + + result = _run(retriever.has_collection("test")) + + assert result is False + + +# --------------------------------------------------------------------------- +# search 测试 +# --------------------------------------------------------------------------- + + +class TestSearch: + def test_successful_search(self, retriever: MilvusVectorRetriever): + """正常搜索返回 VectorChunk 列表。""" + mock_embeddings = AsyncMock() + mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3]) + retriever._embeddings = mock_embeddings + + mock_milvus = MagicMock() + mock_milvus.search = MagicMock(return_value=[ + [ + {"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}}, + {"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}}, + ] + ]) + retriever._milvus_client = mock_milvus + + chunks = _run(retriever.search("my_collection", "用户数据", top_k=5)) + + assert len(chunks) == 2 + assert chunks[0].id == "doc1" + assert chunks[0].text == "文档片段一" + assert chunks[0].score == 0.95 + assert chunks[0].metadata == {"source": "kb1"} + assert chunks[1].id == "doc2" + assert chunks[1].score == 0.82 + + def test_empty_results(self, retriever: MilvusVectorRetriever): + mock_embeddings = AsyncMock() + mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) + retriever._embeddings = mock_embeddings + + mock_milvus = MagicMock() + mock_milvus.search = MagicMock(return_value=[[]]) + retriever._milvus_client = mock_milvus + + chunks = _run(retriever.search("col", "query")) + + assert chunks == [] + + def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever): + """Embedding 失败时 fail-open 返回空列表。""" + mock_embeddings = AsyncMock() + mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error")) + retriever._embeddings = mock_embeddings + + chunks = _run(retriever.search("col", "query")) + + assert chunks == [] + + def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever): + """Milvus 搜索失败时 fail-open 返回空列表。""" + mock_embeddings = AsyncMock() + mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) + retriever._embeddings = mock_embeddings + + mock_milvus = MagicMock() + mock_milvus.search = MagicMock(side_effect=Exception("Milvus down")) + retriever._milvus_client = mock_milvus + + chunks = _run(retriever.search("col", "query")) + + assert chunks == [] + + def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever): + """验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。""" + mock_embeddings = AsyncMock() + mock_embeddings.aembed_query = AsyncMock(return_value=[0.1]) + retriever._embeddings = mock_embeddings + + mock_milvus = MagicMock() + mock_milvus.search = MagicMock(return_value=[[]]) + retriever._milvus_client = mock_milvus + + with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread: + mock_to_thread.return_value = [[]] + + chunks = _run(retriever.search("col", "query")) + + # asyncio.to_thread 应该被调用来包装同步 Milvus 调用 + mock_to_thread.assert_called_once() + call_args = mock_to_thread.call_args + assert call_args.args[0] == mock_milvus.search diff --git a/runtime/datamate-python/app/module/kg_graphrag/test_retriever.py b/runtime/datamate-python/app/module/kg_graphrag/test_retriever.py new file mode 100644 index 0000000..7e6d8c7 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_graphrag/test_retriever.py @@ -0,0 +1,234 @@ +"""GraphRAG 检索编排器的单元测试。""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.module.kg_graphrag.models import ( + EntitySummary, + RelationSummary, + RetrievalStrategy, + VectorChunk, +) +from app.module.kg_graphrag.retriever import GraphRAGRetriever + + +def _run(coro): + return asyncio.run(coro) + + +def _make_retriever( + *, + milvus_search_result: list[VectorChunk] | None = None, + milvus_has_collection: bool = True, + kg_fulltext_result: list[EntitySummary] | None = None, + kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None, +) -> GraphRAGRetriever: + """创建带 mock 依赖的 retriever。""" + mock_milvus = AsyncMock() + mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection) + mock_milvus.search = AsyncMock(return_value=milvus_search_result or []) + + mock_kg = AsyncMock() + mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or []) + mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], [])) + + return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg) + + +# --------------------------------------------------------------------------- +# retrieve 测试 +# --------------------------------------------------------------------------- + + +class TestRetrieve: + """retrieve 方法的测试。""" + + def test_both_vector_and_graph(self): + """同时启用向量和图谱检索。""" + chunks = [ + VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9), + VectorChunk(id="c2", text="其他内容", score=0.7), + ] + seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")] + entities = [ + EntitySummary(id="e1", name="用户数据", type="Dataset"), + EntitySummary(id="e2", name="user_id", type="Field"), + ] + relations = [ + RelationSummary( + source_name="用户数据", source_type="Dataset", + target_name="user_id", target_type="Field", + relation_type="HAS_FIELD", + ), + ] + retriever = _make_retriever( + milvus_search_result=chunks, + kg_fulltext_result=seed, + kg_subgraph_result=(entities, relations), + ) + + ctx = _run(retriever.retrieve( + query="用户数据有哪些字段", + collection_name="kb1", + graph_id="graph-1", + strategy=RetrievalStrategy(), + user_id="u1", + )) + + assert len(ctx.vector_chunks) == 2 + assert len(ctx.graph_context.entities) == 2 + assert len(ctx.graph_context.relations) == 1 + assert "用户数据" in ctx.graph_context.textualized + assert "## 相关文档" in ctx.merged_text + assert "## 知识图谱上下文" in ctx.merged_text + + def test_vector_only(self): + """仅启用向量检索。""" + chunks = [VectorChunk(id="c1", text="doc", score=0.9)] + retriever = _make_retriever(milvus_search_result=chunks) + strategy = RetrievalStrategy(enable_graph=False) + + ctx = _run(retriever.retrieve( + query="test", collection_name="kb", graph_id="g", + strategy=strategy, user_id="u", + )) + + assert len(ctx.vector_chunks) == 1 + assert ctx.graph_context.entities == [] + # KG client should not be called + retriever._kg.fulltext_search.assert_not_called() + + def test_graph_only(self): + """仅启用图谱检索。""" + seed = [EntitySummary(id="e1", name="A", type="T")] + entities = [EntitySummary(id="e1", name="A", type="T")] + retriever = _make_retriever( + kg_fulltext_result=seed, + kg_subgraph_result=(entities, []), + ) + strategy = RetrievalStrategy(enable_vector=False) + + ctx = _run(retriever.retrieve( + query="test", collection_name="kb", graph_id="g", + strategy=strategy, user_id="u", + )) + + assert ctx.vector_chunks == [] + assert len(ctx.graph_context.entities) == 1 + retriever._milvus.search.assert_not_called() + + def test_no_seed_entities(self): + """图谱全文检索无结果时,不调用子图查询。""" + retriever = _make_retriever(kg_fulltext_result=[]) + + ctx = _run(retriever.retrieve( + query="test", collection_name="kb", graph_id="g", + strategy=RetrievalStrategy(enable_vector=False), user_id="u", + )) + + assert ctx.graph_context.entities == [] + retriever._kg.get_subgraph.assert_not_called() + + def test_collection_not_found_skips_vector(self): + """collection 不存在时跳过向量检索。""" + retriever = _make_retriever(milvus_has_collection=False) + strategy = RetrievalStrategy(enable_graph=False) + + ctx = _run(retriever.retrieve( + query="test", collection_name="nonexistent", graph_id="g", + strategy=strategy, user_id="u", + )) + + assert ctx.vector_chunks == [] + retriever._milvus.search.assert_not_called() + + def test_both_empty(self): + """两条检索路径都无结果。""" + retriever = _make_retriever() + + ctx = _run(retriever.retrieve( + query="nothing", collection_name="kb", graph_id="g", + strategy=RetrievalStrategy(), user_id="u", + )) + + assert ctx.vector_chunks == [] + assert ctx.graph_context.entities == [] + assert "未检索到相关上下文信息" in ctx.merged_text + + def test_vector_error_fail_open(self): + """向量检索异常时 fail-open,图谱检索仍可正常返回。""" + retriever = _make_retriever() + retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down")) + + seed = [EntitySummary(id="e1", name="A", type="T")] + retriever._kg.fulltext_search = AsyncMock(return_value=seed) + retriever._kg.get_subgraph = AsyncMock( + return_value=([EntitySummary(id="e1", name="A", type="T")], []) + ) + + ctx = _run(retriever.retrieve( + query="test", collection_name="kb", graph_id="g", + strategy=RetrievalStrategy(), user_id="u", + )) + + # 向量检索失败,但图谱检索仍有结果 + assert ctx.vector_chunks == [] + assert len(ctx.graph_context.entities) == 1 + + +# --------------------------------------------------------------------------- +# _rank_results 测试 +# --------------------------------------------------------------------------- + + +class TestRankResults: + """_rank_results 方法的测试。""" + + def _make_retriever_instance(self) -> GraphRAGRetriever: + return GraphRAGRetriever( + milvus_client=MagicMock(), + kg_client=MagicMock(), + ) + + def test_empty_chunks(self): + r = self._make_retriever_instance() + result = r._rank_results([], [], [], RetrievalStrategy()) + assert result == [] + + def test_single_chunk(self): + r = self._make_retriever_instance() + chunks = [VectorChunk(id="1", text="text", score=0.9)] + result = r._rank_results(chunks, [], [], RetrievalStrategy()) + assert len(result) == 1 + assert result[0].id == "1" + + def test_graph_boost_reorders(self): + """图谱实体命中应提升文档片段排名。""" + r = self._make_retriever_instance() + # chunk1 向量分高但无图谱命中 + # chunk2 向量分低但命中图谱实体 + chunks = [ + VectorChunk(id="1", text="无关内容", score=0.9), + VectorChunk(id="2", text="包含用户数据的内容", score=0.5), + ] + entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")] + strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7) + + result = r._rank_results(chunks, entities, [], strategy) + + # chunk2 应该排在前面(graph_boost 更高) + assert result[0].id == "2" + + def test_all_same_score(self): + """所有 chunk 分数相同时不崩溃。""" + r = self._make_retriever_instance() + chunks = [ + VectorChunk(id="1", text="a", score=0.5), + VectorChunk(id="2", text="b", score=0.5), + ] + result = r._rank_results(chunks, [], [], RetrievalStrategy()) + assert len(result) == 2 diff --git a/runtime/datamate-python/pyproject.toml b/runtime/datamate-python/pyproject.toml index a83c23c..854760e 100644 --- a/runtime/datamate-python/pyproject.toml +++ b/runtime/datamate-python/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "sqlalchemy (>=2.0.45,<3.0.0)", "fastapi (>=0.124.0,<0.125.0)", "Pillow (>=11.0.0,<12.0.0)", + "pymilvus (>=2.5.0,<3.0.0)", ]