You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
197
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
197
runtime/datamate-python/app/module/kg_graphrag/kg_client.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""KG 服务 REST 客户端。
|
||||
|
||||
通过 httpx 调用 Java 侧 knowledge-graph-service 的查询 API,
|
||||
包括全文检索和子图导出。
|
||||
|
||||
失败策略:fail-open —— KG 服务不可用时返回空结果 + 日志告警。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_graphrag.models import EntitySummary, RelationSummary
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class KGServiceClient:
|
||||
"""Java KG 服务 REST 客户端。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str = "http://datamate-kg:8080",
|
||||
internal_token: str = "",
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._internal_token = internal_token
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KGServiceClient:
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(
|
||||
base_url=settings.graphrag_kg_service_url,
|
||||
internal_token=settings.graphrag_kg_internal_token,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self._base_url,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _headers(self, user_id: str = "") -> dict[str, str]:
|
||||
headers: dict[str, str] = {}
|
||||
if self._internal_token:
|
||||
headers["X-Internal-Token"] = self._internal_token
|
||||
if user_id:
|
||||
headers["X-User-Id"] = user_id
|
||||
return headers
|
||||
|
||||
async def fulltext_search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
size: int = 10,
|
||||
user_id: str = "",
|
||||
) -> list[EntitySummary]:
|
||||
"""调用 KG 服务全文检索,返回匹配的实体列表。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空列表。
|
||||
"""
|
||||
try:
|
||||
return await self._fulltext_search_impl(graph_id, query, size, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG fulltext search failed for graph_id=%s (fail-open, returning empty)",
|
||||
graph_id,
|
||||
)
|
||||
return []
|
||||
|
||||
async def _fulltext_search_impl(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
size: int,
|
||||
user_id: str,
|
||||
) -> list[EntitySummary]:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"/knowledge-graph/{graph_id}/query/search",
|
||||
params={"q": query, "size": size},
|
||||
headers=self._headers(user_id),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
|
||||
# Java 返回 PagedResponse<SearchHitVO>:
|
||||
# 可能被全局包装为 {"code": 200, "data": PagedResponse}
|
||||
# 也可能直接返回 PagedResponse {"page": 0, "content": [...]}
|
||||
data = body.get("data", body)
|
||||
# PagedResponse 将实体列表放在 content 字段中
|
||||
items: list[dict] = (
|
||||
data.get("content", []) if isinstance(data, dict) else data if isinstance(data, list) else []
|
||||
)
|
||||
entities: list[EntitySummary] = []
|
||||
for item in items:
|
||||
entities.append(
|
||||
EntitySummary(
|
||||
id=str(item.get("id", "")),
|
||||
name=item.get("name", ""),
|
||||
type=item.get("type", ""),
|
||||
description=item.get("description", ""),
|
||||
)
|
||||
)
|
||||
return entities
|
||||
|
||||
async def get_subgraph(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_ids: list[str],
|
||||
depth: int = 1,
|
||||
user_id: str = "",
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
"""获取种子实体的 N-hop 子图。
|
||||
|
||||
Fail-open: KG 服务不可用时返回空子图。
|
||||
"""
|
||||
try:
|
||||
return await self._get_subgraph_impl(graph_id, entity_ids, depth, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"KG subgraph export failed for graph_id=%s (fail-open, returning empty)",
|
||||
graph_id,
|
||||
)
|
||||
return [], []
|
||||
|
||||
async def _get_subgraph_impl(
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_ids: list[str],
|
||||
depth: int,
|
||||
user_id: str,
|
||||
) -> tuple[list[EntitySummary], list[RelationSummary]]:
|
||||
client = self._get_client()
|
||||
resp = await client.post(
|
||||
f"/knowledge-graph/{graph_id}/query/subgraph/export",
|
||||
params={"depth": depth},
|
||||
json={"entityIds": entity_ids},
|
||||
headers=self._headers(user_id),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
body = resp.json()
|
||||
|
||||
# Java 返回 SubgraphExportVO:
|
||||
# 可能被全局包装为 {"code": 200, "data": SubgraphExportVO}
|
||||
# 也可能直接返回 SubgraphExportVO {"nodes": [...], "edges": [...]}
|
||||
data = body.get("data", body) if isinstance(body.get("data"), dict) else body
|
||||
nodes_raw = data.get("nodes", [])
|
||||
edges_raw = data.get("edges", [])
|
||||
|
||||
# ExportNodeVO: id, name, type, description, properties (Map)
|
||||
entities: list[EntitySummary] = []
|
||||
for node in nodes_raw:
|
||||
entities.append(
|
||||
EntitySummary(
|
||||
id=str(node.get("id", "")),
|
||||
name=node.get("name", ""),
|
||||
type=node.get("type", ""),
|
||||
description=node.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
relations: list[RelationSummary] = []
|
||||
# 构建 id -> entity 的映射用于查找 source/target 名称和类型
|
||||
entity_map = {e.id: e for e in entities}
|
||||
# ExportEdgeVO: sourceEntityId, targetEntityId, relationType
|
||||
# 注意:sourceId 是数据来源 ID,不是源实体 ID
|
||||
for edge in edges_raw:
|
||||
source_id = str(edge.get("sourceEntityId", ""))
|
||||
target_id = str(edge.get("targetEntityId", ""))
|
||||
source_entity = entity_map.get(source_id)
|
||||
target_entity = entity_map.get(target_id)
|
||||
relations.append(
|
||||
RelationSummary(
|
||||
source_name=source_entity.name if source_entity else source_id,
|
||||
source_type=source_entity.type if source_entity else "",
|
||||
target_name=target_entity.name if target_entity else target_id,
|
||||
target_type=target_entity.type if target_entity else "",
|
||||
relation_type=edge.get("relationType", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return entities, relations
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
Reference in New Issue
Block a user