You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
249
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
249
runtime/datamate-python/app/module/kg_graphrag/interface.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""GraphRAG 融合查询 API 端点。
|
||||
|
||||
提供向量检索 + 知识图谱的融合查询能力:
|
||||
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
|
||||
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
|
||||
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
||||
from app.module.kg_graphrag.models import (
|
||||
GraphRAGQueryRequest,
|
||||
GraphRAGQueryResponse,
|
||||
RetrievalContext,
|
||||
)
|
||||
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
||||
from app.module.kg_graphrag.generator import GraphRAGGenerator
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 延迟初始化
|
||||
_retriever: GraphRAGRetriever | None = None
|
||||
_generator: GraphRAGGenerator | None = None
|
||||
_kb_validator: KnowledgeBaseAccessValidator | None = None
|
||||
|
||||
|
||||
def _get_retriever() -> GraphRAGRetriever:
|
||||
global _retriever
|
||||
if _retriever is None:
|
||||
_retriever = GraphRAGRetriever.from_settings()
|
||||
return _retriever
|
||||
|
||||
|
||||
def _get_generator() -> GraphRAGGenerator:
|
||||
global _generator
|
||||
if _generator is None:
|
||||
_generator = GraphRAGGenerator.from_settings()
|
||||
return _generator
|
||||
|
||||
|
||||
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
|
||||
global _kb_validator
|
||||
if _kb_validator is None:
|
||||
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
|
||||
return _kb_validator
|
||||
|
||||
|
||||
def _require_caller_id(
|
||||
x_user_id: Annotated[
|
||||
str,
|
||||
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
|
||||
],
|
||||
) -> str:
|
||||
caller = x_user_id.strip()
|
||||
if not caller:
|
||||
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
||||
return caller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P0: 完整 GraphRAG 查询
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=StandardResponse[GraphRAGQueryResponse],
|
||||
summary="GraphRAG 查询",
|
||||
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
|
||||
)
|
||||
async def query(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
generator = _get_generator()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
try:
|
||||
answer = await generator.generate(query=req.query, context=context.merged_text)
|
||||
except Exception:
|
||||
logger.exception("[%s] Generation failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
result = GraphRAGQueryResponse(
|
||||
answer=answer,
|
||||
context=context,
|
||||
model=generator.model_name,
|
||||
)
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P1-1: 仅检索
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/retrieve",
|
||||
response_model=StandardResponse[RetrievalContext],
|
||||
summary="GraphRAG 仅检索",
|
||||
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
|
||||
)
|
||||
async def retrieve(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
return StandardResponse(code=200, message="success", data=context)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# P1-4: 流式查询 (SSE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/query/stream",
|
||||
summary="GraphRAG 流式查询",
|
||||
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
|
||||
)
|
||||
async def query_stream(
|
||||
req: GraphRAGQueryRequest,
|
||||
caller: Annotated[str, Depends(_require_caller_id)],
|
||||
):
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info(
|
||||
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.graph_id, req.collection_name, caller,
|
||||
)
|
||||
|
||||
retriever = _get_retriever()
|
||||
generator = _get_generator()
|
||||
|
||||
# 权限校验:验证用户是否有权访问该知识库
|
||||
kb_validator = _get_kb_validator()
|
||||
if not await kb_validator.check_access(
|
||||
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
||||
):
|
||||
logger.warning(
|
||||
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
||||
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
||||
)
|
||||
|
||||
try:
|
||||
context = await retriever.retrieve(
|
||||
query=req.query,
|
||||
collection_name=req.collection_name,
|
||||
graph_id=req.graph_id,
|
||||
strategy=req.strategy,
|
||||
user_id=caller,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("[%s] Retrieval failed", trace_id)
|
||||
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
import json
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
async for token in generator.generate_stream(
|
||||
query=req.query, context=context.merged_text
|
||||
):
|
||||
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
|
||||
# 结束事件:附带检索上下文
|
||||
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
except Exception:
|
||||
logger.exception("[%s] Stream generation failed", trace_id)
|
||||
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
Reference in New Issue
Block a user