Files
DataMate/runtime/datamate-python/app/module/kg_graphrag/interface.py
Jerry Yan 75f9b95093
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (java-kotlin) (push) Has been cancelled
CodeQL Advanced / Analyze (javascript-typescript) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
feat(api): 添加 graphrag 权限规则和优化知识图谱缓存失效
- 在权限规则匹配器中添加 /api/graphrag/** 的读写权限控制
- 修改图关系服务中的删除操作以精确失效相关实体缓存
- 更新图同步服务确保 BELONGS_TO 关系在增量同步时正确重建
- 重构图同步步骤服务中的组织归属关系构建逻辑
- 修复前端图_canvas 组件中的元素点击事件处理逻辑
- 实现 Python GraphRAG 缓存的启用/禁用功能
- 为 GraphRAG 缓存统计和清除接口添加调用方日志记录
2026-02-24 09:25:31 +08:00

282 lines
9.4 KiB
Python

"""GraphRAG 融合查询 API 端点。
提供向量检索 + 知识图谱的融合查询能力:
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
"""
from __future__ import annotations
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from app.core.logging import get_logger
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
from app.module.kg_graphrag.models import (
GraphRAGQueryRequest,
GraphRAGQueryResponse,
RetrievalContext,
)
from app.module.kg_graphrag.retriever import GraphRAGRetriever
from app.module.kg_graphrag.generator import GraphRAGGenerator
from app.module.shared.schema import StandardResponse
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
logger = get_logger(__name__)
# 延迟初始化
_retriever: GraphRAGRetriever | None = None
_generator: GraphRAGGenerator | None = None
_kb_validator: KnowledgeBaseAccessValidator | None = None
def _get_retriever() -> GraphRAGRetriever:
global _retriever
if _retriever is None:
_retriever = GraphRAGRetriever.from_settings()
return _retriever
def _get_generator() -> GraphRAGGenerator:
global _generator
if _generator is None:
_generator = GraphRAGGenerator.from_settings()
return _generator
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
global _kb_validator
if _kb_validator is None:
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
return _kb_validator
def _require_caller_id(
x_user_id: Annotated[
str,
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
],
) -> str:
caller = x_user_id.strip()
if not caller:
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
return caller
# ---------------------------------------------------------------------------
# P0: 完整 GraphRAG 查询
# ---------------------------------------------------------------------------
@router.post(
"/query",
response_model=StandardResponse[GraphRAGQueryResponse],
summary="GraphRAG 查询",
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
)
async def query(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
generator = _get_generator()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
try:
answer = await generator.generate(query=req.query, context=context.merged_text)
except Exception:
logger.exception("[%s] Generation failed", trace_id)
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
result = GraphRAGQueryResponse(
answer=answer,
context=context,
model=generator.model_name,
)
return StandardResponse(code=200, message="success", data=result)
# ---------------------------------------------------------------------------
# P1-1: 仅检索
# ---------------------------------------------------------------------------
@router.post(
"/retrieve",
response_model=StandardResponse[RetrievalContext],
summary="GraphRAG 仅检索",
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
)
async def retrieve(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=context)
# ---------------------------------------------------------------------------
# P1-4: 流式查询 (SSE)
# ---------------------------------------------------------------------------
@router.post(
"/query/stream",
summary="GraphRAG 流式查询",
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
)
async def query_stream(
req: GraphRAGQueryRequest,
caller: Annotated[str, Depends(_require_caller_id)],
):
trace_id = uuid.uuid4().hex[:16]
logger.info(
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
trace_id, req.graph_id, req.collection_name, caller,
)
retriever = _get_retriever()
generator = _get_generator()
# 权限校验:验证用户是否有权访问该知识库
kb_validator = _get_kb_validator()
if not await kb_validator.check_access(
req.knowledge_base_id, caller, collection_name=req.collection_name,
):
logger.warning(
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
trace_id, req.knowledge_base_id, req.collection_name, caller,
)
raise HTTPException(
status_code=403,
detail=f"无权访问知识库 {req.knowledge_base_id}",
)
try:
context = await retriever.retrieve(
query=req.query,
collection_name=req.collection_name,
graph_id=req.graph_id,
strategy=req.strategy,
user_id=caller,
)
except Exception:
logger.exception("[%s] Retrieval failed", trace_id)
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
import json
async def event_stream():
try:
async for token in generator.generate_stream(
query=req.query, context=context.merged_text
):
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
# 结束事件:附带检索上下文
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
except Exception:
logger.exception("[%s] Stream generation failed", trace_id)
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
# ---------------------------------------------------------------------------
# 缓存管理
# ---------------------------------------------------------------------------
@router.get(
"/cache/stats",
response_model=StandardResponse[dict],
summary="缓存统计",
description="返回 GraphRAG 检索缓存的命中率和容量统计。",
)
async def cache_stats(caller: Annotated[str, Depends(_require_caller_id)]):
from app.module.kg_graphrag.cache import get_cache
logger.info("GraphRAG cache stats requested by caller=%s", caller)
return StandardResponse(code=200, message="success", data=get_cache().stats())
@router.post(
"/cache/clear",
response_model=StandardResponse[dict],
summary="清空缓存",
description="清空所有 GraphRAG 检索缓存。",
)
async def cache_clear(caller: Annotated[str, Depends(_require_caller_id)]):
from app.module.kg_graphrag.cache import get_cache
logger.info("GraphRAG cache clear requested by caller=%s", caller)
get_cache().clear()
return StandardResponse(code=200, message="success", data={"cleared": True})