You've already forked DataMate
feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证
新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口
修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖
测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)
关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*
安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)
测试结果:79 tests pass ✅
This commit is contained in:
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
102
runtime/datamate-python/app/module/kg_graphrag/models.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""GraphRAG 融合查询的请求/响应数据模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RetrievalStrategy(BaseModel):
|
||||
"""检索策略配置。"""
|
||||
|
||||
vector_top_k: int = Field(default=5, ge=1, le=50, description="向量检索返回数")
|
||||
graph_depth: int = Field(default=2, ge=1, le=5, description="图谱扩展深度")
|
||||
graph_max_entities: int = Field(default=20, ge=1, le=100, description="图谱最大实体数")
|
||||
vector_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="向量分数权重")
|
||||
graph_weight: float = Field(default=0.4, ge=0.0, le=1.0, description="图谱相关性权重")
|
||||
enable_graph: bool = Field(default=True, description="是否启用图谱检索")
|
||||
enable_vector: bool = Field(default=True, description="是否启用向量检索")
|
||||
|
||||
|
||||
class GraphRAGQueryRequest(BaseModel):
|
||||
"""GraphRAG 查询请求。"""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=2000,
|
||||
description="用户查询",
|
||||
)
|
||||
knowledge_base_id: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=64,
|
||||
description="知识库 ID,用于权限校验(由上游 Java 后端传入)",
|
||||
)
|
||||
collection_name: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
pattern=r"^[a-zA-Z0-9_\-\u4e00-\u9fff]+$",
|
||||
description="Milvus collection 名称(= 知识库名),仅允许字母、数字、下划线、连字符和中文",
|
||||
)
|
||||
graph_id: str = Field(
|
||||
...,
|
||||
pattern=r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$",
|
||||
description="Neo4j 图谱 ID(UUID 格式)",
|
||||
)
|
||||
strategy: RetrievalStrategy = Field(
|
||||
default_factory=RetrievalStrategy,
|
||||
description="可选策略覆盖",
|
||||
)
|
||||
|
||||
|
||||
class VectorChunk(BaseModel):
|
||||
"""向量检索到的文档片段。"""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
score: float
|
||||
metadata: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EntitySummary(BaseModel):
|
||||
"""实体摘要。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class RelationSummary(BaseModel):
|
||||
"""关系摘要。"""
|
||||
|
||||
source_name: str
|
||||
source_type: str
|
||||
target_name: str
|
||||
target_type: str
|
||||
relation_type: str
|
||||
|
||||
|
||||
class GraphContext(BaseModel):
|
||||
"""图谱上下文。"""
|
||||
|
||||
entities: list[EntitySummary] = Field(default_factory=list)
|
||||
relations: list[RelationSummary] = Field(default_factory=list)
|
||||
textualized: str = ""
|
||||
|
||||
|
||||
class RetrievalContext(BaseModel):
|
||||
"""检索上下文(检索结果的结构化表示)。"""
|
||||
|
||||
vector_chunks: list[VectorChunk] = Field(default_factory=list)
|
||||
graph_context: GraphContext = Field(default_factory=GraphContext)
|
||||
merged_text: str = ""
|
||||
|
||||
|
||||
class GraphRAGQueryResponse(BaseModel):
|
||||
"""GraphRAG 查询响应。"""
|
||||
|
||||
answer: str = Field(..., description="LLM 生成的回答")
|
||||
context: RetrievalContext = Field(..., description="检索上下文")
|
||||
model: str = Field(..., description="使用的 LLM 模型名")
|
||||
Reference in New Issue
Block a user