You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
135
runtime/datamate-python/app/module/kg_graphrag/milvus_client.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Milvus 向量检索客户端。
|
||||
|
||||
通过 pymilvus 连接 Milvus,对查询文本进行 embedding 后执行混合搜索,
|
||||
返回 top-K 文档片段。
|
||||
|
||||
失败策略:fail-open —— Milvus 不可用时返回空列表 + 日志告警。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.models import VectorChunk
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MilvusVectorRetriever:
|
||||
"""Milvus 向量检索器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
uri: str = "http://milvus-standalone:19530",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
embedding_base_url: str | None = None,
|
||||
embedding_api_key: SecretStr = SecretStr("EMPTY"),
|
||||
) -> None:
|
||||
self._uri = uri
|
||||
self._embedding_model = embedding_model
|
||||
self._embedding_base_url = embedding_base_url
|
||||
self._embedding_api_key = embedding_api_key
|
||||
# Lazy init
|
||||
self._milvus_client = None
|
||||
self._embeddings = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> MilvusVectorRetriever:
|
||||
from app.core.config import settings
|
||||
|
||||
embedding_model = (
|
||||
settings.graphrag_embedding_model
|
||||
or settings.kg_alignment_embedding_model
|
||||
)
|
||||
return cls(
|
||||
uri=settings.graphrag_milvus_uri,
|
||||
embedding_model=embedding_model,
|
||||
embedding_base_url=settings.kg_llm_base_url,
|
||||
embedding_api_key=settings.kg_llm_api_key,
|
||||
)
|
||||
|
||||
def _get_embeddings(self):
|
||||
if self._embeddings is None:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
self._embeddings = OpenAIEmbeddings(
|
||||
model=self._embedding_model,
|
||||
base_url=self._embedding_base_url,
|
||||
api_key=self._embedding_api_key,
|
||||
)
|
||||
return self._embeddings
|
||||
|
||||
def _get_milvus_client(self):
|
||||
if self._milvus_client is None:
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
self._milvus_client = MilvusClient(uri=self._uri)
|
||||
logger.info("Connected to Milvus at %s", self._uri)
|
||||
return self._milvus_client
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
"""检查 Milvus 中是否存在指定 collection(防止越权访问不存在的库)。"""
|
||||
try:
|
||||
client = self._get_milvus_client()
|
||||
return await asyncio.to_thread(client.has_collection, collection_name)
|
||||
except Exception:
|
||||
logger.exception("Milvus has_collection check failed for %s", collection_name)
|
||||
return False
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> list[VectorChunk]:
|
||||
"""向量搜索:embed query -> Milvus search -> 返回 top-K 文档片段。
|
||||
|
||||
Fail-open: Milvus 不可用时返回空列表。
|
||||
"""
|
||||
try:
|
||||
return await self._search_impl(collection_name, query, top_k)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Milvus search failed for collection=%s (fail-open, returning empty)",
|
||||
collection_name,
|
||||
)
|
||||
return []
|
||||
|
||||
async def _search_impl(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
) -> list[VectorChunk]:
|
||||
# 1. Embed query
|
||||
query_vector = await self._get_embeddings().aembed_query(query)
|
||||
|
||||
# 2. Milvus search(同步 I/O,通过 to_thread 避免阻塞事件循环)
|
||||
client = self._get_milvus_client()
|
||||
results = await asyncio.to_thread(
|
||||
client.search,
|
||||
collection_name=collection_name,
|
||||
data=[query_vector],
|
||||
limit=top_k,
|
||||
output_fields=["text", "metadata"],
|
||||
search_params={"metric_type": "COSINE", "params": {"nprobe": 16}},
|
||||
)
|
||||
|
||||
# 3. 转换为 VectorChunk
|
||||
chunks: list[VectorChunk] = []
|
||||
if results and len(results) > 0:
|
||||
for hit in results[0]:
|
||||
entity = hit.get("entity", {})
|
||||
chunks.append(
|
||||
VectorChunk(
|
||||
id=str(hit.get("id", "")),
|
||||
text=entity.get("text", ""),
|
||||
score=float(hit.get("distance", 0.0)),
|
||||
metadata=entity.get("metadata", {}),
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
Reference in New Issue
Block a user