feat(kg): 实现 Phase 2 GraphRAG 融合功能

核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证

新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口

修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖

测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)

关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*

安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)

测试结果:79 tests pass 
This commit is contained in:
2026-02-20 09:41:55 +08:00
parent 0ed7dcbee7
commit 39338df808
18 changed files with 2745 additions and 0 deletions

View File

@@ -88,6 +88,29 @@ class Settings(BaseSettings):
kg_alignment_vector_threshold: float = 0.92
kg_alignment_llm_threshold: float = 0.78
# GraphRAG 融合查询配置
graphrag_enabled: bool = False
graphrag_milvus_uri: str = "http://milvus-standalone:19530"
graphrag_kg_service_url: str = "http://datamate-kg:8080"
graphrag_kg_internal_token: str = ""
# GraphRAG - 检索策略默认值
graphrag_vector_top_k: int = 5
graphrag_graph_depth: int = 2
graphrag_graph_max_entities: int = 20
graphrag_vector_weight: float = 0.6
graphrag_graph_weight: float = 0.4
# GraphRAG - LLM(空则复用 kg_llm_* 配置)
graphrag_llm_model: str = ""
graphrag_llm_base_url: Optional[str] = None
graphrag_llm_api_key: SecretStr = SecretStr("EMPTY")
graphrag_llm_temperature: float = 0.1
graphrag_llm_timeout_seconds: int = 60
# GraphRAG - Embedding(空则复用 kg_alignment_embedding_* 配置)
graphrag_embedding_model: str = ""
# 标注编辑器(Label Studio Editor)相关
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数

View File

