You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
@@ -88,6 +88,29 @@ class Settings(BaseSettings):
|
||||
kg_alignment_vector_threshold: float = 0.92
|
||||
kg_alignment_llm_threshold: float = 0.78
|
||||
|
||||
# GraphRAG 融合查询配置
|
||||
graphrag_enabled: bool = False
|
||||
graphrag_milvus_uri: str = "http://milvus-standalone:19530"
|
||||
graphrag_kg_service_url: str = "http://datamate-kg:8080"
|
||||
graphrag_kg_internal_token: str = ""
|
||||
|
||||
# GraphRAG - 检索策略默认值
|
||||
graphrag_vector_top_k: int = 5
|
||||
graphrag_graph_depth: int = 2
|
||||
graphrag_graph_max_entities: int = 20
|
||||
graphrag_vector_weight: float = 0.6
|
||||
graphrag_graph_weight: float = 0.4
|
||||
|
||||
# GraphRAG - LLM(空则复用 kg_llm_* 配置)
|
||||
graphrag_llm_model: str = ""
|
||||
graphrag_llm_base_url: Optional[str] = None
|
||||
graphrag_llm_api_key: SecretStr = SecretStr("EMPTY")
|
||||
graphrag_llm_temperature: float = 0.1
|
||||
graphrag_llm_timeout_seconds: int = 60
|
||||
|
||||
# GraphRAG - Embedding(空则复用 kg_alignment_embedding_* 配置)
|
||||
graphrag_embedding_model: str = ""
|
||||
|
||||
# 标注编辑器(Label Studio Editor)相关
|
||||
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from .evaluation.interface import router as evaluation_router
|
||||
from .collection.interface import router as collection_route
|
||||
from .dataset.interface import router as dataset_router
|
||||
from .kg_extraction.interface import router as kg_extraction_router
|
||||
from .kg_graphrag.interface import router as kg_graphrag_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api"
|
||||
@@ -21,5 +22,6 @@ router.include_router(evaluation_router)
|
||||
router.include_router(collection_route)
|
||||
router.include_router(dataset_router)
|
||||
router.include_router(kg_extraction_router)
|
||||
router.include_router(kg_graphrag_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""GraphRAG 融合查询模块。"""
|
||||
|
||||
from app.module.kg_graphrag.interface import router
|
||||
|
||||
__all__ = ["router"]
|
||||
@@ -0,0 +1,110 @@
|
||||
"""三元组文本化 + 上下文构建。
|
||||
|
||||
将图谱子图(实体 + 关系)转为自然语言描述,
|
||||
并与向量检索片段合并为 LLM 可消费的上下文文本。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
RelationSummary,
|
||||
VectorChunk,
|
||||
)
|
||||
|
||||
# 关系类型 -> 中文模板映射
|
||||
RELATION_TEMPLATES: dict[str, str] = {
|
||||
"HAS_FIELD": "{source}包含字段{target}",
|
||||
"DERIVED_FROM": "{source}来源于{target}",
|
||||
"USES_DATASET": "{source}使用了数据集{target}",
|
||||
"PRODUCES": "{source}产出了{target}",
|
||||
"ASSIGNED_TO": "{source}分配给了{target}",
|
||||
"BELONGS_TO": "{source}属于{target}",
|
||||
"TRIGGERS": "{source}触发了{target}",
|
||||
"DEPENDS_ON": "{source}依赖于{target}",
|
||||
"IMPACTS": "{source}影响了{target}",
|
||||
"SOURCED_FROM": "{source}的知识来源于{target}",
|
||||
}
|
||||
|
||||
# 通用模板(未在映射中的关系类型)
|
||||
_DEFAULT_TEMPLATE = "{source}与{target}存在{relation}关系"
|
||||
|
||||
|
||||
def textualize_subgraph(
|
||||
entities: list[EntitySummary],
|
||||
relations: list[RelationSummary],
|
||||
) -> str:
|
||||
"""将图谱子图转为自然语言描述。
|
||||
|
||||
Args:
|
||||
entities: 子图中的实体列表。
|
||||
relations: 子图中的关系列表。
|
||||
|
||||
Returns:
|
||||
文本化后的图谱描述,每条关系/实体一行。
|
||||
"""
|
||||
lines: list[str] = []
|
||||
|
||||
# 记录有关系的实体名称
|
||||
mentioned_entities: set[str] = set()
|
||||
|
||||
# 1. 对每条关系生成一句话
|
||||
for rel in relations:
|
||||
source_label = f"{rel.source_type}'{rel.source_name}'"
|
||||
target_label = f"{rel.target_type}'{rel.target_name}'"
|
||||
template = RELATION_TEMPLATES.get(rel.relation_type, _DEFAULT_TEMPLATE)
|
||||
line = template.format(
|
||||
source=source_label,
|
||||
target=target_label,
|
||||
relation=rel.relation_type,
|
||||
)
|
||||
lines.append(line)
|
||||
mentioned_entities.add(rel.source_name)
|
||||
mentioned_entities.add(rel.target_name)
|
||||
|
||||
# 2. 对独立实体(无关系)生成描述句
|
||||
for entity in entities:
|
||||
if entity.name not in mentioned_entities:
|
||||
desc = entity.description or ""
|
||||
if desc:
|
||||
lines.append(f"{entity.type}'{entity.name}': {desc}")
|
||||
else:
|
||||
lines.append(f"存在{entity.type}'{entity.name}'")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_context(
|
||||
vector_chunks: list[VectorChunk],
|
||||
graph_text: str,
|
||||
vector_weight: float = 0.6,
|
||||
graph_weight: float = 0.4,
|
||||
) -> str:
|
||||
"""合并向量检索片段和图谱文本化内容为 LLM 上下文。
|
||||
|
||||
Args:
|
||||
vector_chunks: 向量检索到的文档片段列表。
|
||||
graph_text: 文本化后的图谱描述。
|
||||
vector_weight: 向量分数权重(当前用于日志/调试,不影响上下文排序)。
|
||||
graph_weight: 图谱相关性权重。
|
||||
|
||||
Returns:
|
||||
合并后的上下文文本,分为「相关文档」和「知识图谱上下文」两个部分。
|
||||
"""
|
||||
sections: list[str] = []
|
||||
|
||||
# 向量检索片段
|
||||
if vector_chunks:
|
||||
doc_lines = ["## 相关文档"]
|
||||
for i, chunk in enumerate(vector_chunks, 1):
|
||||
doc_lines.append(f"[{i}] {chunk.text}")
|
||||
sections.append("\n".join(doc_lines))
|
||||
|
||||
# 图谱文本化内容
|
||||
if graph_text:
|
||||
sections.append(f"## 知识图谱上下文\n{graph_text}")
|
||||
|
||||
if not sections:
|
||||
return "(未检索到相关上下文信息)"
|
||||
|
||||
return "\n\n".join(sections)
|
||||
101
runtime/datamate-python/app/module/kg_graphrag/generator.py
Normal file
101
runtime/datamate-python/app/module/kg_graphrag/generator.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""LLM 生成器。
|
||||
|
||||
基于增强上下文(向量 + 图谱)调用 LLM 生成回答,
|
||||
支持同步和流式两种模式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"你是 DataMate 数据管理平台的智能助手。请根据以下上下文信息回答用户的问题。\n"
|
||||
"如果上下文中没有相关信息,请明确说明。不要编造信息。"
|
||||
)
|
||||
|
||||
|
||||
class GraphRAGGenerator:
|
||||
"""GraphRAG LLM 生成器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o-mini",
|
||||
base_url: str | None = None,
|
||||
api_key: SecretStr = SecretStr("EMPTY"),
|
||||
temperature: float = 0.1,
|
||||
timeout: int = 60,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._base_url = base_url
|
||||
self._api_key = api_key
|
||||
self._temperature = temperature
|
||||
self._timeout = timeout
|
||||
self._llm = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> GraphRAGGenerator:
|
||||
from app.core.config import settings
|
||||
|
||||
model = settings.graphrag_llm_model or settings.kg_llm_model
|
||||
base_url = settings.graphrag_llm_base_url or settings.kg_llm_base_url
|
||||
api_key = (
|
||||
settings.graphrag_llm_api_key
|
||||
if settings.graphrag_llm_api_key.get_secret_value() != "EMPTY"
|
||||
else settings.kg_llm_api_key
|
||||
)
|
||||
return cls(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
temperature=settings.graphrag_llm_temperature,
|
||||
timeout=settings.graphrag_llm_timeout_seconds,
|
||||
)
|
||||
|
||||
def _get_llm(self):
|
||||
if self._llm is None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
self._llm = ChatOpenAI(
|
||||
model=self._model,
|
||||
base_url=self._base_url,
|
||||
api_key=self._api_key,
|
||||
temperature=self._temperature,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._llm
|
||||
|
||||
def _build_messages(self, query: str, context: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{context}\n\n用户问题: {query}\n\n请基于上下文中的信息回答。",
|
||||
},
|
||||
]
|
||||
|
||||
async def generate(self, query: str, context: str) -> str:
|
||||
"""基于增强上下文生成回答。"""
|
||||
messages = self._build_messages(query, context)
|
||||
llm = self._get_llm()
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
|
||||
async def generate_stream(self, query: str, context: str) -> AsyncIterator[str]:
|
||||
"""基于增强上下文流式生成回答,逐 token 返回。"""
|
||||
messages = self._build_messages(query, context)
|
||||
llm = self._get_llm()
|
||||
async for chunk in llm.astream(messages):
|
||||
content = chunk.content
|
||||
if content:
|
||||
yield str(content)
|
||||
249
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
249
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""GraphRAG 融合查询 API 端点。
|
||||
|
||||
提供向量检索 + 知识图谱的融合查询能力:
|
||||
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
|
||||
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
|
||||
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
||||
from app.module.kg_graphrag.models import (
|
||||
GraphRAGQueryRequest,
|
||||
GraphRAGQueryResponse,
|
||||
RetrievalContext,
|
||||
)
|
||||
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||
from app.module.kg_graphrag.generator import GraphRAGGenerator
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 延迟初始化
|
||||
_retriever: GraphRAGRetriever | None = None
|
||||
_generator: GraphRAGGenerator | None = None
|
||||
_kb_validator: KnowledgeBaseAccessValidator | None = None
|
||||
|
||||
|
||||
def _get_retriever() -> GraphRAGRetriever:
|
||||
global _retriever
|
||||
if _retriever is None:
|
||||
_retriever = GraphRAGRetriever.from_settings()
|
||||
return _retriever
|
||||
|
||||
|
||||
def _get_generator() -> GraphRAGGenerator:
|
||||
global _generator
|
||||
if _generator is None:
|
||||
_generator = GraphRAGGenerator.from_settings()
|
||||
return _generator
|
||||
|
||||
|
||||
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
|
||||
global _kb_validator
|
||||
if _kb_validator is None:
|
||||
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
|
||||
return _kb_validator
|
||||
|
||||
|
||||
def _require_caller_id(
|
||||
x_user_id: Annotated[
|
||||
str,
|
||||
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
|
||||
],
|
||||
) -> str:
|
||||
caller = x_user_id.strip()
|
||||
if not caller:
|
||||
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
||||
return caller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P0: 完整 GraphRAG 查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=StandardResponse[GraphRAGQueryResponse],
|
||||
summary="GraphRAG 查询",
|
||||
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
|
||||
)
|
||||
async def query(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
generator = _get_generator()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
try:
|
||||
answer = await generator.generate(query=req.query, context=context.merged_text)
|
||||
except Exception:
|
||||
logger.exception("[%s] Generation failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
result = GraphRAGQueryResponse(
|
||||
answer=answer,
|
||||
context=context,
|
||||
model=generator.model_name,
|
||||
)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P1-1: 仅检索
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/retrieve",
|
||||
response_model=StandardResponse[RetrievalContext],
|
||||
summary="GraphRAG 仅检索",
|
||||
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
|
||||
)
|
||||
async def retrieve(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
return StandardResponse(code=200, message="success", data=context)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P1-4: 流式查询 (SSE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query/stream",
|
||||
summary="GraphRAG 流式查询",
|
||||
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
|
||||
)
|
||||
async def query_stream(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
generator = _get_generator()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
import json
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
async for token in generator.generate_stream(
|
||||
query=req.query, context=context.merged_text
|
||||
):
|
||||
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
|
||||
# 结束事件:附带检索上下文
|
||||
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
except Exception:
|
||||
logger.exception("[%s] Stream generation failed", trace_id)
|
||||
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""知识库访问权限校验。
|
||||
|
||||
在执行 GraphRAG 检索前,调用 Java rag-indexer-service 的
|
||||
GET /knowledge-base/{id} 端点验证当前用户是否有权访问该知识库。
|
||||
|
||||
Java 侧实现参考:KnowledgeBaseService.getKnowledgeBaseWithAccessCheck()
|
||||
- 查找 KB 是否存在
|
||||
- 校验 createdBy == currentUserId(管理员跳过)
|
||||
- 不满足则抛出 sys.0005 (INSUFFICIENT_PERMISSIONS)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KnowledgeBaseAccessValidator:
|
||||
"""通过 Java 后端校验用户是否有权访问指定知识库。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "http://datamate-backend:8080/api",
|
||||
timeout: float = 10.0,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KnowledgeBaseAccessValidator:
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(base_url=settings.datamate_backend_base_url)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def check_access(
|
||||
self,
|
||||
knowledge_base_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> bool:
|
||||
"""校验用户是否有权访问指定知识库。
|
||||
|
||||
调用 Java 后端 GET /knowledge-base/{id},该端点内部执行
|
||||
owner 校验(createdBy == currentUserId,管理员跳过)。
|
||||
|
||||
当 *collection_name* 不为 None 时,还会校验请求中的
|
||||
collection_name 与该知识库实际的 name 是否一致,防止
|
||||
用户提交合法 KB ID 但篡改 collection_name 来访问
|
||||
其他知识库的 Milvus 数据。
|
||||
|
||||
Returns:
|
||||
True — 用户有权访问且 collection_name 匹配
|
||||
False — 无权访问、collection_name 不匹配或校验失败
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"/knowledge-base/{knowledge_base_id}",
|
||||
headers={"X-User-Id": user_id},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
body = resp.json()
|
||||
# Java 全局包装: {"code": 200, "data": {...}}
|
||||
# code != 200 说明业务层拒绝(如权限不足)
|
||||
code = body.get("code", resp.status_code)
|
||||
if code != 200:
|
||||
logger.warning(
|
||||
"KB access denied: kb_id=%s, user=%s, biz_code=%s, msg=%s",
|
||||
knowledge_base_id, user_id, code, body.get("message", ""),
|
||||
)
|
||||
return False
|
||||
|
||||
# 校验 collection_name 与 KB 实际名称的绑定关系
|
||||
if collection_name is not None:
|
||||
data = body.get("data") or {}
|
||||
actual_name = data.get("name") if isinstance(data, dict) else None
|
||||
if actual_name != collection_name:
|
||||
logger.warning(
|
||||
"KB collection_name mismatch: kb_id=%s, "
|
||||
"expected=%s, actual=%s, user=%s",
|
||||
knowledge_base_id, collection_name,
|
||||
actual_name, user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
# HTTP 4xx/5xx
|
||||
logger.warning(
|
||||
"KB access check returned HTTP %d: kb_id=%s, user=%s",
|
||||
resp.status_code, knowledge_base_id, user_id,
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
# 网络异常时 fail-close:拒绝访问,防止绕过权限
|
||||
logger.exception(
|
||||
"KB access check failed (fail-close): kb_id=%s, user=%s",
|
||||
knowledge_base_id, user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
197
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
197
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""KG 服务 REST 客户端。
|
||||
|
||||
通过 httpx 调用 Java 侧 knowledge-graph-service 的查询 API,
|
||||
包括全文检索和子图导出。
|
||||
|
||||
失败策略:fail-open —— KG 服务不可用时返回空结果 + 日志告警。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.models import EntitySummary, RelationSummary
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KGServiceClient:
|
||||
"""Java KG 服务 REST 客户端。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "http://datamate-kg:8080",
|
||||
internal_token: str = "",
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._internal_token = internal_token
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KGServiceClient:
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(
|
||||
base_url=settings.graphrag_kg_service_url,
|
||||
internal_token=settings.graphrag_kg_internal_token,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _headers(self, user_id: str = "") -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
if self._internal_token:
|
||||
headers["X-Internal-Token"] = self._internal_token
|
||||
if user_id:
|
||||
headers["X-User-Id"] = user_id
|
||||
return headers
|
||||
|
||||
async def fulltext_search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
size: int = 10,
|
||||
user_id: str = "",
|
||||
) -> list[EntitySummary]:
|
||||
"""调用 KG 服务全文检索,返回匹配的实体列表。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空列表。
|
||||
"""
|
||||
try:
|
||||
return await self._fulltext_search_impl(graph_id, query, size, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG fulltext search failed for graph_id=%s (fail-open, returning empty)",
|
||||
graph_id,
|
||||
)
|
||||
return []
|
||||
|
||||
async def _fulltext_search_impl(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
size: int,
|
||||
user_id: str,
|
||||
) -> list[EntitySummary]:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"/knowledge-graph/{graph_id}/query/search",
|
||||
params={"q": query, "size": size},
|
||||
headers=self._headers(user_id),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
|
||||
# Java 返回 PagedResponse<SearchHitVO>:
|
||||
# 可能被全局包装为 {"code": 200, "data": PagedResponse}
|
||||
# 也可能直接返回 PagedResponse {"page": 0, "content": [...]}
|
||||
data = body.get("data", body)
|
||||
# PagedResponse 将实体列表放在 content 字段中
|
||||
items: list[dict] = (
|
||||
data.get("content", []) if isinstance(data, dict) else data if isinstance(data, list) else []
|
||||
)
|
||||
entities: list[EntitySummary] = []
|
||||
for item in items:
|
||||
entities.append(
|
||||
EntitySummary(
|
||||
id=str(item.get("id", "")),
|
||||
name=item.get("name", ""),
|
||||
type=item.get("type", ""),
|
||||
description=item.get("description", ""),
|
||||
)
|
||||
)
|
||||
return entities
|
||||
|
||||
async def get_subgraph(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_ids: list[str],
|
||||
depth: int = 1,
|
||||
user_id: str = "",
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
"""获取种子实体的 N-hop 子图。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空子图。
|
||||
"""
|
||||
try:
|
||||
return await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG subgraph export failed for graph_id=%s (fail-open, returning empty)",
|
||||
graph_id,
|
||||
)
|
||||
return [], []
|
||||
|
||||
async def _get_subgraph_impl(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_ids: list[str],
|
||||
depth: int,
|
||||
user_id: str,
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
client = self._get_client()
|
||||
resp = await client.post(
|
||||
f"/knowledge-graph/{graph_id}/query/subgraph/export",
|
||||
params={"depth": depth},
|
||||
json={"entityIds": entity_ids},
|
||||
headers=self._headers(user_id),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
|
||||
# Java 返回 SubgraphExportVO:
|
||||
# 可能被全局包装为 {"code": 200, "data": SubgraphExportVO}
|
||||
# 也可能直接返回 SubgraphExportVO {"nodes": [...], "edges": [...]}
|
||||
data = body.get("data", body) if isinstance(body.get("data"), dict) else body
|
||||
nodes_raw = data.get("nodes", [])
|
||||
edges_raw = data.get("edges", [])
|
||||
|
||||
# ExportNodeVO: id, name, type, description, properties (Map)
|
||||
entities: list[EntitySummary] = []
|
||||
for node in nodes_raw:
|
||||
entities.append(
|
||||
EntitySummary(
|
||||
id=str(node.get("id", "")),
|
||||
name=node.get("name", ""),
|
||||
type=node.get("type", ""),
|
||||
description=node.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
relations: list[RelationSummary] = []
|
||||
# 构建 id -> entity 的映射用于查找 source/target 名称和类型
|
||||
entity_map = {e.id: e for e in entities}
|
||||
# ExportEdgeVO: sourceEntityId, targetEntityId, relationType
|
||||
# 注意:sourceId 是数据来源 ID,不是源实体 ID
|
||||
for edge in edges_raw:
|
||||
source_id = str(edge.get("sourceEntityId", ""))
|
||||
target_id = str(edge.get("targetEntityId", ""))
|
||||
source_entity = entity_map.get(source_id)
|
||||
target_entity = entity_map.get(target_id)
|
||||
relations.append(
|
||||
RelationSummary(
|
||||
source_name=source_entity.name if source_entity else source_id,
|
||||
source_type=source_entity.type if source_entity else "",
|
||||
target_name=target_entity.name if target_entity else target_id,
|
||||
target_type=target_entity.type if target_entity else "",
|
||||
relation_type=edge.get("relationType", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return entities, relations
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Milvus 向量检索客户端。
|
||||
|
||||
通过 pymilvus 连接 Milvus,对查询文本进行 embedding 后执行混合搜索,
|
||||
返回 top-K 文档片段。
|
||||
|
||||
失败策略:fail-open —— Milvus 不可用时返回空列表 + 日志告警。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.models import VectorChunk
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MilvusVectorRetriever:
|
||||
"""Milvus 向量检索器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
uri: str = "http://milvus-standalone:19530",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
embedding_base_url: str | None = None,
|
||||
embedding_api_key: SecretStr = SecretStr("EMPTY"),
|
||||
) -> None:
|
||||
self._uri = uri
|
||||
self._embedding_model = embedding_model
|
||||
self._embedding_base_url = embedding_base_url
|
||||
self._embedding_api_key = embedding_api_key
|
||||
# Lazy init
|
||||
self._milvus_client = None
|
||||
self._embeddings = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> MilvusVectorRetriever:
|
||||
from app.core.config import settings
|
||||
|
||||
embedding_model = (
|
||||
settings.graphrag_embedding_model
|
||||
or settings.kg_alignment_embedding_model
|
||||
)
|
||||
return cls(
|
||||
uri=settings.graphrag_milvus_uri,
|
||||
embedding_model=embedding_model,
|
||||
embedding_base_url=settings.kg_llm_base_url,
|
||||
embedding_api_key=settings.kg_llm_api_key,
|
||||
)
|
||||
|
||||
def _get_embeddings(self):
|
||||
if self._embeddings is None:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
self._embeddings = OpenAIEmbeddings(
|
||||
model=self._embedding_model,
|
||||
base_url=self._embedding_base_url,
|
||||
api_key=self._embedding_api_key,
|
||||
)
|
||||
return self._embeddings
|
||||
|
||||
def _get_milvus_client(self):
|
||||
if self._milvus_client is None:
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
self._milvus_client = MilvusClient(uri=self._uri)
|
||||
logger.info("Connected to Milvus at %s", self._uri)
|
||||
return self._milvus_client
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
"""检查 Milvus 中是否存在指定 collection(防止越权访问不存在的库)。"""
|
||||
try:
|
||||
client = self._get_milvus_client()
|
||||
return await asyncio.to_thread(client.has_collection, collection_name)
|
||||
except Exception:
|
||||
logger.exception("Milvus has_collection check failed for %s", collection_name)
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[VectorChunk]:
|
||||
"""向量搜索:embed query -> Milvus search -> 返回 top-K 文档片段。
|
||||
|
||||
Fail-open: Milvus 不可用时返回空列表。
|
||||
"""
|
||||
try:
|
||||
return await self._search_impl(collection_name, query, top_k)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Milvus search failed for collection=%s (fail-open, returning empty)",
|
||||
collection_name,
|
||||
)
|
||||
return []
|
||||
|
||||
async def _search_impl(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> list[VectorChunk]:
|
||||
# 1. Embed query
|
||||
query_vector = await self._get_embeddings().aembed_query(query)
|
||||
|
||||
# 2. Milvus search(同步 I/O,通过 to_thread 避免阻塞事件循环)
|
||||
client = self._get_milvus_client()
|
||||
results = await asyncio.to_thread(
|
||||
client.search,
|
||||
collection_name=collection_name,
|
||||
data=[query_vector],
|
||||
limit=top_k,
|
||||
output_fields=["text", "metadata"],
|
||||
search_params={"metric_type": "COSINE", "params": {"nprobe": 16}},
|
||||
)
|
||||
|
||||
# 3. 转换为 VectorChunk
|
||||
chunks: list[VectorChunk] = []
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
entity = hit.get("entity", {})
|
||||
chunks.append(
|
||||
VectorChunk(
|
||||
id=str(hit.get("id", "")),
|
||||
text=entity.get("text", ""),
|
||||
score=float(hit.get("distance", 0.0)),
|
||||
metadata=entity.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""GraphRAG 融合查询的请求/响应数据模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RetrievalStrategy(BaseModel):
|
||||
"""检索策略配置。"""
|
||||
|
||||
vector_top_k: int = Field(default=5, ge=1, le=50, description="向量检索返回数")
|
||||
graph_depth: int = Field(default=2, ge=1, le=5, description="图谱扩展深度")
|
||||
graph_max_entities: int = Field(default=20, ge=1, le=100, description="图谱最大实体数")
|
||||
vector_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="向量分数权重")
|
||||
graph_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="图谱相关性权重")
|
||||
enable_graph: bool = Field(default=True, description="是否启用图谱检索")
|
||||
enable_vector: bool = Field(default=True, description="是否启用向量检索")
|
||||
|
||||
|
||||
class GraphRAGQueryRequest(BaseModel):
|
||||
"""GraphRAG 查询请求。"""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=2000,
|
||||
description="用户查询",
|
||||
)
|
||||
knowledge_base_id: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=64,
|
||||
description="知识库 ID,用于权限校验(由上游 Java 后端传入)",
|
||||
)
|
||||
collection_name: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
pattern=r"^[a-zA-Z0-9_\-\u4e00-\u9fff]+$",
|
||||
description="Milvus collection 名称(= 知识库名),仅允许字母、数字、下划线、连字符和中文",
|
||||
)
|
||||
graph_id: str = Field(
|
||||
...,
|
||||
pattern=r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
|
||||
description="Neo4j 图谱 ID(UUID 格式)",
|
||||
)
|
||||
strategy: RetrievalStrategy = Field(
|
||||
default_factory=RetrievalStrategy,
|
||||
description="可选策略覆盖",
|
||||
)
|
||||
|
||||
|
||||
class VectorChunk(BaseModel):
|
||||
"""向量检索到的文档片段。"""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
score: float
|
||||
metadata: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EntitySummary(BaseModel):
|
||||
"""实体摘要。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class RelationSummary(BaseModel):
|
||||
"""关系摘要。"""
|
||||
|
||||
source_name: str
|
||||
source_type: str
|
||||
target_name: str
|
||||
target_type: str
|
||||
relation_type: str
|
||||
|
||||
|
||||
class GraphContext(BaseModel):
|
||||
"""图谱上下文。"""
|
||||
|
||||
entities: list[EntitySummary] = Field(default_factory=list)
|
||||
relations: list[RelationSummary] = Field(default_factory=list)
|
||||
textualized: str = ""
|
||||
|
||||
|
||||
class RetrievalContext(BaseModel):
|
||||
"""检索上下文(检索结果的结构化表示)。"""
|
||||
|
||||
vector_chunks: list[VectorChunk] = Field(default_factory=list)
|
||||
graph_context: GraphContext = Field(default_factory=GraphContext)
|
||||
merged_text: str = ""
|
||||
|
||||
|
||||
class GraphRAGQueryResponse(BaseModel):
|
||||
"""GraphRAG 查询响应。"""
|
||||
|
||||
answer: str = Field(..., description="LLM 生成的回答")
|
||||
context: RetrievalContext = Field(..., description="检索上下文")
|
||||
model: str = Field(..., description="使用的 LLM 模型名")
|
||||
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""GraphRAG 检索编排器。
|
||||
|
||||
并行执行向量检索和图谱检索,融合排序后构建统一上下文。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph
|
||||
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
GraphContext,
|
||||
RelationSummary,
|
||||
RetrievalContext,
|
||||
RetrievalStrategy,
|
||||
VectorChunk,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GraphRAGRetriever:
|
||||
"""GraphRAG 检索编排器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
milvus_client: MilvusVectorRetriever,
|
||||
kg_client: KGServiceClient,
|
||||
) -> None:
|
||||
self._milvus = milvus_client
|
||||
self._kg = kg_client
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> GraphRAGRetriever:
|
||||
return cls(
|
||||
milvus_client=MilvusVectorRetriever.from_settings(),
|
||||
kg_client=KGServiceClient.from_settings(),
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
collection_name: str,
|
||||
graph_id: str,
|
||||
strategy: RetrievalStrategy,
|
||||
user_id: str = "",
|
||||
) -> RetrievalContext:
|
||||
"""并行执行向量检索 + 图谱检索,融合结果。"""
|
||||
# 构建并行任务
|
||||
tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
if strategy.enable_vector:
|
||||
# 先校验 collection 存在性,防止越权访问
|
||||
if not await self._milvus.has_collection(collection_name):
|
||||
logger.warning(
|
||||
"Collection %s not found, skipping vector retrieval",
|
||||
collection_name,
|
||||
)
|
||||
else:
|
||||
tasks["vector"] = asyncio.create_task(
|
||||
self._milvus.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
top_k=strategy.vector_top_k,
|
||||
)
|
||||
)
|
||||
|
||||
if strategy.enable_graph:
|
||||
tasks["graph"] = asyncio.create_task(
|
||||
self._retrieve_graph(
|
||||
query=query,
|
||||
graph_id=graph_id,
|
||||
strategy=strategy,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 等待所有任务完成
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks.values(), return_exceptions=True)
|
||||
|
||||
# 收集结果
|
||||
vector_chunks: list[VectorChunk] = []
|
||||
if "vector" in tasks:
|
||||
try:
|
||||
vector_chunks = tasks["vector"].result()
|
||||
except Exception:
|
||||
logger.exception("Vector retrieval task failed")
|
||||
|
||||
entities: list[EntitySummary] = []
|
||||
relations: list[RelationSummary] = []
|
||||
if "graph" in tasks:
|
||||
try:
|
||||
entities, relations = tasks["graph"].result()
|
||||
except Exception:
|
||||
logger.exception("Graph retrieval task failed")
|
||||
|
||||
# 融合排序
|
||||
vector_chunks = self._rank_results(
|
||||
vector_chunks, entities, relations, strategy
|
||||
)
|
||||
|
||||
# 三元组文本化
|
||||
graph_text = textualize_subgraph(entities, relations)
|
||||
|
||||
# 构建上下文
|
||||
merged_text = build_context(
|
||||
vector_chunks,
|
||||
graph_text,
|
||||
vector_weight=strategy.vector_weight,
|
||||
graph_weight=strategy.graph_weight,
|
||||
)
|
||||
|
||||
return RetrievalContext(
|
||||
vector_chunks=vector_chunks,
|
||||
graph_context=GraphContext(
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
textualized=graph_text,
|
||||
),
|
||||
merged_text=merged_text,
|
||||
)
|
||||
|
||||
async def _retrieve_graph(
|
||||
self,
|
||||
query: str,
|
||||
graph_id: str,
|
||||
strategy: RetrievalStrategy,
|
||||
user_id: str,
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
"""图谱检索:全文搜索 -> 种子实体 -> 子图扩展。"""
|
||||
# 1. 全文检索获取种子实体
|
||||
seed_entities = await self._kg.fulltext_search(
|
||||
graph_id=graph_id,
|
||||
query=query,
|
||||
size=strategy.graph_max_entities,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not seed_entities:
|
||||
logger.debug("No seed entities found for query: %s", query)
|
||||
return [], []
|
||||
|
||||
# 2. 获取种子实体的 N-hop 子图
|
||||
seed_ids = [e.id for e in seed_entities]
|
||||
entities, relations = await self._kg.get_subgraph(
|
||||
graph_id=graph_id,
|
||||
entity_ids=seed_ids,
|
||||
depth=strategy.graph_depth,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Graph retrieval: %d seed entities -> %d entities, %d relations",
|
||||
len(seed_entities), len(entities), len(relations),
|
||||
)
|
||||
return entities, relations
|
||||
|
||||
def _rank_results(
|
||||
self,
|
||||
vector_chunks: list[VectorChunk],
|
||||
entities: list[EntitySummary],
|
||||
relations: list[RelationSummary],
|
||||
strategy: RetrievalStrategy,
|
||||
) -> list[VectorChunk]:
|
||||
"""对向量检索结果进行融合排序。
|
||||
|
||||
基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。
|
||||
"""
|
||||
if not vector_chunks:
|
||||
return vector_chunks
|
||||
|
||||
# 向量分数归一化 (min-max scaling)
|
||||
scores = [c.score for c in vector_chunks]
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
score_range = max_score - min_score
|
||||
|
||||
# 构建图谱实体名称集合,用于关联度加分
|
||||
graph_entity_names = {e.name.lower() for e in entities}
|
||||
|
||||
ranked: list[tuple[float, VectorChunk]] = []
|
||||
for chunk in vector_chunks:
|
||||
# 归一化向量分数
|
||||
norm_score = (
|
||||
(chunk.score - min_score) / score_range
|
||||
if score_range > 0
|
||||
else 1.0
|
||||
)
|
||||
|
||||
# 图谱关联度加分:文档片段中提及图谱实体名称
|
||||
graph_boost = 0.0
|
||||
if graph_entity_names:
|
||||
chunk_text_lower = chunk.text.lower()
|
||||
mentioned = sum(
|
||||
1 for name in graph_entity_names if name in chunk_text_lower
|
||||
)
|
||||
graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0)
|
||||
|
||||
# 加权融合分数
|
||||
final_score = (
|
||||
strategy.vector_weight * norm_score
|
||||
+ strategy.graph_weight * graph_boost
|
||||
)
|
||||
ranked.append((final_score, chunk))
|
||||
|
||||
# 按融合分数降序排序
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
return [chunk for _, chunk in ranked]
|
||||
@@ -0,0 +1,182 @@
|
||||
"""三元组文本化 + 上下文构建的单元测试。"""
|
||||
|
||||
from app.module.kg_graphrag.context_builder import (
|
||||
RELATION_TEMPLATES,
|
||||
build_context,
|
||||
textualize_subgraph,
|
||||
)
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
RelationSummary,
|
||||
VectorChunk,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# textualize_subgraph 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextualizeSubgraph:
|
||||
"""textualize_subgraph 函数的测试。"""
|
||||
|
||||
def test_single_relation(self):
|
||||
entities = [
|
||||
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
|
||||
EntitySummary(id="2", name="user_id", type="Field"),
|
||||
]
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="用户行为数据",
|
||||
source_type="Dataset",
|
||||
target_name="user_id",
|
||||
target_type="Field",
|
||||
relation_type="HAS_FIELD",
|
||||
),
|
||||
]
|
||||
result = textualize_subgraph(entities, relations)
|
||||
assert "Dataset'用户行为数据'包含字段Field'user_id'" in result
|
||||
|
||||
def test_multiple_relations(self):
|
||||
entities = [
|
||||
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
|
||||
EntitySummary(id="2", name="清洗管道", type="Workflow"),
|
||||
]
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="清洗管道",
|
||||
source_type="Workflow",
|
||||
target_name="用户行为数据",
|
||||
target_type="Dataset",
|
||||
relation_type="USES_DATASET",
|
||||
),
|
||||
RelationSummary(
|
||||
source_name="用户行为数据",
|
||||
source_type="Dataset",
|
||||
target_name="外部系统",
|
||||
target_type="DataSource",
|
||||
relation_type="SOURCED_FROM",
|
||||
),
|
||||
]
|
||||
result = textualize_subgraph(entities, relations)
|
||||
assert "Workflow'清洗管道'使用了数据集Dataset'用户行为数据'" in result
|
||||
assert "Dataset'用户行为数据'的知识来源于DataSource'外部系统'" in result
|
||||
|
||||
def test_all_relation_templates(self):
|
||||
"""验证所有 10 种关系模板都能正确生成。"""
|
||||
for rel_type, template in RELATION_TEMPLATES.items():
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="A",
|
||||
source_type="TypeA",
|
||||
target_name="B",
|
||||
target_type="TypeB",
|
||||
relation_type=rel_type,
|
||||
),
|
||||
]
|
||||
result = textualize_subgraph([], relations)
|
||||
assert "TypeA'A'" in result
|
||||
assert "TypeB'B'" in result
|
||||
assert result # 非空
|
||||
|
||||
def test_unknown_relation_type(self):
|
||||
"""未知关系类型使用通用模板。"""
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="X",
|
||||
source_type="T1",
|
||||
target_name="Y",
|
||||
target_type="T2",
|
||||
relation_type="CUSTOM_REL",
|
||||
),
|
||||
]
|
||||
result = textualize_subgraph([], relations)
|
||||
assert "T1'X'与T2'Y'存在CUSTOM_REL关系" in result
|
||||
|
||||
def test_orphan_entity_with_description(self):
|
||||
"""无关系的独立实体(有描述)。"""
|
||||
entities = [
|
||||
EntitySummary(id="1", name="孤立实体", type="Dataset", description="这是一个测试实体"),
|
||||
]
|
||||
result = textualize_subgraph(entities, [])
|
||||
assert "Dataset'孤立实体': 这是一个测试实体" in result
|
||||
|
||||
def test_orphan_entity_without_description(self):
|
||||
"""无关系的独立实体(无描述)。"""
|
||||
entities = [
|
||||
EntitySummary(id="1", name="孤立实体", type="Dataset"),
|
||||
]
|
||||
result = textualize_subgraph(entities, [])
|
||||
assert "存在Dataset'孤立实体'" in result
|
||||
|
||||
def test_empty_inputs(self):
|
||||
result = textualize_subgraph([], [])
|
||||
assert result == ""
|
||||
|
||||
def test_entity_with_relation_not_orphan(self):
|
||||
"""有关系的实体不应出现在独立实体部分。"""
|
||||
entities = [
|
||||
EntitySummary(id="1", name="A", type="Dataset"),
|
||||
EntitySummary(id="2", name="B", type="Field"),
|
||||
EntitySummary(id="3", name="C", type="Workflow"),
|
||||
]
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="A",
|
||||
source_type="Dataset",
|
||||
target_name="B",
|
||||
target_type="Field",
|
||||
relation_type="HAS_FIELD",
|
||||
),
|
||||
]
|
||||
result = textualize_subgraph(entities, relations)
|
||||
# A 和 B 有关系,不应作为独立实体出现
|
||||
# C 无关系,应出现
|
||||
assert "存在Workflow'C'" in result
|
||||
lines = result.strip().split("\n")
|
||||
assert len(lines) == 2 # 一条关系 + 一个独立实体
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_context 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildContext:
|
||||
"""build_context 函数的测试。"""
|
||||
|
||||
def test_both_vector_and_graph(self):
|
||||
chunks = [
|
||||
VectorChunk(id="1", text="文档片段一", score=0.9),
|
||||
VectorChunk(id="2", text="文档片段二", score=0.8),
|
||||
]
|
||||
graph_text = "Dataset'用户数据'包含字段Field'user_id'"
|
||||
result = build_context(chunks, graph_text)
|
||||
assert "## 相关文档" in result
|
||||
assert "[1] 文档片段一" in result
|
||||
assert "[2] 文档片段二" in result
|
||||
assert "## 知识图谱上下文" in result
|
||||
assert graph_text in result
|
||||
|
||||
def test_vector_only(self):
|
||||
chunks = [VectorChunk(id="1", text="文档片段", score=0.9)]
|
||||
result = build_context(chunks, "")
|
||||
assert "## 相关文档" in result
|
||||
assert "## 知识图谱上下文" not in result
|
||||
|
||||
def test_graph_only(self):
|
||||
result = build_context([], "图谱内容")
|
||||
assert "## 知识图谱上下文" in result
|
||||
assert "## 相关文档" not in result
|
||||
|
||||
def test_empty_both(self):
|
||||
result = build_context([], "")
|
||||
assert "未检索到相关上下文信息" in result
|
||||
|
||||
def test_context_section_order(self):
|
||||
"""验证文档在图谱之前。"""
|
||||
chunks = [VectorChunk(id="1", text="doc", score=0.9)]
|
||||
result = build_context(chunks, "graph")
|
||||
doc_pos = result.index("## 相关文档")
|
||||
graph_pos = result.index("## 知识图谱上下文")
|
||||
assert doc_pos < graph_pos
|
||||
300
runtime/datamate-python/app/module/kg_graphrag/test_interface.py
Normal file
300
runtime/datamate-python/app/module/kg_graphrag/test_interface.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""GraphRAG API 端点回归测试。
|
||||
|
||||
验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点
|
||||
的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.exception import (
|
||||
fastapi_http_exception_handler,
|
||||
starlette_http_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
from app.module.kg_graphrag.interface import router
|
||||
from app.module.kg_graphrag.models import (
|
||||
GraphContext,
|
||||
RetrievalContext,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_app = FastAPI()
|
||||
_app.include_router(router, prefix="/api")
|
||||
_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler)
|
||||
_app.add_exception_handler(HTTPException, fastapi_http_exception_handler)
|
||||
_app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
||||
_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc"
|
||||
|
||||
_VALID_BODY = {
|
||||
"query": "测试查询",
|
||||
"knowledge_base_id": "kb-1",
|
||||
"collection_name": "test-collection",
|
||||
"graph_id": _VALID_GRAPH_ID,
|
||||
}
|
||||
|
||||
_HEADERS = {"X-User-Id": "user-1"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_retrieval_context() -> RetrievalContext:
|
||||
return RetrievalContext(
|
||||
vector_chunks=[],
|
||||
graph_context=GraphContext(),
|
||||
merged_text="test context",
|
||||
)
|
||||
|
||||
|
||||
def _make_retriever_mock() -> AsyncMock:
|
||||
m = AsyncMock()
|
||||
m.retrieve = AsyncMock(return_value=_fake_retrieval_context())
|
||||
return m
|
||||
|
||||
|
||||
def _make_generator_mock() -> AsyncMock:
|
||||
m = AsyncMock()
|
||||
m.generate = AsyncMock(return_value="test answer")
|
||||
m.model_name = "test-model"
|
||||
|
||||
async def _stream(*, query: str, context: str): # noqa: ARG001
|
||||
for token in ["hello", " ", "world"]:
|
||||
yield token
|
||||
|
||||
m.generate_stream = _stream
|
||||
return m
|
||||
|
||||
|
||||
def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock:
|
||||
m = AsyncMock()
|
||||
m.check_access = AsyncMock(return_value=access_granted)
|
||||
return m
|
||||
|
||||
|
||||
def _patch_all(
|
||||
*,
|
||||
access_granted: bool = True,
|
||||
retriever: AsyncMock | None = None,
|
||||
generator: AsyncMock | None = None,
|
||||
validator: AsyncMock | None = None,
|
||||
):
|
||||
"""返回 context manager,统一 patch 三个懒加载工厂函数。"""
|
||||
retriever = retriever or _make_retriever_mock()
|
||||
generator = generator or _make_generator_mock()
|
||||
validator = validator or _make_kb_validator_mock(access_granted=access_granted)
|
||||
|
||||
class _Ctx:
|
||||
def __init__(self):
|
||||
self.retriever = retriever
|
||||
self.generator = generator
|
||||
self.validator = validator
|
||||
self._patches = [
|
||||
patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever),
|
||||
patch("app.module.kg_graphrag.interface._get_generator", return_value=generator),
|
||||
patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator),
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
for p in self._patches:
|
||||
p.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
for p in reversed(self._patches):
|
||||
p.__exit__(*args)
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_app)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/graphrag/query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueryEndpoint:
|
||||
"""POST /api/graphrag/query 端点测试。"""
|
||||
|
||||
def test_success(self, client: TestClient):
|
||||
"""权限校验通过 + 检索 + 生成 → 200。"""
|
||||
with _patch_all(access_granted=True) as ctx:
|
||||
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["code"] == 200
|
||||
assert body["data"]["answer"] == "test answer"
|
||||
assert body["data"]["model"] == "test-model"
|
||||
ctx.retriever.retrieve.assert_awaited_once()
|
||||
ctx.generator.generate.assert_awaited_once()
|
||||
|
||||
def test_access_denied_returns_403(self, client: TestClient):
|
||||
"""check_access 返回 False → 403 + 标准错误格式。"""
|
||||
with _patch_all(access_granted=False):
|
||||
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 403
|
||||
body = resp.json()
|
||||
assert body["code"] == 403
|
||||
assert "kb-1" in body["data"]["detail"]
|
||||
|
||||
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
||||
"""权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。"""
|
||||
with _patch_all(access_granted=False) as ctx:
|
||||
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 403
|
||||
ctx.retriever.retrieve.assert_not_called()
|
||||
ctx.generator.generate.assert_not_called()
|
||||
|
||||
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||
"""验证 check_access 被调用时携带正确的 collection_name 参数。"""
|
||||
with _patch_all(access_granted=True) as ctx:
|
||||
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 200
|
||||
ctx.validator.check_access.assert_awaited_once_with(
|
||||
"kb-1", "user-1", collection_name="test-collection",
|
||||
)
|
||||
|
||||
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||
"""缺少 X-User-Id 请求头 → 422 验证错误。"""
|
||||
with _patch_all(access_granted=True):
|
||||
resp = client.post("/api/graphrag/query", json=_VALID_BODY)
|
||||
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/graphrag/retrieve
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetrieveEndpoint:
|
||||
"""POST /api/graphrag/retrieve 端点测试。"""
|
||||
|
||||
def test_success(self, client: TestClient):
|
||||
"""权限通过 → 检索 → 返回 RetrievalContext。"""
|
||||
with _patch_all(access_granted=True) as ctx:
|
||||
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["code"] == 200
|
||||
assert body["data"]["merged_text"] == "test context"
|
||||
ctx.retriever.retrieve.assert_awaited_once()
|
||||
|
||||
def test_access_denied_returns_403(self, client: TestClient):
|
||||
"""权限拒绝 → 403。"""
|
||||
with _patch_all(access_granted=False):
|
||||
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 403
|
||||
body = resp.json()
|
||||
assert body["code"] == 403
|
||||
|
||||
def test_access_denied_skips_retrieval(self, client: TestClient):
|
||||
"""权限拒绝时不调用 retriever.retrieve。"""
|
||||
with _patch_all(access_granted=False) as ctx:
|
||||
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 403
|
||||
ctx.retriever.retrieve.assert_not_called()
|
||||
|
||||
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||
"""验证 check_access 收到 collection_name 参数。"""
|
||||
with _patch_all(access_granted=True) as ctx:
|
||||
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
|
||||
|
||||
assert resp.status_code == 200
|
||||
ctx.validator.check_access.assert_awaited_once_with(
|
||||
"kb-1", "user-1", collection_name="test-collection",
|
||||
)
|
||||
|
||||
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||
"""缺少 X-User-Id → 422。"""
|
||||
with _patch_all(access_granted=True):
|
||||
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY)
|
||||
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/graphrag/query/stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueryStreamEndpoint:
|
||||
"""POST /api/graphrag/query/stream 端点测试。"""
|
||||
|
||||
def test_success_returns_sse(self, client: TestClient):
|
||||
"""权限通过 → SSE 流式响应,包含 token 和 done 事件。"""
|
||||
with _patch_all(access_granted=True):
|
||||
resp = client.post(
|
||||
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
text = resp.text
|
||||
assert '"token"' in text
|
||||
assert '"done": true' in text or '"done":true' in text
|
||||
|
||||
def test_access_denied_returns_403(self, client: TestClient):
|
||||
"""权限拒绝 → 403。"""
|
||||
with _patch_all(access_granted=False):
|
||||
resp = client.post(
|
||||
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
body = resp.json()
|
||||
assert body["code"] == 403
|
||||
|
||||
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
|
||||
"""权限拒绝时不调用检索和生成。"""
|
||||
with _patch_all(access_granted=False) as ctx:
|
||||
resp = client.post(
|
||||
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||
)
|
||||
|
||||
assert resp.status_code == 403
|
||||
ctx.retriever.retrieve.assert_not_called()
|
||||
|
||||
def test_check_access_receives_collection_name(self, client: TestClient):
|
||||
"""验证 check_access 收到 collection_name 参数。"""
|
||||
with _patch_all(access_granted=True) as ctx:
|
||||
resp = client.post(
|
||||
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
ctx.validator.check_access.assert_awaited_once_with(
|
||||
"kb-1", "user-1", collection_name="test-collection",
|
||||
)
|
||||
|
||||
def test_missing_user_id_returns_422(self, client: TestClient):
|
||||
"""缺少 X-User-Id → 422。"""
|
||||
with _patch_all(access_granted=True):
|
||||
resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY)
|
||||
|
||||
assert resp.status_code == 422
|
||||
330
runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py
Normal file
330
runtime/datamate-python/app/module/kg_graphrag/test_kb_access.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""知识库访问权限校验的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def validator() -> KnowledgeBaseAccessValidator:
|
||||
return KnowledgeBaseAccessValidator(
|
||||
base_url="http://test-backend:8080/api",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
||||
|
||||
|
||||
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
||||
"""创建带 request 的 httpx.Response。"""
|
||||
if json is not None:
|
||||
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
||||
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_access 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckAccess:
|
||||
"""check_access 方法的测试。"""
|
||||
|
||||
def test_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 返回 200 + code=200: 用户有权访问。"""
|
||||
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""权限通过且 collection_name 与 KB name 一致:允许访问。"""
|
||||
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-1", "user-1", collection_name="my-collection",
|
||||
))
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。"""
|
||||
mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "other-user"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 返回 HTTP 403。"""
|
||||
mock_resp = _resp(403, text="Forbidden")
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""知识库不存在,Java 返回 404。"""
|
||||
mock_resp = _resp(404, text="Not Found")
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("nonexistent-kb", "user-1"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 后端返回 500。"""
|
||||
mock_resp = _resp(500, text="Internal Server Error")
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""网络异常时 fail-close(拒绝访问),防止绕过权限校验。"""
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""超时时 fail-close(拒绝访问)。"""
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout"))
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_request_headers(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""验证请求中携带正确的 X-User-Id header。"""
|
||||
mock_resp = _resp(200, json={"code": 200, "data": {}})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
_run(validator.check_access("kb-123", "user-456"))
|
||||
|
||||
call_kwargs = mock_http.get.call_args
|
||||
assert "/knowledge-base/kb-123" in call_kwargs.args[0]
|
||||
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456"
|
||||
|
||||
def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。
|
||||
|
||||
模拟 Java 后端返回权限不足的业务错误。
|
||||
"""
|
||||
# 用户 A 创建的 KB,用户 B 请求访问
|
||||
mock_resp = _resp(200, json={
|
||||
"code": "sys.0005",
|
||||
"message": "权限不足",
|
||||
"data": None,
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-user-a", "user-b"))
|
||||
|
||||
assert result is False
|
||||
|
||||
# 确认请求携带的是用户 B 的 ID
|
||||
call_kwargs = mock_http.get.call_args
|
||||
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b"
|
||||
|
||||
def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access("kb-user-a", "admin-user"))
|
||||
|
||||
# Java 侧管理员校验通过,返回 200 + code=200
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# collection_name 绑定校验测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCollectionNameBinding:
|
||||
"""collection_name 与 knowledge_base_id 的绑定校验测试。
|
||||
|
||||
防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他
|
||||
知识库的 Milvus 数据。
|
||||
"""
|
||||
|
||||
def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-a", "name": "collection-a"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-a", "user-1", collection_name="collection-b",
|
||||
))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""collection_name=None 时不做绑定校验(向后兼容)。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-1", "name": "some-name"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
# 不传 collection_name → 仅校验权限,不校验绑定
|
||||
result = _run(validator.check_access("kb-1", "user-1"))
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 响应 data 中没有 name 字段:fail-close 拒绝。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-1"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-1", "user-1", collection_name="any-collection",
|
||||
))
|
||||
|
||||
# data.name is None, doesn't match "any-collection" → denied
|
||||
assert result is False
|
||||
|
||||
def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 响应 data 为 null:fail-close 拒绝。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": None,
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-1", "user-1", collection_name="any-collection",
|
||||
))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""Java 响应 data 为空 dict {}:fail-close 拒绝。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-1", "user-1", collection_name="any-collection",
|
||||
))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的
|
||||
collection_name='kb-b-data':应被拒绝。
|
||||
|
||||
这是核心越权场景的完整模拟。
|
||||
"""
|
||||
# 用户有权访问 KB-A
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
# 但 collection_name 指向 KB-B 的数据
|
||||
result = _run(validator.check_access(
|
||||
"kb-a", "user-1", collection_name="kb-b-data",
|
||||
))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator):
|
||||
"""中文 collection_name 精确匹配。"""
|
||||
mock_resp = _resp(200, json={
|
||||
"code": 200,
|
||||
"data": {"id": "kb-1", "name": "用户行为数据"},
|
||||
})
|
||||
with patch.object(validator, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
result = _run(validator.check_access(
|
||||
"kb-1", "user-1", collection_name="用户行为数据",
|
||||
))
|
||||
|
||||
assert result is True
|
||||
297
runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py
Normal file
297
runtime/datamate-python/app/module/kg_graphrag/test_kg_client.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""KG 服务 REST 客户端的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> KGServiceClient:
|
||||
return KGServiceClient(
|
||||
base_url="http://test-kg:8080",
|
||||
internal_token="test-token",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
_FAKE_REQUEST = httpx.Request("GET", "http://test")
|
||||
|
||||
|
||||
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
|
||||
"""创建带 request 的 httpx.Response(raise_for_status 需要)。"""
|
||||
if json is not None:
|
||||
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
|
||||
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fulltext_search 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFulltextSearch:
|
||||
"""fulltext_search 方法的测试。"""
|
||||
|
||||
def test_wrapped_paged_response(self, client: KGServiceClient):
|
||||
"""Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}"""
|
||||
mock_body = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"page": 0,
|
||||
"size": 20,
|
||||
"totalElements": 2,
|
||||
"totalPages": 1,
|
||||
"content": [
|
||||
{"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5},
|
||||
{"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8},
|
||||
],
|
||||
},
|
||||
}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1"))
|
||||
|
||||
assert len(entities) == 2
|
||||
assert entities[0].id == "e1"
|
||||
assert entities[0].name == "用户数据"
|
||||
assert entities[0].type == "Dataset"
|
||||
assert entities[1].name == "清洗管道"
|
||||
|
||||
def test_unwrapped_paged_response(self, client: KGServiceClient):
|
||||
"""Java 直接返回 PagedResponse(无全局包装)。"""
|
||||
mock_body = {
|
||||
"page": 0,
|
||||
"size": 10,
|
||||
"totalElements": 1,
|
||||
"totalPages": 1,
|
||||
"content": [
|
||||
{"id": "e1", "name": "A", "type": "Dataset", "description": "desc"},
|
||||
],
|
||||
}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities = _run(client.fulltext_search("graph-1", "A"))
|
||||
|
||||
# body has no "data" key → fallback to body itself → read "content"
|
||||
assert len(entities) == 1
|
||||
assert entities[0].name == "A"
|
||||
|
||||
def test_empty_content(self, client: KGServiceClient):
|
||||
mock_body = {"code": 200, "data": {"page": 0, "content": []}}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities = _run(client.fulltext_search("graph-1", "nothing"))
|
||||
|
||||
assert entities == []
|
||||
|
||||
def test_fail_open_on_http_error(self, client: KGServiceClient):
|
||||
"""HTTP 错误时 fail-open 返回空列表。"""
|
||||
mock_resp = _resp(500, text="Internal Server Error")
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities = _run(client.fulltext_search("graph-1", "test"))
|
||||
|
||||
assert entities == []
|
||||
|
||||
def test_fail_open_on_connection_error(self, client: KGServiceClient):
|
||||
"""连接错误时 fail-open 返回空列表。"""
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities = _run(client.fulltext_search("graph-1", "test"))
|
||||
|
||||
assert entities == []
|
||||
|
||||
def test_request_headers(self, client: KGServiceClient):
|
||||
"""验证请求中携带正确的 headers。"""
|
||||
mock_resp = _resp(200, json={"data": {"content": []}})
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.get = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
_run(client.fulltext_search("gid", "q", size=5, user_id="user-123"))
|
||||
|
||||
call_kwargs = mock_http.get.call_args
|
||||
assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token"
|
||||
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123"
|
||||
assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_subgraph 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSubgraph:
|
||||
"""get_subgraph 方法的测试。"""
|
||||
|
||||
def test_wrapped_subgraph_response(self, client: KGServiceClient):
|
||||
"""Java 返回被全局包装的 SubgraphExportVO。"""
|
||||
mock_body = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"nodes": [
|
||||
{"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}},
|
||||
{"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "edge1",
|
||||
"sourceEntityId": "n1",
|
||||
"targetEntityId": "n2",
|
||||
"relationType": "HAS_FIELD",
|
||||
"weight": 1.0,
|
||||
"confidence": 0.9,
|
||||
"sourceId": "kb-1",
|
||||
},
|
||||
],
|
||||
"nodeCount": 2,
|
||||
"edgeCount": 1,
|
||||
},
|
||||
}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1"))
|
||||
|
||||
assert len(entities) == 2
|
||||
assert entities[0].name == "用户数据"
|
||||
assert entities[1].name == "user_id"
|
||||
|
||||
assert len(relations) == 1
|
||||
assert relations[0].source_name == "用户数据"
|
||||
assert relations[0].target_name == "user_id"
|
||||
assert relations[0].relation_type == "HAS_FIELD"
|
||||
assert relations[0].source_type == "Dataset"
|
||||
assert relations[0].target_type == "Field"
|
||||
|
||||
def test_unwrapped_subgraph_response(self, client: KGServiceClient):
|
||||
"""Java 直接返回 SubgraphExportVO(无全局包装)。"""
|
||||
mock_body = {
|
||||
"nodes": [
|
||||
{"id": "n1", "name": "A", "type": "T1", "description": ""},
|
||||
],
|
||||
"edges": [],
|
||||
"nodeCount": 1,
|
||||
"edgeCount": 0,
|
||||
}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||
|
||||
assert len(entities) == 1
|
||||
assert entities[0].name == "A"
|
||||
assert relations == []
|
||||
|
||||
def test_edge_with_unknown_entity(self, client: KGServiceClient):
|
||||
"""边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。"""
|
||||
mock_body = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}],
|
||||
"edges": [
|
||||
{
|
||||
"sourceEntityId": "n1",
|
||||
"targetEntityId": "n999",
|
||||
"relationType": "DEPENDS_ON",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
mock_resp = _resp(200, json=mock_body)
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||
|
||||
assert len(relations) == 1
|
||||
assert relations[0].source_name == "A"
|
||||
assert relations[0].target_name == "n999" # fallback to ID
|
||||
assert relations[0].target_type == ""
|
||||
|
||||
def test_fail_open_on_error(self, client: KGServiceClient):
|
||||
mock_resp = _resp(500, text="error")
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
|
||||
|
||||
assert entities == []
|
||||
assert relations == []
|
||||
|
||||
def test_request_params(self, client: KGServiceClient):
|
||||
"""验证子图请求参数正确传递。"""
|
||||
mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}})
|
||||
with patch.object(client, "_get_client") as mock_get:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.post = AsyncMock(return_value=mock_resp)
|
||||
mock_get.return_value = mock_http
|
||||
|
||||
_run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1"))
|
||||
|
||||
call_kwargs = mock_http.post.call_args
|
||||
assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0]
|
||||
assert call_kwargs.kwargs["params"] == {"depth": 3}
|
||||
assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# headers 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeaders:
|
||||
def test_headers_with_token_and_user(self, client: KGServiceClient):
|
||||
headers = client._headers(user_id="user-1")
|
||||
assert headers["X-Internal-Token"] == "test-token"
|
||||
assert headers["X-User-Id"] == "user-1"
|
||||
|
||||
def test_headers_without_user(self, client: KGServiceClient):
|
||||
headers = client._headers()
|
||||
assert "X-Internal-Token" in headers
|
||||
assert "X-User-Id" not in headers
|
||||
|
||||
def test_headers_without_token(self):
|
||||
c = KGServiceClient(base_url="http://test:8080", internal_token="")
|
||||
headers = c._headers(user_id="u1")
|
||||
assert "X-Internal-Token" not in headers
|
||||
assert headers["X-User-Id"] == "u1"
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Milvus 向量检索客户端的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def retriever() -> MilvusVectorRetriever:
|
||||
return MilvusVectorRetriever(
|
||||
uri="http://test-milvus:19530",
|
||||
embedding_model="text-embedding-test",
|
||||
)
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# has_collection 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHasCollection:
|
||||
def test_collection_exists(self, retriever: MilvusVectorRetriever):
|
||||
mock_client = MagicMock()
|
||||
mock_client.has_collection = MagicMock(return_value=True)
|
||||
retriever._milvus_client = mock_client
|
||||
|
||||
result = _run(retriever.has_collection("my_collection"))
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_collection_not_exists(self, retriever: MilvusVectorRetriever):
|
||||
mock_client = MagicMock()
|
||||
mock_client.has_collection = MagicMock(return_value=False)
|
||||
retriever._milvus_client = mock_client
|
||||
|
||||
result = _run(retriever.has_collection("nonexistent"))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_fail_open_on_error(self, retriever: MilvusVectorRetriever):
|
||||
mock_client = MagicMock()
|
||||
mock_client.has_collection = MagicMock(side_effect=Exception("connection error"))
|
||||
retriever._milvus_client = mock_client
|
||||
|
||||
result = _run(retriever.has_collection("test"))
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# search 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearch:
|
||||
def test_successful_search(self, retriever: MilvusVectorRetriever):
|
||||
"""正常搜索返回 VectorChunk 列表。"""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
retriever._embeddings = mock_embeddings
|
||||
|
||||
mock_milvus = MagicMock()
|
||||
mock_milvus.search = MagicMock(return_value=[
|
||||
[
|
||||
{"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}},
|
||||
{"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}},
|
||||
]
|
||||
])
|
||||
retriever._milvus_client = mock_milvus
|
||||
|
||||
chunks = _run(retriever.search("my_collection", "用户数据", top_k=5))
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].id == "doc1"
|
||||
assert chunks[0].text == "文档片段一"
|
||||
assert chunks[0].score == 0.95
|
||||
assert chunks[0].metadata == {"source": "kb1"}
|
||||
assert chunks[1].id == "doc2"
|
||||
assert chunks[1].score == 0.82
|
||||
|
||||
def test_empty_results(self, retriever: MilvusVectorRetriever):
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||
retriever._embeddings = mock_embeddings
|
||||
|
||||
mock_milvus = MagicMock()
|
||||
mock_milvus.search = MagicMock(return_value=[[]])
|
||||
retriever._milvus_client = mock_milvus
|
||||
|
||||
chunks = _run(retriever.search("col", "query"))
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever):
|
||||
"""Embedding 失败时 fail-open 返回空列表。"""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error"))
|
||||
retriever._embeddings = mock_embeddings
|
||||
|
||||
chunks = _run(retriever.search("col", "query"))
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever):
|
||||
"""Milvus 搜索失败时 fail-open 返回空列表。"""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||
retriever._embeddings = mock_embeddings
|
||||
|
||||
mock_milvus = MagicMock()
|
||||
mock_milvus.search = MagicMock(side_effect=Exception("Milvus down"))
|
||||
retriever._milvus_client = mock_milvus
|
||||
|
||||
chunks = _run(retriever.search("col", "query"))
|
||||
|
||||
assert chunks == []
|
||||
|
||||
def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever):
|
||||
"""验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。"""
|
||||
mock_embeddings = AsyncMock()
|
||||
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
|
||||
retriever._embeddings = mock_embeddings
|
||||
|
||||
mock_milvus = MagicMock()
|
||||
mock_milvus.search = MagicMock(return_value=[[]])
|
||||
retriever._milvus_client = mock_milvus
|
||||
|
||||
with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
|
||||
mock_to_thread.return_value = [[]]
|
||||
|
||||
chunks = _run(retriever.search("col", "query"))
|
||||
|
||||
# asyncio.to_thread 应该被调用来包装同步 Milvus 调用
|
||||
mock_to_thread.assert_called_once()
|
||||
call_args = mock_to_thread.call_args
|
||||
assert call_args.args[0] == mock_milvus.search
|
||||
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
234
runtime/datamate-python/app/module/kg_graphrag/test_retriever.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""GraphRAG 检索编排器的单元测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
RelationSummary,
|
||||
RetrievalStrategy,
|
||||
VectorChunk,
|
||||
)
|
||||
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||
|
||||
|
||||
def _run(coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _make_retriever(
|
||||
*,
|
||||
milvus_search_result: list[VectorChunk] | None = None,
|
||||
milvus_has_collection: bool = True,
|
||||
kg_fulltext_result: list[EntitySummary] | None = None,
|
||||
kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None,
|
||||
) -> GraphRAGRetriever:
|
||||
"""创建带 mock 依赖的 retriever。"""
|
||||
mock_milvus = AsyncMock()
|
||||
mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection)
|
||||
mock_milvus.search = AsyncMock(return_value=milvus_search_result or [])
|
||||
|
||||
mock_kg = AsyncMock()
|
||||
mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or [])
|
||||
mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], []))
|
||||
|
||||
return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# retrieve 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetrieve:
|
||||
"""retrieve 方法的测试。"""
|
||||
|
||||
def test_both_vector_and_graph(self):
|
||||
"""同时启用向量和图谱检索。"""
|
||||
chunks = [
|
||||
VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9),
|
||||
VectorChunk(id="c2", text="其他内容", score=0.7),
|
||||
]
|
||||
seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||
entities = [
|
||||
EntitySummary(id="e1", name="用户数据", type="Dataset"),
|
||||
EntitySummary(id="e2", name="user_id", type="Field"),
|
||||
]
|
||||
relations = [
|
||||
RelationSummary(
|
||||
source_name="用户数据", source_type="Dataset",
|
||||
target_name="user_id", target_type="Field",
|
||||
relation_type="HAS_FIELD",
|
||||
),
|
||||
]
|
||||
retriever = _make_retriever(
|
||||
milvus_search_result=chunks,
|
||||
kg_fulltext_result=seed,
|
||||
kg_subgraph_result=(entities, relations),
|
||||
)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="用户数据有哪些字段",
|
||||
collection_name="kb1",
|
||||
graph_id="graph-1",
|
||||
strategy=RetrievalStrategy(),
|
||||
user_id="u1",
|
||||
))
|
||||
|
||||
assert len(ctx.vector_chunks) == 2
|
||||
assert len(ctx.graph_context.entities) == 2
|
||||
assert len(ctx.graph_context.relations) == 1
|
||||
assert "用户数据" in ctx.graph_context.textualized
|
||||
assert "## 相关文档" in ctx.merged_text
|
||||
assert "## 知识图谱上下文" in ctx.merged_text
|
||||
|
||||
def test_vector_only(self):
|
||||
"""仅启用向量检索。"""
|
||||
chunks = [VectorChunk(id="c1", text="doc", score=0.9)]
|
||||
retriever = _make_retriever(milvus_search_result=chunks)
|
||||
strategy = RetrievalStrategy(enable_graph=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert len(ctx.vector_chunks) == 1
|
||||
assert ctx.graph_context.entities == []
|
||||
# KG client should not be called
|
||||
retriever._kg.fulltext_search.assert_not_called()
|
||||
|
||||
def test_graph_only(self):
|
||||
"""仅启用图谱检索。"""
|
||||
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||
entities = [EntitySummary(id="e1", name="A", type="T")]
|
||||
retriever = _make_retriever(
|
||||
kg_fulltext_result=seed,
|
||||
kg_subgraph_result=(entities, []),
|
||||
)
|
||||
strategy = RetrievalStrategy(enable_vector=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
assert len(ctx.graph_context.entities) == 1
|
||||
retriever._milvus.search.assert_not_called()
|
||||
|
||||
def test_no_seed_entities(self):
|
||||
"""图谱全文检索无结果时,不调用子图查询。"""
|
||||
retriever = _make_retriever(kg_fulltext_result=[])
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(enable_vector=False), user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.graph_context.entities == []
|
||||
retriever._kg.get_subgraph.assert_not_called()
|
||||
|
||||
def test_collection_not_found_skips_vector(self):
|
||||
"""collection 不存在时跳过向量检索。"""
|
||||
retriever = _make_retriever(milvus_has_collection=False)
|
||||
strategy = RetrievalStrategy(enable_graph=False)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="nonexistent", graph_id="g",
|
||||
strategy=strategy, user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
retriever._milvus.search.assert_not_called()
|
||||
|
||||
def test_both_empty(self):
|
||||
"""两条检索路径都无结果。"""
|
||||
retriever = _make_retriever()
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="nothing", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(), user_id="u",
|
||||
))
|
||||
|
||||
assert ctx.vector_chunks == []
|
||||
assert ctx.graph_context.entities == []
|
||||
assert "未检索到相关上下文信息" in ctx.merged_text
|
||||
|
||||
def test_vector_error_fail_open(self):
|
||||
"""向量检索异常时 fail-open,图谱检索仍可正常返回。"""
|
||||
retriever = _make_retriever()
|
||||
retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down"))
|
||||
|
||||
seed = [EntitySummary(id="e1", name="A", type="T")]
|
||||
retriever._kg.fulltext_search = AsyncMock(return_value=seed)
|
||||
retriever._kg.get_subgraph = AsyncMock(
|
||||
return_value=([EntitySummary(id="e1", name="A", type="T")], [])
|
||||
)
|
||||
|
||||
ctx = _run(retriever.retrieve(
|
||||
query="test", collection_name="kb", graph_id="g",
|
||||
strategy=RetrievalStrategy(), user_id="u",
|
||||
))
|
||||
|
||||
# 向量检索失败,但图谱检索仍有结果
|
||||
assert ctx.vector_chunks == []
|
||||
assert len(ctx.graph_context.entities) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _rank_results 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRankResults:
|
||||
"""_rank_results 方法的测试。"""
|
||||
|
||||
def _make_retriever_instance(self) -> GraphRAGRetriever:
|
||||
return GraphRAGRetriever(
|
||||
milvus_client=MagicMock(),
|
||||
kg_client=MagicMock(),
|
||||
)
|
||||
|
||||
def test_empty_chunks(self):
|
||||
r = self._make_retriever_instance()
|
||||
result = r._rank_results([], [], [], RetrievalStrategy())
|
||||
assert result == []
|
||||
|
||||
def test_single_chunk(self):
|
||||
r = self._make_retriever_instance()
|
||||
chunks = [VectorChunk(id="1", text="text", score=0.9)]
|
||||
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "1"
|
||||
|
||||
def test_graph_boost_reorders(self):
|
||||
"""图谱实体命中应提升文档片段排名。"""
|
||||
r = self._make_retriever_instance()
|
||||
# chunk1 向量分高但无图谱命中
|
||||
# chunk2 向量分低但命中图谱实体
|
||||
chunks = [
|
||||
VectorChunk(id="1", text="无关内容", score=0.9),
|
||||
VectorChunk(id="2", text="包含用户数据的内容", score=0.5),
|
||||
]
|
||||
entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
|
||||
strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7)
|
||||
|
||||
result = r._rank_results(chunks, entities, [], strategy)
|
||||
|
||||
# chunk2 应该排在前面(graph_boost 更高)
|
||||
assert result[0].id == "2"
|
||||
|
||||
def test_all_same_score(self):
|
||||
"""所有 chunk 分数相同时不崩溃。"""
|
||||
r = self._make_retriever_instance()
|
||||
chunks = [
|
||||
VectorChunk(id="1", text="a", score=0.5),
|
||||
VectorChunk(id="2", text="b", score=0.5),
|
||||
]
|
||||
result = r._rank_results(chunks, [], [], RetrievalStrategy())
|
||||
assert len(result) == 2
|
||||
Reference in New Issue
Block a user