You've already forked DataMate
核心功能: - Neo4j 索引优化(entityType, graphId, properties.name) - Redis 缓存(Java 侧,3 个缓存区,TTL 可配置) - LRU 缓存(Python 侧,KG + Embedding,线程安全) - 细粒度缓存清除(graphId 前缀匹配) - 失败路径缓存清除(finally 块) 新增文件(Java 侧,7 个): - V2__PerformanceIndexes.java - Flyway 迁移,创建 3 个索引 - IndexHealthService.java - 索引健康监控 - RedisCacheConfig.java - Spring Cache + Redis 配置 - GraphCacheService.java - 缓存清除管理器 - CacheableIntegrationTest.java - 集成测试(10 tests) - GraphCacheServiceTest.java - 单元测试(19 tests) - V2__PerformanceIndexesTest.java, IndexHealthServiceTest.java 新增文件(Python 侧,2 个): - cache.py - 内存 TTL+LRU 缓存(cachetools) - test_cache.py - 单元测试(20 tests) 修改文件(Java 侧,9 个): - GraphEntityService.java - 添加 @Cacheable,缓存清除 - GraphQueryService.java - 添加 @Cacheable(包含用户权限上下文) - GraphRelationService.java - 添加缓存清除 - GraphSyncService.java - 添加缓存清除(finally 块,失败路径) - KnowledgeGraphProperties.java - 添加 Cache 配置类 - application-knowledgegraph.yml - 添加 Redis 和缓存 TTL 配置 - GraphEntityServiceTest.java - 添加 verify(cacheService) 断言 - GraphRelationServiceTest.java - 添加 verify(cacheService) 断言 - GraphSyncServiceTest.java - 添加失败路径缓存清除测试 修改文件(Python 侧,5 个): - kg_client.py - 集成缓存(fulltext_search, get_subgraph) - interface.py - 添加 /cache/stats 和 /cache/clear 端点 - config.py - 添加缓存配置字段 - pyproject.toml - 添加 cachetools 依赖 - test_kg_client.py - 添加 _disable_cache fixture 安全修复(3 轮迭代): - P0: 缓存 key 用户隔离(防止跨用户数据泄露) - P1-1: 同步子步骤后的缓存清除(18 个方法) - P1-2: 实体创建后的搜索缓存清除 - P1-3: 失败路径缓存清除(finally 块) - P2-1: 细粒度缓存清除(graphId 前缀匹配,避免跨图谱冲刷) - P2-2: 服务层测试添加 verify(cacheService) 断言 测试结果: - Java: 280 tests pass ✅ (270 → 280, +10 new) - Python: 154 tests pass ✅ (140 → 154, +14 new) 缓存配置: - kg:entities - 实体缓存,TTL 1h - kg:queries - 查询结果缓存,TTL 5min - kg:search - 全文搜索缓存,TTL 3min - KG cache (Python) - 256 entries, 5min TTL - Embedding cache (Python) - 512 entries, 10min TTL
280 lines
9.1 KiB
Python
280 lines
9.1 KiB
Python
"""GraphRAG 融合查询 API 端点。
|
|
|
|
提供向量检索 + 知识图谱的融合查询能力:
|
|
- POST /api/graphrag/query — 完整 GraphRAG 查询(检索+生成)
|
|
- POST /api/graphrag/retrieve — 仅检索(返回上下文,不调 LLM)
|
|
- POST /api/graphrag/query/stream — 流式 GraphRAG 查询(SSE)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.logging import get_logger
|
|
from app.module.kg_graphrag.kb_access import KnowledgeBaseAccessValidator
|
|
from app.module.kg_graphrag.models import (
|
|
GraphRAGQueryRequest,
|
|
GraphRAGQueryResponse,
|
|
RetrievalContext,
|
|
)
|
|
from app.module.kg_graphrag.retriever import GraphRAGRetriever
|
|
from app.module.kg_graphrag.generator import GraphRAGGenerator
|
|
from app.module.shared.schema import StandardResponse
|
|
|
|
router = APIRouter(prefix="/graphrag", tags=["graphrag"])
|
|
logger = get_logger(__name__)
|
|
|
|
# 延迟初始化
|
|
_retriever: GraphRAGRetriever | None = None
|
|
_generator: GraphRAGGenerator | None = None
|
|
_kb_validator: KnowledgeBaseAccessValidator | None = None
|
|
|
|
|
|
def _get_retriever() -> GraphRAGRetriever:
|
|
global _retriever
|
|
if _retriever is None:
|
|
_retriever = GraphRAGRetriever.from_settings()
|
|
return _retriever
|
|
|
|
|
|
def _get_generator() -> GraphRAGGenerator:
|
|
global _generator
|
|
if _generator is None:
|
|
_generator = GraphRAGGenerator.from_settings()
|
|
return _generator
|
|
|
|
|
|
def _get_kb_validator() -> KnowledgeBaseAccessValidator:
|
|
global _kb_validator
|
|
if _kb_validator is None:
|
|
_kb_validator = KnowledgeBaseAccessValidator.from_settings()
|
|
return _kb_validator
|
|
|
|
|
|
def _require_caller_id(
|
|
x_user_id: Annotated[
|
|
str,
|
|
Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递"),
|
|
],
|
|
) -> str:
|
|
caller = x_user_id.strip()
|
|
if not caller:
|
|
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
|
return caller
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P0: 完整 GraphRAG 查询
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post(
|
|
"/query",
|
|
response_model=StandardResponse[GraphRAGQueryResponse],
|
|
summary="GraphRAG 查询",
|
|
description="并行从向量库和知识图谱检索上下文,融合后调用 LLM 生成回答。",
|
|
)
|
|
async def query(
|
|
req: GraphRAGQueryRequest,
|
|
caller: Annotated[str, Depends(_require_caller_id)],
|
|
):
|
|
trace_id = uuid.uuid4().hex[:16]
|
|
logger.info(
|
|
"[%s] GraphRAG query: graph_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.graph_id, req.collection_name, caller,
|
|
)
|
|
|
|
retriever = _get_retriever()
|
|
generator = _get_generator()
|
|
|
|
# 权限校验:验证用户是否有权访问该知识库
|
|
kb_validator = _get_kb_validator()
|
|
if not await kb_validator.check_access(
|
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
|
):
|
|
logger.warning(
|
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
|
)
|
|
|
|
try:
|
|
context = await retriever.retrieve(
|
|
query=req.query,
|
|
collection_name=req.collection_name,
|
|
graph_id=req.graph_id,
|
|
strategy=req.strategy,
|
|
user_id=caller,
|
|
)
|
|
except Exception:
|
|
logger.exception("[%s] Retrieval failed", trace_id)
|
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
|
|
|
try:
|
|
answer = await generator.generate(query=req.query, context=context.merged_text)
|
|
except Exception:
|
|
logger.exception("[%s] Generation failed", trace_id)
|
|
raise HTTPException(status_code=502, detail=f"生成服务暂不可用 (trace: {trace_id})")
|
|
|
|
result = GraphRAGQueryResponse(
|
|
answer=answer,
|
|
context=context,
|
|
model=generator.model_name,
|
|
)
|
|
return StandardResponse(code=200, message="success", data=result)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P1-1: 仅检索
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post(
|
|
"/retrieve",
|
|
response_model=StandardResponse[RetrievalContext],
|
|
summary="GraphRAG 仅检索",
|
|
description="并行从向量库和知识图谱检索上下文,返回结构化上下文(不调 LLM)。",
|
|
)
|
|
async def retrieve(
|
|
req: GraphRAGQueryRequest,
|
|
caller: Annotated[str, Depends(_require_caller_id)],
|
|
):
|
|
trace_id = uuid.uuid4().hex[:16]
|
|
logger.info(
|
|
"[%s] GraphRAG retrieve: graph_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.graph_id, req.collection_name, caller,
|
|
)
|
|
|
|
retriever = _get_retriever()
|
|
|
|
# 权限校验:验证用户是否有权访问该知识库
|
|
kb_validator = _get_kb_validator()
|
|
if not await kb_validator.check_access(
|
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
|
):
|
|
logger.warning(
|
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
|
)
|
|
|
|
try:
|
|
context = await retriever.retrieve(
|
|
query=req.query,
|
|
collection_name=req.collection_name,
|
|
graph_id=req.graph_id,
|
|
strategy=req.strategy,
|
|
user_id=caller,
|
|
)
|
|
except Exception:
|
|
logger.exception("[%s] Retrieval failed", trace_id)
|
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
|
|
|
return StandardResponse(code=200, message="success", data=context)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# P1-4: 流式查询 (SSE)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post(
|
|
"/query/stream",
|
|
summary="GraphRAG 流式查询",
|
|
description="并行检索后,通过 SSE 流式返回 LLM 生成内容。",
|
|
)
|
|
async def query_stream(
|
|
req: GraphRAGQueryRequest,
|
|
caller: Annotated[str, Depends(_require_caller_id)],
|
|
):
|
|
trace_id = uuid.uuid4().hex[:16]
|
|
logger.info(
|
|
"[%s] GraphRAG stream: graph_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.graph_id, req.collection_name, caller,
|
|
)
|
|
|
|
retriever = _get_retriever()
|
|
generator = _get_generator()
|
|
|
|
# 权限校验:验证用户是否有权访问该知识库
|
|
kb_validator = _get_kb_validator()
|
|
if not await kb_validator.check_access(
|
|
req.knowledge_base_id, caller, collection_name=req.collection_name,
|
|
):
|
|
logger.warning(
|
|
"[%s] KB access denied: kb_id=%s, collection=%s, caller=%s",
|
|
trace_id, req.knowledge_base_id, req.collection_name, caller,
|
|
)
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail=f"无权访问知识库 {req.knowledge_base_id}",
|
|
)
|
|
|
|
try:
|
|
context = await retriever.retrieve(
|
|
query=req.query,
|
|
collection_name=req.collection_name,
|
|
graph_id=req.graph_id,
|
|
strategy=req.strategy,
|
|
user_id=caller,
|
|
)
|
|
except Exception:
|
|
logger.exception("[%s] Retrieval failed", trace_id)
|
|
raise HTTPException(status_code=502, detail=f"检索服务暂不可用 (trace: {trace_id})")
|
|
|
|
import json
|
|
|
|
async def event_stream():
|
|
try:
|
|
async for token in generator.generate_stream(
|
|
query=req.query, context=context.merged_text
|
|
):
|
|
yield f"data: {json.dumps({'token': token}, ensure_ascii=False)}\n\n"
|
|
# 结束事件:附带检索上下文
|
|
yield f"data: {json.dumps({'done': True, 'context': context.model_dump()}, ensure_ascii=False)}\n\n"
|
|
except Exception:
|
|
logger.exception("[%s] Stream generation failed", trace_id)
|
|
yield f"data: {json.dumps({'error': '生成服务暂不可用'})}\n\n"
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 缓存管理
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.get(
|
|
"/cache/stats",
|
|
response_model=StandardResponse[dict],
|
|
summary="缓存统计",
|
|
description="返回 GraphRAG 检索缓存的命中率和容量统计。",
|
|
)
|
|
async def cache_stats():
|
|
from app.module.kg_graphrag.cache import get_cache
|
|
|
|
return StandardResponse(code=200, message="success", data=get_cache().stats())
|
|
|
|
|
|
@router.post(
|
|
"/cache/clear",
|
|
response_model=StandardResponse[dict],
|
|
summary="清空缓存",
|
|
description="清空所有 GraphRAG 检索缓存。",
|
|
)
|
|
async def cache_clear():
|
|
from app.module.kg_graphrag.cache import get_cache
|
|
|
|
get_cache().clear()
|
|
return StandardResponse(code=200, message="success", data={"cleared": True})
|