You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
214
runtime/datamate-python/app/module/kg_graphrag/retriever.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""GraphRAG 检索编排器。
|
||||
|
||||
并行执行向量检索和图谱检索,融合排序后构建统一上下文。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph
|
||||
from app.module.kg_graphrag.kg_client import KGServiceClient
|
||||
from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever
|
||||
from app.module.kg_graphrag.models import (
|
||||
EntitySummary,
|
||||
GraphContext,
|
||||
RelationSummary,
|
||||
RetrievalContext,
|
||||
RetrievalStrategy,
|
||||
VectorChunk,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GraphRAGRetriever:
|
||||
"""GraphRAG 检索编排器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
milvus_client: MilvusVectorRetriever,
|
||||
kg_client: KGServiceClient,
|
||||
) -> None:
|
||||
self._milvus = milvus_client
|
||||
self._kg = kg_client
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> GraphRAGRetriever:
|
||||
return cls(
|
||||
milvus_client=MilvusVectorRetriever.from_settings(),
|
||||
kg_client=KGServiceClient.from_settings(),
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
collection_name: str,
|
||||
graph_id: str,
|
||||
strategy: RetrievalStrategy,
|
||||
user_id: str = "",
|
||||
) -> RetrievalContext:
|
||||
"""并行执行向量检索 + 图谱检索,融合结果。"""
|
||||
# 构建并行任务
|
||||
tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
if strategy.enable_vector:
|
||||
# 先校验 collection 存在性,防止越权访问
|
||||
if not await self._milvus.has_collection(collection_name):
|
||||
logger.warning(
|
||||
"Collection %s not found, skipping vector retrieval",
|
||||
collection_name,
|
||||
)
|
||||
else:
|
||||
tasks["vector"] = asyncio.create_task(
|
||||
self._milvus.search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
top_k=strategy.vector_top_k,
|
||||
)
|
||||
)
|
||||
|
||||
if strategy.enable_graph:
|
||||
tasks["graph"] = asyncio.create_task(
|
||||
self._retrieve_graph(
|
||||
query=query,
|
||||
graph_id=graph_id,
|
||||
strategy=strategy,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 等待所有任务完成
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks.values(), return_exceptions=True)
|
||||
|
||||
# 收集结果
|
||||
vector_chunks: list[VectorChunk] = []
|
||||
if "vector" in tasks:
|
||||
try:
|
||||
vector_chunks = tasks["vector"].result()
|
||||
except Exception:
|
||||
logger.exception("Vector retrieval task failed")
|
||||
|
||||
entities: list[EntitySummary] = []
|
||||
relations: list[RelationSummary] = []
|
||||
if "graph" in tasks:
|
||||
try:
|
||||
entities, relations = tasks["graph"].result()
|
||||
except Exception:
|
||||
logger.exception("Graph retrieval task failed")
|
||||
|
||||
# 融合排序
|
||||
vector_chunks = self._rank_results(
|
||||
vector_chunks, entities, relations, strategy
|
||||
)
|
||||
|
||||
# 三元组文本化
|
||||
graph_text = textualize_subgraph(entities, relations)
|
||||
|
||||
# 构建上下文
|
||||
merged_text = build_context(
|
||||
vector_chunks,
|
||||
graph_text,
|
||||
vector_weight=strategy.vector_weight,
|
||||
graph_weight=strategy.graph_weight,
|
||||
)
|
||||
|
||||
return RetrievalContext(
|
||||
vector_chunks=vector_chunks,
|
||||
graph_context=GraphContext(
|
||||
entities=entities,
|
||||
relations=relations,
|
||||
textualized=graph_text,
|
||||
),
|
||||
merged_text=merged_text,
|
||||
)
|
||||
|
||||
async def _retrieve_graph(
|
||||
self,
|
||||
query: str,
|
||||
graph_id: str,
|
||||
strategy: RetrievalStrategy,
|
||||
user_id: str,
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
"""图谱检索:全文搜索 -> 种子实体 -> 子图扩展。"""
|
||||
# 1. 全文检索获取种子实体
|
||||
seed_entities = await self._kg.fulltext_search(
|
||||
graph_id=graph_id,
|
||||
query=query,
|
||||
size=strategy.graph_max_entities,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not seed_entities:
|
||||
logger.debug("No seed entities found for query: %s", query)
|
||||
return [], []
|
||||
|
||||
# 2. 获取种子实体的 N-hop 子图
|
||||
seed_ids = [e.id for e in seed_entities]
|
||||
entities, relations = await self._kg.get_subgraph(
|
||||
graph_id=graph_id,
|
||||
entity_ids=seed_ids,
|
||||
depth=strategy.graph_depth,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Graph retrieval: %d seed entities -> %d entities, %d relations",
|
||||
len(seed_entities), len(entities), len(relations),
|
||||
)
|
||||
return entities, relations
|
||||
|
||||
def _rank_results(
|
||||
self,
|
||||
vector_chunks: list[VectorChunk],
|
||||
entities: list[EntitySummary],
|
||||
relations: list[RelationSummary],
|
||||
strategy: RetrievalStrategy,
|
||||
) -> list[VectorChunk]:
|
||||
"""对向量检索结果进行融合排序。
|
||||
|
||||
基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。
|
||||
"""
|
||||
if not vector_chunks:
|
||||
return vector_chunks
|
||||
|
||||
# 向量分数归一化 (min-max scaling)
|
||||
scores = [c.score for c in vector_chunks]
|
||||
min_score = min(scores)
|
||||
max_score = max(scores)
|
||||
score_range = max_score - min_score
|
||||
|
||||
# 构建图谱实体名称集合,用于关联度加分
|
||||
graph_entity_names = {e.name.lower() for e in entities}
|
||||
|
||||
ranked: list[tuple[float, VectorChunk]] = []
|
||||
for chunk in vector_chunks:
|
||||
# 归一化向量分数
|
||||
norm_score = (
|
||||
(chunk.score - min_score) / score_range
|
||||
if score_range > 0
|
||||
else 1.0
|
||||
)
|
||||
|
||||
# 图谱关联度加分:文档片段中提及图谱实体名称
|
||||
graph_boost = 0.0
|
||||
if graph_entity_names:
|
||||
chunk_text_lower = chunk.text.lower()
|
||||
mentioned = sum(
|
||||
1 for name in graph_entity_names if name in chunk_text_lower
|
||||
)
|
||||
graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0)
|
||||
|
||||
# 加权融合分数
|
||||
final_score = (
|
||||
strategy.vector_weight * norm_score
|
||||
+ strategy.graph_weight * graph_boost
|
||||
)
|
||||
ranked.append((final_score, chunk))
|
||||
|
||||
# 按融合分数降序排序
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
return [chunk for _, chunk in ranked]
|
||||
Reference in New Issue
Block a user