@@ -8,6 +8,7 @@ from .evaluation.interface import router as evaluation_router
from .collection.interface import router as collection_route
from .dataset.interface import router as dataset_router
from .kg_extraction.interface import router as kg_extraction_router
from .kg_graphrag.interface import router as kg_graphrag_router
router = APIRouter(
prefix="/api"
@@ -21,5 +22,6 @@ router.include_router(evaluation_router)
router.include_router(collection_route)
router.include_router(dataset_router)
router.include_router(kg_extraction_router)
router.include_router(kg_graphrag_router)
__all__ = ["router"]

View File

@@ -0,0 +1,5 @@
"""GraphRAG 融合查询模块。"""
from app.module.kg_graphrag.interface import router
__all__ = ["router"]

View File

@@ -0,0 +1,110 @@
"""三元组文本化 + 上下文构建。
将图谱子图(实体 + 关系)转为自然语言描述,
并与向量检索片段合并为 LLM 可消费的上下文文本。
"""
from __future__ import annotations
from app.module.kg_graphrag.models import (
EntitySummary,
RelationSummary,
VectorChunk,
)
# 关系类型 -> 中文模板映射
RELATION_TEMPLATES: dict[str, str] = {
"HAS_FIELD": "{source}包含字段{target}",
"DERIVED_FROM": "{source}来源于{target}",
"USES_DATASET": "{source}使用了数据集{target}",
"PRODUCES": "{source}产出了{target}",
"ASSIGNED_TO": "{source}分配给了{target}",
"BELONGS_TO": "{source}属于{target}",
"TRIGGERS": "{source}触发了{target}",
"DEPENDS_ON": "{source}依赖于{target}",
"IMPACTS": "{source}影响了{target}",
"SOURCED_FROM": "{source}的知识来源于{target}",
}
# 通用模板(未在映射中的关系类型)
_DEFAULT_TEMPLATE = "{source}{target}存在{relation}关系"
def textualize_subgraph(
entities: list[EntitySummary],
relations: list[RelationSummary],
) -> str:
"""将图谱子图转为自然语言描述。
Args:
entities: 子图中的实体列表。
relations: 子图中的关系列表。
Returns:
文本化后的图谱描述,每条关系/实体一行。
"""
lines: list[str] = []
# 记录有关系的实体名称
mentioned_entities: set[str] = set()
# 1. 对每条关系生成一句话
for rel in relations:
source_label = f"{rel.source_type}'{rel.source_name}'"
target_label = f"{rel.target_type}'{rel.target_name}'"
template = RELATION_TEMPLATES.get(rel.relation_type, _DEFAULT_TEMPLATE)
line = template.format(
source=source_label,
target=target_label,
relation=rel.relation_type,
)
lines.append(line)
mentioned_entities.add(rel.source_name)
mentioned_entities.add(rel.target_name)
# 2. 对独立实体(无关系)生成描述句
for entity in entities:
if entity.name not in mentioned_entities:
desc = entity.description or ""
if desc:
lines.append(f"{entity.type}'{entity.name}': {desc}")
else:
lines.append(f"存在{entity.type}'{entity.name}'")
return "\n".join(lines)
def build_context(
vector_chunks: list[VectorChunk],
graph_text: str,
vector_weight: float = 0.6,
graph_weight: float = 0.4,
) -> str:
"""合并向量检索片段和图谱文本化内容为 LLM 上下文。
Args:
vector_chunks: 向量检索到的文档片段列表。
graph_text: 文本化后的图谱描述。
vector_weight: 向量分数权重(当前用于日志/调试,不影响上下文排序)。
graph_weight: 图谱相关性权重。
Returns:
合并后的上下文文本,分为「相关文档」和「知识图谱上下文」两个部分。
"""
sections: list[str] = []
# 向量检索片段
if vector_chunks:
doc_lines = ["## 相关文档"]
for i, chunk in enumerate(vector_chunks, 1):
doc_lines.append(f"[{i}] {chunk.text}")
sections.append("\n".join(doc_lines))
# 图谱文本化内容
if graph_text:
sections.append(f"## 知识图谱上下文\n{graph_text}")
if not sections:
return "(未检索到相关上下文信息)"
return "\n\n".join(sections)

View File

@@ -0,0 +1,101 @@
"""LLM 生成器。
基于增强上下文(向量 + 图谱)调用 LLM 生成回答,
支持同步和流式两种模式。
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from pydantic import SecretStr
from app.core.logging import get_logger
logger = get_logger(__name__)
_SYSTEM_PROMPT = (
"你是 DataMate 数据管理平台的智能助手。请根据以下上下文信息回答用户的问题。\n"
"如果上下文中没有相关信息,请明确说明。不要编造信息。"
)
class GraphRAGGenerator:
"""GraphRAG LLM 生成器。"""
def __init__(
self,
*,
model: str = "gpt-4o-mini",
base_url: str | None = None,
api_key: SecretStr = SecretStr("EMPTY"),
temperature: float = 0.1,
timeout: int = 60,
) -> None:
self._model = model
self._base_url = base_url
self._api_key = api_key
self._temperature = temperature
self._timeout = timeout
self._llm = None
@property
def model_name(self) -> str:
return self._model
@classmethod
def from_settings(cls) -> GraphRAGGenerator:
from app.core.config import settings
model = settings.graphrag_llm_model or settings.kg_llm_model
base_url = settings.graphrag_llm_base_url or settings.kg_llm_base_url
api_key = (
settings.graphrag_llm_api_key
if settings.graphrag_llm_api_key.get_secret_value() != "EMPTY"
else settings.kg_llm_api_key
)
return cls(
model=model,
base_url=base_url,
api_key=api_key,
temperature=settings.graphrag_llm_temperature,
timeout=settings.graphrag_llm_timeout_seconds,
)
def _get_llm(self):
if self._llm is None:
from langchain_openai import ChatOpenAI
self._llm = ChatOpenAI(
model=self._model,
base_url=self._base_url,
api_key=self._api_key,
temperature=self._temperature,
timeout=self._timeout,
)
return self._llm
def _build_messages(self, query: str, context: str) -> list[dict[str, str]]:
return [
{"role": "system", "content": _SYSTEM_PROMPT},
{
"role": "user",
"content": f"{context}\n\n用户问题: {query}\n\n请基于上下文中的信息回答。",
},
]
async def generate(self, query: str, context: str) -> str:
"""基于增强上下文生成回答。"""
messages = self._build_messages(query, context)
llm = self._get_llm()
response = await llm.ainvoke(messages)
return str(response.content)
async def generate_stream(self, query: str, context: str) -> AsyncIterator[str]:
"""基于增强上下文流式生成回答,逐 token 返回。"""
messages = self._build_messages(query, context)
llm = self._get_llm()
async for chunk in llm.astream(messages):
content = chunk.content
if content:
yield str(content)

View File

@@ -0,0 +1,249 @@
"""GraphRAG 融合查询 API 端点。
提供向量检索 + 知识图谱的融合查询能力:
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
"""
from __future__ import annotations
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from app.core.logging import get_logger
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
from app.module.kg_graphrag.models import (
GraphRAGQueryRequest,
GraphRAGQueryResponse,
RetrievalContext,
)
from app.module.kg_graphrag.retriever import GraphRAGRetriever
from app.module.kg_graphrag.generator import GraphRAGGenerator
from app.module.shared.schema import StandardResponse
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
logger = get_logger(__name__)
# 延迟初始化
_retriever: GraphRAGRetriever | None = None
_generator: GraphRAGGenerator | None = None
_kb_validator: KnowledgeBaseAccessValidator | None = None
def _get_retriever() -> GraphRAGRetriever:
global _retriever
if _retriever is None:
_retriever = GraphRAGRetriever.from_settings()
return _retriever
def _get_generator() -> GraphRAGGenerator:
global _generator
if _generator is None:
_generator = GraphRAGGenerator.from_settings()
return _generator
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
global _kb_validator
if _kb_validator is None:
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
return _kb_validator
def _require_caller_id(
x_user_id: Annotated[
str,
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
],
) -> str:
caller = x_user_id.strip()
if not caller:
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
return caller
# ---------------------------------------------------------------------------
# P0: 完整 GraphRAG 查询
# ---------------------------------------------------------------------------
@router.post(
"/query",
response_model=StandardResponse[GraphRAGQueryResponse],
summary="GraphRAG 查询",
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
)
async def query(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
generator = _get_generator()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
try:
answer = await generator.generate(query=req.query, context=context.merged_text)
except Exception:
logger.exception("[%s] Generation failed", trace_id)
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
result = GraphRAGQueryResponse(
answer=answer,
context=context,
model=generator.model_name,
)
return StandardResponse(code=200, message="success", data=result)
# ---------------------------------------------------------------------------
# P1-1: 仅检索
# ---------------------------------------------------------------------------
@router.post(
"/retrieve",
response_model=StandardResponse[RetrievalContext],
summary="GraphRAG 仅检索",
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
)
async def retrieve(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=context)
# ---------------------------------------------------------------------------
# P1-4: 流式查询 (SSE)
# ---------------------------------------------------------------------------
@router.post(
"/query/stream",
summary="GraphRAG 流式查询",
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
)
async def query_stream(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
generator = _get_generator()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
import json
async def event_stream():
try:
async for token in generator.generate_stream(
query=req.query, context=context.merged_text
):
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
# 结束事件:附带检索上下文
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
except Exception:
logger.exception("[%s] Stream generation failed", trace_id)
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")

View File

@@ -0,0 +1,118 @@
"""知识库访问权限校验。
在执行 GraphRAG 检索前,调用 Java rag-indexer-service 的
GET /knowledge-base/{id} 端点验证当前用户是否有权访问该知识库。
Java 侧实现参考:KnowledgeBaseService.getKnowledgeBaseWithAccessCheck()
- 查找 KB 是否存在
- 校验 createdBy == currentUserId(管理员跳过)
- 不满足则抛出 sys.0005 (INSUFFICIENT_PERMISSIONS)
"""
from __future__ import annotations
import httpx
from app.core.logging import get_logger
logger = get_logger(__name__)
class KnowledgeBaseAccessValidator:
"""通过 Java 后端校验用户是否有权访问指定知识库。"""
def __init__(
self,
*,
base_url: str = "http://datamate-backend:8080/api",
timeout: float = 10.0,
) -> None:
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
@classmethod
def from_settings(cls) -> KnowledgeBaseAccessValidator:
from app.core.config import settings
return cls(base_url=settings.datamate_backend_base_url)
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
)
return self._client
async def check_access(
self,
knowledge_base_id: str,
user_id: str,
*,
collection_name: str | None = None,
) -> bool:
"""校验用户是否有权访问指定知识库。
调用 Java 后端 GET /knowledge-base/{id},该端点内部执行
owner 校验(createdBy == currentUserId,管理员跳过)。
当 *collection_name* 不为 None 时,还会校验请求中的
collection_name 与该知识库实际的 name 是否一致,防止
用户提交合法 KB ID 但篡改 collection_name 来访问
其他知识库的 Milvus 数据。
Returns:
True — 用户有权访问且 collection_name 匹配
False — 无权访问、collection_name 不匹配或校验失败
"""
try:
client = self._get_client()
resp = await client.get(
f"/knowledge-base/{knowledge_base_id}",
headers={"X-User-Id": user_id},
)
if resp.status_code == 200:
body = resp.json()
# Java 全局包装: {"code": 200, "data": {...}}
# code != 200 说明业务层拒绝(如权限不足)
code = body.get("code", resp.status_code)
if code != 200:
logger.warning(
"KB access denied: kb_id=%s, user=%s, biz_code=%s, msg=%s",
knowledge_base_id, user_id, code, body.get("message", ""),
)
return False
# 校验 collection_name 与 KB 实际名称的绑定关系
if collection_name is not None:
data = body.get("data") or {}
actual_name = data.get("name") if isinstance(data, dict) else None
if actual_name != collection_name:
logger.warning(
"KB collection_name mismatch: kb_id=%s, "
"expected=%s, actual=%s, user=%s",
knowledge_base_id, collection_name,
actual_name, user_id,
)
return False
return True
# HTTP 4xx/5xx
logger.warning(
"KB access check returned HTTP %d: kb_id=%s, user=%s",
resp.status_code, knowledge_base_id, user_id,
)
return False
except Exception:
# 网络异常时 fail-close:拒绝访问,防止绕过权限
logger.exception(
"KB access check failed (fail-close): kb_id=%s, user=%s",
knowledge_base_id, user_id,
)
return False
async def close(self) -> None:
if self._client is not None:
await self._client.aclose()
self._client = None

View File

@@ -0,0 +1,197 @@
"""KG 服务 REST 客户端。
通过 httpx 调用 Java 侧 knowledge-graph-service 的查询 API,
包括全文检索和子图导出。
失败策略:fail-open —— KG 服务不可用时返回空结果 + 日志告警。
"""
from __future__ import annotations
import httpx
from app.core.logging import get_logger
from app.module.kg_graphrag.models import EntitySummary, RelationSummary
logger = get_logger(__name__)
class KGServiceClient:
"""Java KG 服务 REST 客户端。"""
def __init__(
self,
*,
base_url: str = "http://datamate-kg:8080",
internal_token: str = "",
timeout: float = 30.0,
) -> None:
self._base_url = base_url.rstrip("/")
self._internal_token = internal_token
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
@classmethod
def from_settings(cls) -> KGServiceClient:
from app.core.config import settings
return cls(
base_url=settings.graphrag_kg_service_url,
internal_token=settings.graphrag_kg_internal_token,
timeout=30.0,
)
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
)
return self._client
def _headers(self, user_id: str = "") -> dict[str, str]:
headers: dict[str, str] = {}
if self._internal_token:
headers["X-Internal-Token"] = self._internal_token
if user_id:
headers["X-User-Id"] = user_id
return headers
async def fulltext_search(
self,
graph_id: str,
query: str,
size: int = 10,
user_id: str = "",
) -> list[EntitySummary]:
"""调用 KG 服务全文检索,返回匹配的实体列表。
Fail-open: KG 服务不可用时返回空列表。
"""
try:
return await self._fulltext_search_impl(graph_id, query, size, user_id)
except Exception:
logger.exception(
"KG fulltext search failed for graph_id=%s (fail-open, returning empty)",
graph_id,
)
return []
async def _fulltext_search_impl(
self,
graph_id: str,
query: str,
size: int,
user_id: str,
) -> list[EntitySummary]:
client = self._get_client()
resp = await client.get(
f"/knowledge-graph/{graph_id}/query/search",
params={"q": query, "size": size},
headers=self._headers(user_id),
)
resp.raise_for_status()
body = resp.json()
# Java 返回 PagedResponse<SearchHitVO>:
# 可能被全局包装为 {"code": 200, "data": PagedResponse}
# 也可能直接返回 PagedResponse {"page": 0, "content": [...]}
data = body.get("data", body)
# PagedResponse 将实体列表放在 content 字段中
items: list[dict] = (
data.get("content", []) if isinstance(data, dict) else data if isinstance(data, list) else []
)
entities: list[EntitySummary] = []
for item in items:
entities.append(
EntitySummary(
id=str(item.get("id", "")),
name=item.get("name", ""),
type=item.get("type", ""),
description=item.get("description", ""),
)
)
return entities
async def get_subgraph(
self,
graph_id: str,
entity_ids: list[str],
depth: int = 1,
user_id: str = "",
) -> tuple[list[EntitySummary], list[RelationSummary]]:
"""获取种子实体的 N-hop 子图。
Fail-open: KG 服务不可用时返回空子图。
"""
try:
return await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
except Exception:
logger.exception(
"KG subgraph export failed for graph_id=%s (fail-open, returning empty)",
graph_id,
)
return [], []
async def _get_subgraph_impl(
self,
graph_id: str,
entity_ids: list[str],
depth: int,
user_id: str,
) -> tuple[list[EntitySummary], list[RelationSummary]]:
client = self._get_client()
resp = await client.post(
f"/knowledge-graph/{graph_id}/query/subgraph/export",
params={"depth": depth},
json={"entityIds": entity_ids},
headers=self._headers(user_id),
)
resp.raise_for_status()
body = resp.json()
# Java 返回 SubgraphExportVO:
# 可能被全局包装为 {"code": 200, "data": SubgraphExportVO}
# 也可能直接返回 SubgraphExportVO {"nodes": [...], "edges": [...]}
data = body.get("data", body) if isinstance(body.get("data"), dict) else body
nodes_raw = data.get("nodes", [])
edges_raw = data.get("edges", [])
# ExportNodeVO: id, name, type, description, properties (Map)
entities: list[EntitySummary] = []
for node in nodes_raw:
entities.append(
EntitySummary(
id=str(node.get("id", "")),
name=node.get("name", ""),
type=node.get("type", ""),
description=node.get("description", ""),
)
)
relations: list[RelationSummary] = []
# 构建 id -> entity 的映射用于查找 source/target 名称和类型
entity_map = {e.id: e for e in entities}
# ExportEdgeVO: sourceEntityId, targetEntityId, relationType
# 注意:sourceId 是数据来源 ID,不是源实体 ID
for edge in edges_raw:
source_id = str(edge.get("sourceEntityId", ""))
target_id = str(edge.get("targetEntityId", ""))
source_entity = entity_map.get(source_id)
target_entity = entity_map.get(target_id)
relations.append(
RelationSummary(
source_name=source_entity.name if source_entity else source_id,
source_type=source_entity.type if source_entity else "",
target_name=target_entity.name if target_entity else target_id,
target_type=target_entity.type if target_entity else "",
relation_type=edge.get("relationType", ""),
)
)
return entities, relations
async def close(self) -> None:
if self._client is not None:
await self._client.aclose()
self._client = None

View File

@@ -0,0 +1,135 @@
"""Milvus 向量检索客户端。
通过 pymilvus 连接 Milvus,对查询文本进行 embedding 后执行混合搜索,
返回 top-K 文档片段。
失败策略:fail-open —— Milvus 不可用时返回空列表 + 日志告警。
"""
from __future__ import annotations
import asyncio
from pydantic import SecretStr
from app.core.logging import get_logger
from app.module.kg_graphrag.models import VectorChunk
logger = get_logger(__name__)
class MilvusVectorRetriever:
"""Milvus 向量检索器。"""
def __init__(
self,
*,
uri: str = "http://milvus-standalone:19530",
embedding_model: str = "text-embedding-3-small",
embedding_base_url: str | None = None,
embedding_api_key: SecretStr = SecretStr("EMPTY"),
) -> None:
self._uri = uri
self._embedding_model = embedding_model
self._embedding_base_url = embedding_base_url
self._embedding_api_key = embedding_api_key
# Lazy init
self._milvus_client = None
self._embeddings = None
@classmethod
def from_settings(cls) -> MilvusVectorRetriever:
from app.core.config import settings
embedding_model = (
settings.graphrag_embedding_model
or settings.kg_alignment_embedding_model
)
return cls(
uri=settings.graphrag_milvus_uri,
embedding_model=embedding_model,
embedding_base_url=settings.kg_llm_base_url,
embedding_api_key=settings.kg_llm_api_key,
)
def _get_embeddings(self):
if self._embeddings is None:
from langchain_openai import OpenAIEmbeddings
self._embeddings = OpenAIEmbeddings(
model=self._embedding_model,
base_url=self._embedding_base_url,
api_key=self._embedding_api_key,
)
return self._embeddings
def _get_milvus_client(self):
if self._milvus_client is None:
from pymilvus import MilvusClient
self._milvus_client = MilvusClient(uri=self._uri)
logger.info("Connected to Milvus at %s", self._uri)
return self._milvus_client
async def has_collection(self, collection_name: str) -> bool:
"""检查 Milvus 中是否存在指定 collection(防止越权访问不存在的库)。"""
try:
client = self._get_milvus_client()
return await asyncio.to_thread(client.has_collection, collection_name)
except Exception:
logger.exception("Milvus has_collection check failed for %s", collection_name)
return False
async def search(
self,
collection_name: str,
query: str,
top_k: int = 5,
) -> list[VectorChunk]:
"""向量搜索:embed query -> Milvus search -> 返回 top-K 文档片段。
Fail-open: Milvus 不可用时返回空列表。
"""
try:
return await self._search_impl(collection_name, query, top_k)
except Exception:
logger.exception(
"Milvus search failed for collection=%s (fail-open, returning empty)",
collection_name,
)
return []
async def _search_impl(
self,
collection_name: str,
query: str,
top_k: int,
) -> list[VectorChunk]:
# 1. Embed query
query_vector = await self._get_embeddings().aembed_query(query)
# 2. Milvus search(同步 I/O,通过 to_thread 避免阻塞事件循环)
client = self._get_milvus_client()
results = await asyncio.to_thread(
client.search,
collection_name=collection_name,
data=[query_vector],
limit=top_k,
output_fields=["text", "metadata"],
search_params={"metric_type": "COSINE", "params": {"nprobe": 16}},
)
# 3. 转换为 VectorChunk
chunks: list[VectorChunk] = []
if results and len(results) > 0:
for hit in results[0]:
entity = hit.get("entity", {})
chunks.append(
VectorChunk(
id=str(hit.get("id", "")),
text=entity.get("text", ""),
score=float(hit.get("distance", 0.0)),
metadata=entity.get("metadata", {}),
)
)
return chunks

View File

@@ -0,0 +1,102 @@
"""GraphRAG 融合查询的请求/响应数据模型。"""
from __future__ import annotations
from pydantic import BaseModel, Field
class RetrievalStrategy(BaseModel):
"""检索策略配置。"""
vector_top_k: int = Field(default=5, ge=1, le=50, description="向量检索返回数")
graph_depth: int = Field(default=2, ge=1, le=5, description="图谱扩展深度")
graph_max_entities: int = Field(default=20, ge=1, le=100, description="图谱最大实体数")
vector_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="向量分数权重")
graph_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="图谱相关性权重")
enable_graph: bool = Field(default=True, description="是否启用图谱检索")
enable_vector: bool = Field(default=True, description="是否启用向量检索")
class GraphRAGQueryRequest(BaseModel):
"""GraphRAG 查询请求。"""
query: str = Field(
...,
min_length=1,
max_length=2000,
description="用户查询",
)
knowledge_base_id: str = Field(
...,
min_length=1,
max_length=64,
description="知识库 ID,用于权限校验(由上游 Java 后端传入)",
)
collection_name: str = Field(
...,
min_length=1,
max_length=256,
pattern=r"^[a-zA-Z0-9_\-\u4e00-\u9fff]+$",
description="Milvus collection 名称(= 知识库名),仅允许字母、数字、下划线、连字符和中文",
)
graph_id: str = Field(
...,
pattern=r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
description="Neo4j 图谱 ID(UUID 格式)",
)
strategy: RetrievalStrategy = Field(
default_factory=RetrievalStrategy,
description="可选策略覆盖",
)
class VectorChunk(BaseModel):
"""向量检索到的文档片段。"""
id: str
text: str
score: float
metadata: dict[str, object] = Field(default_factory=dict)
class EntitySummary(BaseModel):
"""实体摘要。"""
id: str
name: str
type: str
description: str = ""
class RelationSummary(BaseModel):
"""关系摘要。"""
source_name: str
source_type: str
target_name: str
target_type: str
relation_type: str
class GraphContext(BaseModel):
"""图谱上下文。"""
entities: list[EntitySummary] = Field(default_factory=list)
relations: list[RelationSummary] = Field(default_factory=list)
textualized: str = ""
class RetrievalContext(BaseModel):
"""检索上下文(检索结果的结构化表示)。"""
vector_chunks: list[VectorChunk] = Field(default_factory=list)
graph_context: GraphContext = Field(default_factory=GraphContext)
merged_text: str = ""
class GraphRAGQueryResponse(BaseModel):
"""GraphRAG 查询响应。"""
answer: str = Field(..., description="LLM 生成的回答")
context: RetrievalContext = Field(..., description="检索上下文")
model: str = Field(..., description="使用的 LLM 模型名")

View File

@@ -0,0 +1,214 @@
"""GraphRAG 检索编排器。
并行执行向量检索和图谱检索,融合排序后构建统一上下文。
"""
from __future__ import annotations
import asyncio
from app.core.logging import get_logger
from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph
from app.module.kg_graphrag.kg_client import KGServiceClient
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
from app.module.kg_graphrag.models import (
EntitySummary,
GraphContext,
RelationSummary,
RetrievalContext,
RetrievalStrategy,
VectorChunk,
)
logger = get_logger(__name__)
class GraphRAGRetriever:
"""GraphRAG 检索编排器。"""
def __init__(
self,
*,
milvus_client: MilvusVectorRetriever,
kg_client: KGServiceClient,
) -> None:
self._milvus = milvus_client
self._kg = kg_client
@classmethod
def from_settings(cls) -> GraphRAGRetriever:
return cls(
milvus_client=MilvusVectorRetriever.from_settings(),
kg_client=KGServiceClient.from_settings(),
)
async def retrieve(
self,
query: str,
collection_name: str,
graph_id: str,
strategy: RetrievalStrategy,
user_id: str = "",
) -> RetrievalContext:
"""并行执行向量检索 + 图谱检索,融合结果。"""
# 构建并行任务
tasks: dict[str, asyncio.Task] = {}
if strategy.enable_vector:
# 先校验 collection 存在性,防止越权访问
if not await self._milvus.has_collection(collection_name):
logger.warning(
"Collection %s not found, skipping vector retrieval",
collection_name,
)
else:
tasks["vector"] = asyncio.create_task(
self._milvus.search(
collection_name=collection_name,
query=query,
top_k=strategy.vector_top_k,
)
)
if strategy.enable_graph:
tasks["graph"] = asyncio.create_task(
self._retrieve_graph(
query=query,
graph_id=graph_id,
strategy=strategy,
user_id=user_id,
)
)
# 等待所有任务完成
if tasks:
await asyncio.gather(*tasks.values(), return_exceptions=True)
# 收集结果
vector_chunks: list[VectorChunk] = []
if "vector" in tasks:
try:
vector_chunks = tasks["vector"].result()
except Exception:
logger.exception("Vector retrieval task failed")
entities: list[EntitySummary] = []
relations: list[RelationSummary] = []
if "graph" in tasks:
try:
entities, relations = tasks["graph"].result()
except Exception:
logger.exception("Graph retrieval task failed")
# 融合排序
vector_chunks = self._rank_results(
vector_chunks, entities, relations, strategy
)
# 三元组文本化
graph_text = textualize_subgraph(entities, relations)
# 构建上下文
merged_text = build_context(
vector_chunks,
graph_text,
vector_weight=strategy.vector_weight,
graph_weight=strategy.graph_weight,
)
return RetrievalContext(
vector_chunks=vector_chunks,
graph_context=GraphContext(
entities=entities,
relations=relations,
textualized=graph_text,
),
merged_text=merged_text,
)
async def _retrieve_graph(
self,
query: str,
graph_id: str,
strategy: RetrievalStrategy,
user_id: str,
) -> tuple[list[EntitySummary], list[RelationSummary]]:
"""图谱检索:全文搜索 -> 种子实体 -> 子图扩展。"""
# 1. 全文检索获取种子实体
seed_entities = await self._kg.fulltext_search(
graph_id=graph_id,
query=query,
size=strategy.graph_max_entities,
user_id=user_id,
)
if not seed_entities:
logger.debug("No seed entities found for query: %s", query)
return [], []
# 2. 获取种子实体的 N-hop 子图
seed_ids = [e.id for e in seed_entities]
entities, relations = await self._kg.get_subgraph(
graph_id=graph_id,
entity_ids=seed_ids,
depth=strategy.graph_depth,
user_id=user_id,
)
logger.info(
"Graph retrieval: %d seed entities -> %d entities, %d relations",
len(seed_entities), len(entities), len(relations),
)
return entities, relations
def _rank_results(
self,
vector_chunks: list[VectorChunk],
entities: list[EntitySummary],
relations: list[RelationSummary],
strategy: RetrievalStrategy,
) -> list[VectorChunk]:
"""对向量检索结果进行融合排序。
基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。
"""
if not vector_chunks:
return vector_chunks
# 向量分数归一化 (min-max scaling)
scores = [c.score for c in vector_chunks]
min_score = min(scores)
max_score = max(scores)
score_range = max_score - min_score
# 构建图谱实体名称集合,用于关联度加分
graph_entity_names = {e.name.lower() for e in entities}
ranked: list[tuple[float, VectorChunk]] = []
for chunk in vector_chunks:
# 归一化向量分数
norm_score = (
(chunk.score - min_score) / score_range
if score_range > 0
else 1.0
)
# 图谱关联度加分:文档片段中提及图谱实体名称
graph_boost = 0.0
if graph_entity_names:
chunk_text_lower = chunk.text.lower()
mentioned = sum(
1 for name in graph_entity_names if name in chunk_text_lower
)
graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0)
# 加权融合分数
final_score = (
strategy.vector_weight * norm_score
+ strategy.graph_weight * graph_boost
)
ranked.append((final_score, chunk))
# 按融合分数降序排序
ranked.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in ranked]

View File

@@ -0,0 +1,182 @@
"""三元组文本化 + 上下文构建的单元测试。"""
from app.module.kg_graphrag.context_builder import (
RELATION_TEMPLATES,
build_context,
textualize_subgraph,
)
from app.module.kg_graphrag.models import (
EntitySummary,
RelationSummary,
VectorChunk,
)
# ---------------------------------------------------------------------------
# textualize_subgraph 测试
# ---------------------------------------------------------------------------
class TestTextualizeSubgraph:
"""textualize_subgraph 函数的测试。"""
def test_single_relation(self):
entities = [
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
EntitySummary(id="2", name="user_id", type="Field"),
]
relations = [
RelationSummary(
source_name="用户行为数据",
source_type="Dataset",
target_name="user_id",
target_type="Field",
relation_type="HAS_FIELD",
),
]
result = textualize_subgraph(entities, relations)
assert "Dataset'用户行为数据'包含字段Field'user_id'" in result
def test_multiple_relations(self):
entities = [
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
EntitySummary(id="2", name="清洗管道", type="Workflow"),
]
relations = [
RelationSummary(
source_name="清洗管道",
source_type="Workflow",
target_name="用户行为数据",
target_type="Dataset",
relation_type="USES_DATASET",
),
RelationSummary(
source_name="用户行为数据",
source_type="Dataset",
target_name="外部系统",
target_type="DataSource",
relation_type="SOURCED_FROM",
),
]
result = textualize_subgraph(entities, relations)
assert "Workflow'清洗管道'使用了数据集Dataset'用户行为数据'" in result
assert "Dataset'用户行为数据'的知识来源于DataSource'外部系统'" in result
def test_all_relation_templates(self):
"""验证所有 10 种关系模板都能正确生成。"""
for rel_type, template in RELATION_TEMPLATES.items():
relations = [
RelationSummary(
source_name="A",
source_type="TypeA",
target_name="B",
target_type="TypeB",
relation_type=rel_type,
),
]
result = textualize_subgraph([], relations)
assert "TypeA'A'" in result
assert "TypeB'B'" in result
assert result # 非空
def test_unknown_relation_type(self):
"""未知关系类型使用通用模板。"""
relations = [
RelationSummary(
source_name="X",
source_type="T1",
target_name="Y",
target_type="T2",
relation_type="CUSTOM_REL",
),
]
result = textualize_subgraph([], relations)
assert "T1'X'与T2'Y'存在CUSTOM_REL关系" in result
def test_orphan_entity_with_description(self):
"""无关系的独立实体(有描述)。"""
entities = [
EntitySummary(id="1", name="孤立实体", type="Dataset", description="这是一个测试实体"),
]
result = textualize_subgraph(entities, [])
assert "Dataset'孤立实体': 这是一个测试实体" in result
def test_orphan_entity_without_description(self):
"""无关系的独立实体(无描述)。"""
entities = [
EntitySummary(id="1", name="孤立实体", type="Dataset"),
]
result = textualize_subgraph(entities, [])
assert "存在Dataset'孤立实体'" in result
def test_empty_inputs(self):
result = textualize_subgraph([], [])
assert result == ""
def test_entity_with_relation_not_orphan(self):
"""有关系的实体不应出现在独立实体部分。"""
entities = [
EntitySummary(id="1", name="A", type="Dataset"),
EntitySummary(id="2", name="B", type="Field"),
EntitySummary(id="3", name="C", type="Workflow"),
]
relations = [
RelationSummary(
source_name="A",
source_type="Dataset",
target_name="B",
target_type="Field",
relation_type="HAS_FIELD",
),
]
result = textualize_subgraph(entities, relations)
# A 和 B 有关系,不应作为独立实体出现
# C 无关系,应出现
assert "存在Workflow'C'" in result
lines = result.strip().split("\n")
assert len(lines) == 2 # 一条关系 + 一个独立实体
# ---------------------------------------------------------------------------
# build_context 测试
# ---------------------------------------------------------------------------
class TestBuildContext:
"""build_context 函数的测试。"""
def test_both_vector_and_graph(self):
chunks = [
VectorChunk(id="1", text="文档片段一", score=0.9),
VectorChunk(id="2", text="文档片段二", score=0.8),
]
graph_text = "Dataset'用户数据'包含字段Field'user_id'"
result = build_context(chunks, graph_text)
assert "## 相关文档" in result
assert "[1] 文档片段一" in result
assert "[2] 文档片段二" in result
assert "## 知识图谱上下文" in result
assert graph_text in result
def test_vector_only(self):
chunks = [VectorChunk(id="1", text="文档片段", score=0.9)]
result = build_context(chunks, "")
assert "## 相关文档" in result
assert "## 知识图谱上下文" not in result
def test_graph_only(self):
result = build_context([], "图谱内容")
assert "## 知识图谱上下文" in result
assert "## 相关文档" not in result
def test_empty_both(self):
result = build_context([], "")
assert "未检索到相关上下文信息" in result
def test_context_section_order(self):
"""验证文档在图谱之前。"""
chunks = [VectorChunk(id="1", text="doc", score=0.9)]
result = build_context(chunks, "graph")
doc_pos = result.index("## 相关文档")
graph_pos = result.index("## 知识图谱上下文")
assert doc_pos < graph_pos

View File

@@ -0,0 +1,300 @@
"""GraphRAG API 端点回归测试。
验证 /graphrag/query、/graphrag/retrieve、/graphrag/query/stream 端点
的权限校验行为,确保 collection_name 不一致时返回 403 且不进入检索链路。
"""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.testclient import TestClient
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.exception import (
fastapi_http_exception_handler,
starlette_http_exception_handler,
validation_exception_handler,
)
from app.module.kg_graphrag.interface import router
from app.module.kg_graphrag.models import (
GraphContext,
RetrievalContext,
)
# ---------------------------------------------------------------------------
# 测试用 FastAPI 应用(仅挂载 graphrag router + 异常处理器)
# ---------------------------------------------------------------------------
_app = FastAPI()
_app.include_router(router, prefix="/api")
_app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler)
_app.add_exception_handler(HTTPException, fastapi_http_exception_handler)
_app.add_exception_handler(RequestValidationError, validation_exception_handler)
_VALID_GRAPH_ID = "12345678-1234-1234-1234-123456789abc"
_VALID_BODY = {
"query": "测试查询",
"knowledge_base_id": "kb-1",
"collection_name": "test-collection",
"graph_id": _VALID_GRAPH_ID,
}
_HEADERS = {"X-User-Id": "user-1"}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fake_retrieval_context() -> RetrievalContext:
return RetrievalContext(
vector_chunks=[],
graph_context=GraphContext(),
merged_text="test context",
)
def _make_retriever_mock() -> AsyncMock:
m = AsyncMock()
m.retrieve = AsyncMock(return_value=_fake_retrieval_context())
return m
def _make_generator_mock() -> AsyncMock:
m = AsyncMock()
m.generate = AsyncMock(return_value="test answer")
m.model_name = "test-model"
async def _stream(*, query: str, context: str): # noqa: ARG001
for token in ["hello", " ", "world"]:
yield token
m.generate_stream = _stream
return m
def _make_kb_validator_mock(*, access_granted: bool = True) -> AsyncMock:
m = AsyncMock()
m.check_access = AsyncMock(return_value=access_granted)
return m
def _patch_all(
*,
access_granted: bool = True,
retriever: AsyncMock | None = None,
generator: AsyncMock | None = None,
validator: AsyncMock | None = None,
):
"""返回 context manager,统一 patch 三个懒加载工厂函数。"""
retriever = retriever or _make_retriever_mock()
generator = generator or _make_generator_mock()
validator = validator or _make_kb_validator_mock(access_granted=access_granted)
class _Ctx:
def __init__(self):
self.retriever = retriever
self.generator = generator
self.validator = validator
self._patches = [
patch("app.module.kg_graphrag.interface._get_retriever", return_value=retriever),
patch("app.module.kg_graphrag.interface._get_generator", return_value=generator),
patch("app.module.kg_graphrag.interface._get_kb_validator", return_value=validator),
]
def __enter__(self):
for p in self._patches:
p.__enter__()
return self
def __exit__(self, *args):
for p in reversed(self._patches):
p.__exit__(*args)
return _Ctx()
@pytest.fixture
def client():
return TestClient(_app)
# ---------------------------------------------------------------------------
# POST /api/graphrag/query
# ---------------------------------------------------------------------------
class TestQueryEndpoint:
"""POST /api/graphrag/query 端点测试。"""
def test_success(self, client: TestClient):
"""权限校验通过 + 检索 + 生成 → 200。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
body = resp.json()
assert body["code"] == 200
assert body["data"]["answer"] == "test answer"
assert body["data"]["model"] == "test-model"
ctx.retriever.retrieve.assert_awaited_once()
ctx.generator.generate.assert_awaited_once()
def test_access_denied_returns_403(self, client: TestClient):
"""check_access 返回 False → 403 + 标准错误格式。"""
with _patch_all(access_granted=False):
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
assert "kb-1" in body["data"]["detail"]
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
"""权限拒绝时,retriever.retrieve 和 generator.generate 均不调用。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
ctx.generator.generate.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 被调用时携带正确的 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/query", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id 请求头 → 422 验证错误。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/query", json=_VALID_BODY)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/graphrag/retrieve
# ---------------------------------------------------------------------------
class TestRetrieveEndpoint:
"""POST /api/graphrag/retrieve 端点测试。"""
def test_success(self, client: TestClient):
"""权限通过 → 检索 → 返回 RetrievalContext。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
body = resp.json()
assert body["code"] == 200
assert body["data"]["merged_text"] == "test context"
ctx.retriever.retrieve.assert_awaited_once()
def test_access_denied_returns_403(self, client: TestClient):
"""权限拒绝 → 403。"""
with _patch_all(access_granted=False):
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
def test_access_denied_skips_retrieval(self, client: TestClient):
"""权限拒绝时不调用 retriever.retrieve。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 收到 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY, headers=_HEADERS)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id → 422。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/retrieve", json=_VALID_BODY)
assert resp.status_code == 422
# ---------------------------------------------------------------------------
# POST /api/graphrag/query/stream
# ---------------------------------------------------------------------------
class TestQueryStreamEndpoint:
"""POST /api/graphrag/query/stream 端点测试。"""
def test_success_returns_sse(self, client: TestClient):
"""权限通过 → SSE 流式响应,包含 token 和 done 事件。"""
with _patch_all(access_granted=True):
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
text = resp.text
assert '"token"' in text
assert '"done": true' in text or '"done":true' in text
def test_access_denied_returns_403(self, client: TestClient):
"""权限拒绝 → 403。"""
with _patch_all(access_granted=False):
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 403
body = resp.json()
assert body["code"] == 403
def test_access_denied_skips_retrieval_and_generation(self, client: TestClient):
"""权限拒绝时不调用检索和生成。"""
with _patch_all(access_granted=False) as ctx:
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 403
ctx.retriever.retrieve.assert_not_called()
def test_check_access_receives_collection_name(self, client: TestClient):
"""验证 check_access 收到 collection_name 参数。"""
with _patch_all(access_granted=True) as ctx:
resp = client.post(
"/api/graphrag/query/stream", json=_VALID_BODY, headers=_HEADERS,
)
assert resp.status_code == 200
ctx.validator.check_access.assert_awaited_once_with(
"kb-1", "user-1", collection_name="test-collection",
)
def test_missing_user_id_returns_422(self, client: TestClient):
"""缺少 X-User-Id → 422。"""
with _patch_all(access_granted=True):
resp = client.post("/api/graphrag/query/stream", json=_VALID_BODY)
assert resp.status_code == 422

View File

@@ -0,0 +1,330 @@
"""知识库访问权限校验的单元测试。"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
@pytest.fixture
def validator() -> KnowledgeBaseAccessValidator:
return KnowledgeBaseAccessValidator(
base_url="http://test-backend:8080/api",
timeout=5.0,
)
def _run(coro):
return asyncio.run(coro)
_FAKE_REQUEST = httpx.Request("GET", "http://test")
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
"""创建带 request 的 httpx.Response。"""
if json is not None:
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
# ---------------------------------------------------------------------------
# check_access 测试
# ---------------------------------------------------------------------------
class TestCheckAccess:
"""check_access 方法的测试。"""
def test_access_granted(self, validator: KnowledgeBaseAccessValidator):
"""Java 返回 200 + code=200: 用户有权访问。"""
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "test-kb"}})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "user-1"))
assert result is True
def test_access_granted_with_matching_collection(self, validator: KnowledgeBaseAccessValidator):
"""权限通过且 collection_name 与 KB name 一致:允许访问。"""
mock_resp = _resp(200, json={"code": 200, "data": {"id": "kb-1", "name": "my-collection"}})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-1", "user-1", collection_name="my-collection",
))
assert result is True
def test_access_denied_by_biz_code(self, validator: KnowledgeBaseAccessValidator):
"""Java 返回 HTTP 200 但 code != 200(权限不足 sys.0005)。"""
mock_resp = _resp(200, json={"code": "sys.0005", "message": "权限不足"})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "other-user"))
assert result is False
def test_access_denied_http_403(self, validator: KnowledgeBaseAccessValidator):
"""Java 返回 HTTP 403。"""
mock_resp = _resp(403, text="Forbidden")
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "user-1"))
assert result is False
def test_kb_not_found_http_404(self, validator: KnowledgeBaseAccessValidator):
"""知识库不存在,Java 返回 404。"""
mock_resp = _resp(404, text="Not Found")
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("nonexistent-kb", "user-1"))
assert result is False
def test_server_error_http_500(self, validator: KnowledgeBaseAccessValidator):
"""Java 后端返回 500。"""
mock_resp = _resp(500, text="Internal Server Error")
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "user-1"))
assert result is False
def test_fail_close_on_connection_error(self, validator: KnowledgeBaseAccessValidator):
"""网络异常时 fail-close(拒绝访问),防止绕过权限校验。"""
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "user-1"))
assert result is False
def test_fail_close_on_timeout(self, validator: KnowledgeBaseAccessValidator):
"""超时时 fail-close(拒绝访问)。"""
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=httpx.ReadTimeout("timeout"))
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-1", "user-1"))
assert result is False
def test_request_headers(self, validator: KnowledgeBaseAccessValidator):
"""验证请求中携带正确的 X-User-Id header。"""
mock_resp = _resp(200, json={"code": 200, "data": {}})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
_run(validator.check_access("kb-123", "user-456"))
call_kwargs = mock_http.get.call_args
assert "/knowledge-base/kb-123" in call_kwargs.args[0]
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-456"
def test_cross_user_access_denied(self, validator: KnowledgeBaseAccessValidator):
"""跨用户访问:用户 B 试图访问用户 A 的知识库,应被拒绝。
模拟 Java 后端返回权限不足的业务错误。
"""
# 用户 A 创建的 KB,用户 B 请求访问
mock_resp = _resp(200, json={
"code": "sys.0005",
"message": "权限不足",
"data": None,
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-user-a", "user-b"))
assert result is False
# 确认请求携带的是用户 B 的 ID
call_kwargs = mock_http.get.call_args
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-b"
def test_admin_access_granted(self, validator: KnowledgeBaseAccessValidator):
"""管理员访问其他用户的知识库:Java 侧管理员跳过 owner 校验。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-user-a", "name": "用户A的知识库", "createdBy": "user-a"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access("kb-user-a", "admin-user"))
# Java 侧管理员校验通过,返回 200 + code=200
assert result is True
# ---------------------------------------------------------------------------
# collection_name 绑定校验测试
# ---------------------------------------------------------------------------
class TestCollectionNameBinding:
"""collection_name 与 knowledge_base_id 的绑定校验测试。
防止用户提交合法的 KB ID 但篡改 collection_name 来读取其他
知识库的 Milvus 数据。
"""
def test_collection_name_mismatch_denied(self, validator: KnowledgeBaseAccessValidator):
"""KB-A 的 name='collection-a',但请求传了 collection_name='collection-b':拒绝。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-a", "name": "collection-a"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-a", "user-1", collection_name="collection-b",
))
assert result is False
def test_collection_name_none_skips_check(self, validator: KnowledgeBaseAccessValidator):
"""collection_name=None 时不做绑定校验(向后兼容)。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-1", "name": "some-name"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
# 不传 collection_name → 仅校验权限,不校验绑定
result = _run(validator.check_access("kb-1", "user-1"))
assert result is True
def test_response_data_missing_name_denied(self, validator: KnowledgeBaseAccessValidator):
"""Java 响应 data 中没有 name 字段:fail-close 拒绝。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-1"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-1", "user-1", collection_name="any-collection",
))
# data.name is None, doesn't match "any-collection" → denied
assert result is False
def test_response_data_null_denied(self, validator: KnowledgeBaseAccessValidator):
"""Java 响应 data 为 null:fail-close 拒绝。"""
mock_resp = _resp(200, json={
"code": 200,
"data": None,
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-1", "user-1", collection_name="any-collection",
))
assert result is False
def test_response_data_empty_dict_denied(self, validator: KnowledgeBaseAccessValidator):
"""Java 响应 data 为空 dict {}:fail-close 拒绝。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-1", "user-1", collection_name="any-collection",
))
assert result is False
def test_cross_kb_collection_swap_denied(self, validator: KnowledgeBaseAccessValidator):
"""用户有权访问 KB-A(name='kb-a-data'),试图用 KB-A 的 ID 搭配 KB-B 的
collection_name='kb-b-data':应被拒绝。
这是核心越权场景的完整模拟。
"""
# 用户有权访问 KB-A
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-a", "name": "kb-a-data", "createdBy": "user-1"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
# 但 collection_name 指向 KB-B 的数据
result = _run(validator.check_access(
"kb-a", "user-1", collection_name="kb-b-data",
))
assert result is False
def test_chinese_collection_name_match(self, validator: KnowledgeBaseAccessValidator):
"""中文 collection_name 精确匹配。"""
mock_resp = _resp(200, json={
"code": 200,
"data": {"id": "kb-1", "name": "用户行为数据"},
})
with patch.object(validator, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
result = _run(validator.check_access(
"kb-1", "user-1", collection_name="用户行为数据",
))
assert result is True

View File

@@ -0,0 +1,297 @@
"""KG 服务 REST 客户端的单元测试。"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from app.module.kg_graphrag.kg_client import KGServiceClient
@pytest.fixture
def client() -> KGServiceClient:
return KGServiceClient(
base_url="http://test-kg:8080",
internal_token="test-token",
timeout=5.0,
)
def _run(coro):
return asyncio.run(coro)
_FAKE_REQUEST = httpx.Request("GET", "http://test")
def _resp(status_code: int, *, json=None, text=None) -> httpx.Response:
"""创建带 request 的 httpx.Response(raise_for_status 需要)。"""
if json is not None:
return httpx.Response(status_code, json=json, request=_FAKE_REQUEST)
return httpx.Response(status_code, text=text or "", request=_FAKE_REQUEST)
# ---------------------------------------------------------------------------
# fulltext_search 测试
# ---------------------------------------------------------------------------
class TestFulltextSearch:
"""fulltext_search 方法的测试。"""
def test_wrapped_paged_response(self, client: KGServiceClient):
"""Java 返回被全局包装的 PagedResponse: {"code": 200, "data": {"content": [...]}}"""
mock_body = {
"code": 200,
"data": {
"page": 0,
"size": 20,
"totalElements": 2,
"totalPages": 1,
"content": [
{"id": "e1", "name": "用户数据", "type": "Dataset", "description": "用户行为", "score": 2.5},
{"id": "e2", "name": "清洗管道", "type": "Workflow", "description": "", "score": 1.8},
],
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "用户数据", size=10, user_id="u1"))
assert len(entities) == 2
assert entities[0].id == "e1"
assert entities[0].name == "用户数据"
assert entities[0].type == "Dataset"
assert entities[1].name == "清洗管道"
def test_unwrapped_paged_response(self, client: KGServiceClient):
"""Java 直接返回 PagedResponse(无全局包装)。"""
mock_body = {
"page": 0,
"size": 10,
"totalElements": 1,
"totalPages": 1,
"content": [
{"id": "e1", "name": "A", "type": "Dataset", "description": "desc"},
],
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "A"))
# body has no "data" key → fallback to body itself → read "content"
assert len(entities) == 1
assert entities[0].name == "A"
def test_empty_content(self, client: KGServiceClient):
mock_body = {"code": 200, "data": {"page": 0, "content": []}}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "nothing"))
assert entities == []
def test_fail_open_on_http_error(self, client: KGServiceClient):
"""HTTP 错误时 fail-open 返回空列表。"""
mock_resp = _resp(500, text="Internal Server Error")
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "test"))
assert entities == []
def test_fail_open_on_connection_error(self, client: KGServiceClient):
"""连接错误时 fail-open 返回空列表。"""
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(side_effect=httpx.ConnectError("connection refused"))
mock_get.return_value = mock_http
entities = _run(client.fulltext_search("graph-1", "test"))
assert entities == []
def test_request_headers(self, client: KGServiceClient):
"""验证请求中携带正确的 headers。"""
mock_resp = _resp(200, json={"data": {"content": []}})
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.get = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
_run(client.fulltext_search("gid", "q", size=5, user_id="user-123"))
call_kwargs = mock_http.get.call_args
assert call_kwargs.kwargs["headers"]["X-Internal-Token"] == "test-token"
assert call_kwargs.kwargs["headers"]["X-User-Id"] == "user-123"
assert call_kwargs.kwargs["params"] == {"q": "q", "size": 5}
# ---------------------------------------------------------------------------
# get_subgraph 测试
# ---------------------------------------------------------------------------
class TestGetSubgraph:
"""get_subgraph 方法的测试。"""
def test_wrapped_subgraph_response(self, client: KGServiceClient):
"""Java 返回被全局包装的 SubgraphExportVO。"""
mock_body = {
"code": 200,
"data": {
"nodes": [
{"id": "n1", "name": "用户数据", "type": "Dataset", "description": "desc1", "properties": {}},
{"id": "n2", "name": "user_id", "type": "Field", "description": "", "properties": {}},
],
"edges": [
{
"id": "edge1",
"sourceEntityId": "n1",
"targetEntityId": "n2",
"relationType": "HAS_FIELD",
"weight": 1.0,
"confidence": 0.9,
"sourceId": "kb-1",
},
],
"nodeCount": 2,
"edgeCount": 1,
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"], depth=2, user_id="u1"))
assert len(entities) == 2
assert entities[0].name == "用户数据"
assert entities[1].name == "user_id"
assert len(relations) == 1
assert relations[0].source_name == "用户数据"
assert relations[0].target_name == "user_id"
assert relations[0].relation_type == "HAS_FIELD"
assert relations[0].source_type == "Dataset"
assert relations[0].target_type == "Field"
def test_unwrapped_subgraph_response(self, client: KGServiceClient):
"""Java 直接返回 SubgraphExportVO(无全局包装)。"""
mock_body = {
"nodes": [
{"id": "n1", "name": "A", "type": "T1", "description": ""},
],
"edges": [],
"nodeCount": 1,
"edgeCount": 0,
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert len(entities) == 1
assert entities[0].name == "A"
assert relations == []
def test_edge_with_unknown_entity(self, client: KGServiceClient):
"""边引用的实体不在 nodes 列表中时,使用 ID 作为 fallback。"""
mock_body = {
"code": 200,
"data": {
"nodes": [{"id": "n1", "name": "A", "type": "T1", "description": ""}],
"edges": [
{
"sourceEntityId": "n1",
"targetEntityId": "n999",
"relationType": "DEPENDS_ON",
},
],
},
}
mock_resp = _resp(200, json=mock_body)
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert len(relations) == 1
assert relations[0].source_name == "A"
assert relations[0].target_name == "n999" # fallback to ID
assert relations[0].target_type == ""
def test_fail_open_on_error(self, client: KGServiceClient):
mock_resp = _resp(500, text="error")
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
entities, relations = _run(client.get_subgraph("gid", ["n1"]))
assert entities == []
assert relations == []
def test_request_params(self, client: KGServiceClient):
"""验证子图请求参数正确传递。"""
mock_resp = _resp(200, json={"data": {"nodes": [], "edges": []}})
with patch.object(client, "_get_client") as mock_get:
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_resp)
mock_get.return_value = mock_http
_run(client.get_subgraph("gid", ["e1", "e2"], depth=3, user_id="u1"))
call_kwargs = mock_http.post.call_args
assert "/knowledge-graph/gid/query/subgraph/export" in call_kwargs.args[0]
assert call_kwargs.kwargs["params"] == {"depth": 3}
assert call_kwargs.kwargs["json"] == {"entityIds": ["e1", "e2"]}
# ---------------------------------------------------------------------------
# headers 测试
# ---------------------------------------------------------------------------
class TestHeaders:
def test_headers_with_token_and_user(self, client: KGServiceClient):
headers = client._headers(user_id="user-1")
assert headers["X-Internal-Token"] == "test-token"
assert headers["X-User-Id"] == "user-1"
def test_headers_without_user(self, client: KGServiceClient):
headers = client._headers()
assert "X-Internal-Token" in headers
assert "X-User-Id" not in headers
def test_headers_without_token(self):
c = KGServiceClient(base_url="http://test:8080", internal_token="")
headers = c._headers(user_id="u1")
assert "X-Internal-Token" not in headers
assert headers["X-User-Id"] == "u1"

View File

@@ -0,0 +1,145 @@
"""Milvus 向量检索客户端的单元测试。"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
@pytest.fixture
def retriever() -> MilvusVectorRetriever:
return MilvusVectorRetriever(
uri="http://test-milvus:19530",
embedding_model="text-embedding-test",
)
def _run(coro):
return asyncio.run(coro)
# ---------------------------------------------------------------------------
# has_collection 测试
# ---------------------------------------------------------------------------
class TestHasCollection:
def test_collection_exists(self, retriever: MilvusVectorRetriever):
mock_client = MagicMock()
mock_client.has_collection = MagicMock(return_value=True)
retriever._milvus_client = mock_client
result = _run(retriever.has_collection("my_collection"))
assert result is True
def test_collection_not_exists(self, retriever: MilvusVectorRetriever):
mock_client = MagicMock()
mock_client.has_collection = MagicMock(return_value=False)
retriever._milvus_client = mock_client
result = _run(retriever.has_collection("nonexistent"))
assert result is False
def test_fail_open_on_error(self, retriever: MilvusVectorRetriever):
mock_client = MagicMock()
mock_client.has_collection = MagicMock(side_effect=Exception("connection error"))
retriever._milvus_client = mock_client
result = _run(retriever.has_collection("test"))
assert result is False
# ---------------------------------------------------------------------------
# search 测试
# ---------------------------------------------------------------------------
class TestSearch:
def test_successful_search(self, retriever: MilvusVectorRetriever):
"""正常搜索返回 VectorChunk 列表。"""
mock_embeddings = AsyncMock()
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1, 0.2, 0.3])
retriever._embeddings = mock_embeddings
mock_milvus = MagicMock()
mock_milvus.search = MagicMock(return_value=[
[
{"id": "doc1", "distance": 0.95, "entity": {"text": "文档片段一", "metadata": {"source": "kb1"}}},
{"id": "doc2", "distance": 0.82, "entity": {"text": "文档片段二", "metadata": {}}},
]
])
retriever._milvus_client = mock_milvus
chunks = _run(retriever.search("my_collection", "用户数据", top_k=5))
assert len(chunks) == 2
assert chunks[0].id == "doc1"
assert chunks[0].text == "文档片段一"
assert chunks[0].score == 0.95
assert chunks[0].metadata == {"source": "kb1"}
assert chunks[1].id == "doc2"
assert chunks[1].score == 0.82
def test_empty_results(self, retriever: MilvusVectorRetriever):
mock_embeddings = AsyncMock()
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
retriever._embeddings = mock_embeddings
mock_milvus = MagicMock()
mock_milvus.search = MagicMock(return_value=[[]])
retriever._milvus_client = mock_milvus
chunks = _run(retriever.search("col", "query"))
assert chunks == []
def test_fail_open_on_embedding_error(self, retriever: MilvusVectorRetriever):
"""Embedding 失败时 fail-open 返回空列表。"""
mock_embeddings = AsyncMock()
mock_embeddings.aembed_query = AsyncMock(side_effect=Exception("API error"))
retriever._embeddings = mock_embeddings
chunks = _run(retriever.search("col", "query"))
assert chunks == []
def test_fail_open_on_milvus_error(self, retriever: MilvusVectorRetriever):
"""Milvus 搜索失败时 fail-open 返回空列表。"""
mock_embeddings = AsyncMock()
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
retriever._embeddings = mock_embeddings
mock_milvus = MagicMock()
mock_milvus.search = MagicMock(side_effect=Exception("Milvus down"))
retriever._milvus_client = mock_milvus
chunks = _run(retriever.search("col", "query"))
assert chunks == []
def test_search_uses_to_thread(self, retriever: MilvusVectorRetriever):
"""验证搜索通过 asyncio.to_thread 执行同步 Milvus I/O。"""
mock_embeddings = AsyncMock()
mock_embeddings.aembed_query = AsyncMock(return_value=[0.1])
retriever._embeddings = mock_embeddings
mock_milvus = MagicMock()
mock_milvus.search = MagicMock(return_value=[[]])
retriever._milvus_client = mock_milvus
with patch("app.module.kg_graphrag.milvus_client.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread:
mock_to_thread.return_value = [[]]
chunks = _run(retriever.search("col", "query"))
# asyncio.to_thread 应该被调用来包装同步 Milvus 调用
mock_to_thread.assert_called_once()
call_args = mock_to_thread.call_args
assert call_args.args[0] == mock_milvus.search

View File

@@ -0,0 +1,234 @@
"""GraphRAG 检索编排器的单元测试。"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.module.kg_graphrag.models import (
EntitySummary,
RelationSummary,
RetrievalStrategy,
VectorChunk,
)
from app.module.kg_graphrag.retriever import GraphRAGRetriever
def _run(coro):
return asyncio.run(coro)
def _make_retriever(
*,
milvus_search_result: list[VectorChunk] | None = None,
milvus_has_collection: bool = True,
kg_fulltext_result: list[EntitySummary] | None = None,
kg_subgraph_result: tuple[list[EntitySummary], list[RelationSummary]] | None = None,
) -> GraphRAGRetriever:
"""创建带 mock 依赖的 retriever。"""
mock_milvus = AsyncMock()
mock_milvus.has_collection = AsyncMock(return_value=milvus_has_collection)
mock_milvus.search = AsyncMock(return_value=milvus_search_result or [])
mock_kg = AsyncMock()
mock_kg.fulltext_search = AsyncMock(return_value=kg_fulltext_result or [])
mock_kg.get_subgraph = AsyncMock(return_value=kg_subgraph_result or ([], []))
return GraphRAGRetriever(milvus_client=mock_milvus, kg_client=mock_kg)
# ---------------------------------------------------------------------------
# retrieve 测试
# ---------------------------------------------------------------------------
class TestRetrieve:
"""retrieve 方法的测试。"""
def test_both_vector_and_graph(self):
"""同时启用向量和图谱检索。"""
chunks = [
VectorChunk(id="c1", text="文档片段关于用户数据", score=0.9),
VectorChunk(id="c2", text="其他内容", score=0.7),
]
seed = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
entities = [
EntitySummary(id="e1", name="用户数据", type="Dataset"),
EntitySummary(id="e2", name="user_id", type="Field"),
]
relations = [
RelationSummary(
source_name="用户数据", source_type="Dataset",
target_name="user_id", target_type="Field",
relation_type="HAS_FIELD",
),
]
retriever = _make_retriever(
milvus_search_result=chunks,
kg_fulltext_result=seed,
kg_subgraph_result=(entities, relations),
)
ctx = _run(retriever.retrieve(
query="用户数据有哪些字段",
collection_name="kb1",
graph_id="graph-1",
strategy=RetrievalStrategy(),
user_id="u1",
))
assert len(ctx.vector_chunks) == 2
assert len(ctx.graph_context.entities) == 2
assert len(ctx.graph_context.relations) == 1
assert "用户数据" in ctx.graph_context.textualized
assert "## 相关文档" in ctx.merged_text
assert "## 知识图谱上下文" in ctx.merged_text
def test_vector_only(self):
"""仅启用向量检索。"""
chunks = [VectorChunk(id="c1", text="doc", score=0.9)]
retriever = _make_retriever(milvus_search_result=chunks)
strategy = RetrievalStrategy(enable_graph=False)
ctx = _run(retriever.retrieve(
query="test", collection_name="kb", graph_id="g",
strategy=strategy, user_id="u",
))
assert len(ctx.vector_chunks) == 1
assert ctx.graph_context.entities == []
# KG client should not be called
retriever._kg.fulltext_search.assert_not_called()
def test_graph_only(self):
"""仅启用图谱检索。"""
seed = [EntitySummary(id="e1", name="A", type="T")]
entities = [EntitySummary(id="e1", name="A", type="T")]
retriever = _make_retriever(
kg_fulltext_result=seed,
kg_subgraph_result=(entities, []),
)
strategy = RetrievalStrategy(enable_vector=False)
ctx = _run(retriever.retrieve(
query="test", collection_name="kb", graph_id="g",
strategy=strategy, user_id="u",
))
assert ctx.vector_chunks == []
assert len(ctx.graph_context.entities) == 1
retriever._milvus.search.assert_not_called()
def test_no_seed_entities(self):
"""图谱全文检索无结果时,不调用子图查询。"""
retriever = _make_retriever(kg_fulltext_result=[])
ctx = _run(retriever.retrieve(
query="test", collection_name="kb", graph_id="g",
strategy=RetrievalStrategy(enable_vector=False), user_id="u",
))
assert ctx.graph_context.entities == []
retriever._kg.get_subgraph.assert_not_called()
def test_collection_not_found_skips_vector(self):
"""collection 不存在时跳过向量检索。"""
retriever = _make_retriever(milvus_has_collection=False)
strategy = RetrievalStrategy(enable_graph=False)
ctx = _run(retriever.retrieve(
query="test", collection_name="nonexistent", graph_id="g",
strategy=strategy, user_id="u",
))
assert ctx.vector_chunks == []
retriever._milvus.search.assert_not_called()
def test_both_empty(self):
"""两条检索路径都无结果。"""
retriever = _make_retriever()
ctx = _run(retriever.retrieve(
query="nothing", collection_name="kb", graph_id="g",
strategy=RetrievalStrategy(), user_id="u",
))
assert ctx.vector_chunks == []
assert ctx.graph_context.entities == []
assert "未检索到相关上下文信息" in ctx.merged_text
def test_vector_error_fail_open(self):
"""向量检索异常时 fail-open,图谱检索仍可正常返回。"""
retriever = _make_retriever()
retriever._milvus.search = AsyncMock(side_effect=Exception("milvus down"))
seed = [EntitySummary(id="e1", name="A", type="T")]
retriever._kg.fulltext_search = AsyncMock(return_value=seed)
retriever._kg.get_subgraph = AsyncMock(
return_value=([EntitySummary(id="e1", name="A", type="T")], [])
)
ctx = _run(retriever.retrieve(
query="test", collection_name="kb", graph_id="g",
strategy=RetrievalStrategy(), user_id="u",
))
# 向量检索失败,但图谱检索仍有结果
assert ctx.vector_chunks == []
assert len(ctx.graph_context.entities) == 1
# ---------------------------------------------------------------------------
# _rank_results 测试
# ---------------------------------------------------------------------------
class TestRankResults:
"""_rank_results 方法的测试。"""
def _make_retriever_instance(self) -> GraphRAGRetriever:
return GraphRAGRetriever(
milvus_client=MagicMock(),
kg_client=MagicMock(),
)
def test_empty_chunks(self):
r = self._make_retriever_instance()
result = r._rank_results([], [], [], RetrievalStrategy())
assert result == []
def test_single_chunk(self):
r = self._make_retriever_instance()
chunks = [VectorChunk(id="1", text="text", score=0.9)]
result = r._rank_results(chunks, [], [], RetrievalStrategy())
assert len(result) == 1
assert result[0].id == "1"
def test_graph_boost_reorders(self):
"""图谱实体命中应提升文档片段排名。"""
r = self._make_retriever_instance()
# chunk1 向量分高但无图谱命中
# chunk2 向量分低但命中图谱实体
chunks = [
VectorChunk(id="1", text="无关内容", score=0.9),
VectorChunk(id="2", text="包含用户数据的内容", score=0.5),
]
entities = [EntitySummary(id="e1", name="用户数据", type="Dataset")]
strategy = RetrievalStrategy(vector_weight=0.3, graph_weight=0.7)
result = r._rank_results(chunks, entities, [], strategy)
# chunk2 应该排在前面(graph_boost 更高)
assert result[0].id == "2"
def test_all_same_score(self):
"""所有 chunk 分数相同时不崩溃。"""
r = self._make_retriever_instance()
chunks = [
VectorChunk(id="1", text="a", score=0.5),
VectorChunk(id="2", text="b", score=0.5),
]
result = r._rank_results(chunks, [], [], RetrievalStrategy())
assert len(result) == 2

View File

@@ -36,6 +36,7 @@ dependencies = [
"sqlalchemy (>=2.0.45,<3.0.0)",
"fastapi (>=0.124.0,<0.125.0)",
"Pillow (>=11.0.0,<12.0.0)",
"pymilvus (>=2.5.0,<3.0.0)",
]