You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
118
runtime/datamate-python/app/module/kg_graphrag/kb_access.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""知识库访问权限校验。
|
||||
|
||||
在执行 GraphRAG 检索前,调用 Java rag-indexer-service 的
|
||||
GET /knowledge-base/{id} 端点验证当前用户是否有权访问该知识库。
|
||||
|
||||
Java 侧实现参考:KnowledgeBaseService.getKnowledgeBaseWithAccessCheck()
|
||||
- 查找 KB 是否存在
|
||||
- 校验 createdBy == currentUserId(管理员跳过)
|
||||
- 不满足则抛出 sys.0005 (INSUFFICIENT_PERMISSIONS)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KnowledgeBaseAccessValidator:
|
||||
"""通过 Java 后端校验用户是否有权访问指定知识库。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "http://datamate-backend:8080/api",
|
||||
timeout: float = 10.0,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KnowledgeBaseAccessValidator:
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(base_url=settings.datamate_backend_base_url)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def check_access(
|
||||
self,
|
||||
knowledge_base_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
collection_name: str | None = None,
|
||||
) -> bool:
|
||||
"""校验用户是否有权访问指定知识库。
|
||||
|
||||
调用 Java 后端 GET /knowledge-base/{id},该端点内部执行
|
||||
owner 校验(createdBy == currentUserId,管理员跳过)。
|
||||
|
||||
当 *collection_name* 不为 None 时,还会校验请求中的
|
||||
collection_name 与该知识库实际的 name 是否一致,防止
|
||||
用户提交合法 KB ID 但篡改 collection_name 来访问
|
||||
其他知识库的 Milvus 数据。
|
||||
|
||||
Returns:
|
||||
True — 用户有权访问且 collection_name 匹配
|
||||
False — 无权访问、collection_name 不匹配或校验失败
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"/knowledge-base/{knowledge_base_id}",
|
||||
headers={"X-User-Id": user_id},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
body = resp.json()
|
||||
# Java 全局包装: {"code": 200, "data": {...}}
|
||||
# code != 200 说明业务层拒绝(如权限不足)
|
||||
code = body.get("code", resp.status_code)
|
||||
if code != 200:
|
||||
logger.warning(
|
||||
"KB access denied: kb_id=%s, user=%s, biz_code=%s, msg=%s",
|
||||
knowledge_base_id, user_id, code, body.get("message", ""),
|
||||
)
|
||||
return False
|
||||
|
||||
# 校验 collection_name 与 KB 实际名称的绑定关系
|
||||
if collection_name is not None:
|
||||
data = body.get("data") or {}
|
||||
actual_name = data.get("name") if isinstance(data, dict) else None
|
||||
if actual_name != collection_name:
|
||||
logger.warning(
|
||||
"KB collection_name mismatch: kb_id=%s, "
|
||||
"expected=%s, actual=%s, user=%s",
|
||||
knowledge_base_id, collection_name,
|
||||
actual_name, user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
# HTTP 4xx/5xx
|
||||
logger.warning(
|
||||
"KB access check returned HTTP %d: kb_id=%s, user=%s",
|
||||
resp.status_code, knowledge_base_id, user_id,
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
# 网络异常时 fail-close:拒绝访问,防止绕过权限
|
||||
logger.exception(
|
||||
"KB access check failed (fail-close): kb_id=%s, user=%s",
|
||||
knowledge_base_id, user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
Reference in New Issue
Block a user