"""GraphRAG 融合查询 API 端点。 提供向量检索 + 知识图谱的融合查询能力: - POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成) - POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM) - POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE) """ from __future__ import annotations import uuid from typing import Annotated from fastapi import APIRouter, Depends, Header, HTTPException from fastapi.responses import StreamingResponse from app.core.logging import get_logger from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator from app.module.kg_graphrag.models import ( GraphRAGQueryRequest, GraphRAGQueryResponse, RetrievalContext, ) from app.module.kg_graphrag.retriever import GraphRAGRetriever from app.module.kg_graphrag.generator import GraphRAGGenerator from app.module.shared.schema import StandardResponse router = APIRouter(prefix="/graphrag", tags=["graphrag"]) logger = get_logger(__name__) # 延迟初始化 _retriever: GraphRAGRetriever | None = None _generator: GraphRAGGenerator | None = None _kb_validator: KnowledgeBaseAccessValidator | None = None def _get_retriever() -> GraphRAGRetriever: global _retriever if _retriever is None: _retriever = GraphRAGRetriever.from_settings() return _retriever def _get_generator() -> GraphRAGGenerator: global _generator if _generator is None: _generator = GraphRAGGenerator.from_settings() return _generator def _get_kb_validator() -> KnowledgeBaseAccessValidator: global _kb_validator if _kb_validator is None: _kb_validator = KnowledgeBaseAccessValidator.from_settings() return _kb_validator def _require_caller_id( x_user_id: Annotated[ str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"), ], ) -> str: caller = x_user_id.strip() if not caller: raise HTTPException(status_code=401, detail="Missing required header: X-User-Id") return caller # --------------------------------------------------------------------------- # P0: 完整 GraphRAG 查询 # --------------------------------------------------------------------------- @router.post( "/query", response_model=StandardResponse[GraphRAGQueryResponse], summary="GraphRAG 查询", description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。", ) async def query( req: GraphRAGQueryRequest, caller: Annotated[str, Depends(_require_caller_id)], ): trace_id = uuid.uuid4().hex[:16] logger.info( "[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s", trace_id, req.graph_id, req.collection_name, caller, ) retriever = _get_retriever() generator = _get_generator() # 权限校验:验证用户是否有权访问该知识库 kb_validator = _get_kb_validator() if not await kb_validator.check_access( req.knowledge_base_id, caller, collection_name=req.collection_name, ): logger.warning( "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", trace_id, req.knowledge_base_id, req.collection_name, caller, ) raise HTTPException( status_code=403, detail=f"无权访问知识库 {req.knowledge_base_id}", ) try: context = await retriever.retrieve( query=req.query, collection_name=req.collection_name, graph_id=req.graph_id, strategy=req.strategy, user_id=caller, ) except Exception: logger.exception("[%s] Retrieval failed", trace_id) raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") try: answer = await generator.generate(query=req.query, context=context.merged_text) except Exception: logger.exception("[%s] Generation failed", trace_id) raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})") result = GraphRAGQueryResponse( answer=answer, context=context, model=generator.model_name, ) return StandardResponse(code=200, message="success", data=result) # --------------------------------------------------------------------------- # P1-1: 仅检索 # --------------------------------------------------------------------------- @router.post( "/retrieve", response_model=StandardResponse[RetrievalContext], summary="GraphRAG 仅检索", description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。", ) async def retrieve( req: GraphRAGQueryRequest, caller: Annotated[str, Depends(_require_caller_id)], ): trace_id = uuid.uuid4().hex[:16] logger.info( "[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s", trace_id, req.graph_id, req.collection_name, caller, ) retriever = _get_retriever() # 权限校验:验证用户是否有权访问该知识库 kb_validator = _get_kb_validator() if not await kb_validator.check_access( req.knowledge_base_id, caller, collection_name=req.collection_name, ): logger.warning( "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", trace_id, req.knowledge_base_id, req.collection_name, caller, ) raise HTTPException( status_code=403, detail=f"无权访问知识库 {req.knowledge_base_id}", ) try: context = await retriever.retrieve( query=req.query, collection_name=req.collection_name, graph_id=req.graph_id, strategy=req.strategy, user_id=caller, ) except Exception: logger.exception("[%s] Retrieval failed", trace_id) raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") return StandardResponse(code=200, message="success", data=context) # --------------------------------------------------------------------------- # P1-4: 流式查询 (SSE) # --------------------------------------------------------------------------- @router.post( "/query/stream", summary="GraphRAG 流式查询", description="并行检索后,通过 SSE 流式返回 LLM 生成内容。", ) async def query_stream( req: GraphRAGQueryRequest, caller: Annotated[str, Depends(_require_caller_id)], ): trace_id = uuid.uuid4().hex[:16] logger.info( "[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s", trace_id, req.graph_id, req.collection_name, caller, ) retriever = _get_retriever() generator = _get_generator() # 权限校验:验证用户是否有权访问该知识库 kb_validator = _get_kb_validator() if not await kb_validator.check_access( req.knowledge_base_id, caller, collection_name=req.collection_name, ): logger.warning( "[%s] KB access denied: kb_id=%s, collection=%s, caller=%s", trace_id, req.knowledge_base_id, req.collection_name, caller, ) raise HTTPException( status_code=403, detail=f"无权访问知识库 {req.knowledge_base_id}", ) try: context = await retriever.retrieve( query=req.query, collection_name=req.collection_name, graph_id=req.graph_id, strategy=req.strategy, user_id=caller, ) except Exception: logger.exception("[%s] Retrieval failed", trace_id) raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})") import json async def event_stream(): try: async for token in generator.generate_stream( query=req.query, context=context.merged_text ): yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n" # 结束事件:附带检索上下文 yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n" except Exception: logger.exception("[%s] Stream generation failed", trace_id) yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n" return StreamingResponse(event_stream(), media_type="text/event-stream") # --------------------------------------------------------------------------- # 缓存管理 # --------------------------------------------------------------------------- @router.get( "/cache/stats", response_model=StandardResponse[dict], summary="缓存统计", description="返回 GraphRAG 检索缓存的命中率和容量统计。", ) async def cache_stats(): from app.module.kg_graphrag.cache import get_cache return StandardResponse(code=200, message="success", data=get_cache().stats()) @router.post( "/cache/clear", response_model=StandardResponse[dict], summary="清空缓存", description="清空所有 GraphRAG 检索缓存。", ) async def cache_clear(): from app.module.kg_graphrag.cache import get_cache get_cache().clear() return StandardResponse(code=200, message="success", data={"cleared": True})