"""GraphRAG 检索编排器。 并行执行向量检索和图谱检索,融合排序后构建统一上下文。 """ from __future__ import annotations import asyncio from app.core.logging import get_logger from app.module.kg_graphrag.context_builder import build_context, textualize_subgraph from app.module.kg_graphrag.kg_client import KGServiceClient from app.module.kg_graphrag.milvus_client import MilvusVectorRetriever from app.module.kg_graphrag.models import ( EntitySummary, GraphContext, RelationSummary, RetrievalContext, RetrievalStrategy, VectorChunk, ) logger = get_logger(__name__) class GraphRAGRetriever: """GraphRAG 检索编排器。""" def __init__( self, *, milvus_client: MilvusVectorRetriever, kg_client: KGServiceClient, ) -> None: self._milvus = milvus_client self._kg = kg_client @classmethod def from_settings(cls) -> GraphRAGRetriever: return cls( milvus_client=MilvusVectorRetriever.from_settings(), kg_client=KGServiceClient.from_settings(), ) async def retrieve( self, query: str, collection_name: str, graph_id: str, strategy: RetrievalStrategy, user_id: str = "", ) -> RetrievalContext: """并行执行向量检索 + 图谱检索,融合结果。""" # 构建并行任务 tasks: dict[str, asyncio.Task] = {} if strategy.enable_vector: # 先校验 collection 存在性,防止越权访问 if not await self._milvus.has_collection(collection_name): logger.warning( "Collection %s not found, skipping vector retrieval", collection_name, ) else: tasks["vector"] = asyncio.create_task( self._milvus.search( collection_name=collection_name, query=query, top_k=strategy.vector_top_k, ) ) if strategy.enable_graph: tasks["graph"] = asyncio.create_task( self._retrieve_graph( query=query, graph_id=graph_id, strategy=strategy, user_id=user_id, ) ) # 等待所有任务完成 if tasks: await asyncio.gather(*tasks.values(), return_exceptions=True) # 收集结果 vector_chunks: list[VectorChunk] = [] if "vector" in tasks: try: vector_chunks = tasks["vector"].result() except Exception: logger.exception("Vector retrieval task failed") entities: list[EntitySummary] = [] relations: list[RelationSummary] = [] if "graph" in tasks: try: entities, relations = tasks["graph"].result() except Exception: logger.exception("Graph retrieval task failed") # 融合排序 vector_chunks = self._rank_results( vector_chunks, entities, relations, strategy ) # 三元组文本化 graph_text = textualize_subgraph(entities, relations) # 构建上下文 merged_text = build_context( vector_chunks, graph_text, vector_weight=strategy.vector_weight, graph_weight=strategy.graph_weight, ) return RetrievalContext( vector_chunks=vector_chunks, graph_context=GraphContext( entities=entities, relations=relations, textualized=graph_text, ), merged_text=merged_text, ) async def _retrieve_graph( self, query: str, graph_id: str, strategy: RetrievalStrategy, user_id: str, ) -> tuple[list[EntitySummary], list[RelationSummary]]: """图谱检索:全文搜索 -> 种子实体 -> 子图扩展。""" # 1. 全文检索获取种子实体 seed_entities = await self._kg.fulltext_search( graph_id=graph_id, query=query, size=strategy.graph_max_entities, user_id=user_id, ) if not seed_entities: logger.debug("No seed entities found for query: %s", query) return [], [] # 2. 获取种子实体的 N-hop 子图 seed_ids = [e.id for e in seed_entities] entities, relations = await self._kg.get_subgraph( graph_id=graph_id, entity_ids=seed_ids, depth=strategy.graph_depth, user_id=user_id, ) logger.info( "Graph retrieval: %d seed entities -> %d entities, %d relations", len(seed_entities), len(entities), len(relations), ) return entities, relations def _rank_results( self, vector_chunks: list[VectorChunk], entities: list[EntitySummary], relations: list[RelationSummary], strategy: RetrievalStrategy, ) -> list[VectorChunk]: """对向量检索结果进行融合排序。 基于向量分数归一化后加权排序。图谱关联度通过实体度数近似评估。 """ if not vector_chunks: return vector_chunks # 向量分数归一化 (min-max scaling) scores = [c.score for c in vector_chunks] min_score = min(scores) max_score = max(scores) score_range = max_score - min_score # 构建图谱实体名称集合,用于关联度加分 graph_entity_names = {e.name.lower() for e in entities} ranked: list[tuple[float, VectorChunk]] = [] for chunk in vector_chunks: # 归一化向量分数 norm_score = ( (chunk.score - min_score) / score_range if score_range > 0 else 1.0 ) # 图谱关联度加分:文档片段中提及图谱实体名称 graph_boost = 0.0 if graph_entity_names: chunk_text_lower = chunk.text.lower() mentioned = sum( 1 for name in graph_entity_names if name in chunk_text_lower ) graph_boost = min(mentioned / max(len(graph_entity_names), 1), 1.0) # 加权融合分数 final_score = ( strategy.vector_weight * norm_score + strategy.graph_weight * graph_boost ) ranked.append((final_score, chunk)) # 按融合分数降序排序 ranked.sort(key=lambda x: x[0], reverse=True) return [chunk for _, chunk in ranked]