Files
DataMate/runtime/datamate-python/app/module/kg_graphrag/test_context_builder.py
Jerry Yan 39338df808 feat(kg): 实现 Phase 2 GraphRAG 融合功能
核心功能:
- 三层检索策略:向量检索(Milvus)+ 图检索(KG 服务)+ 融合排序
- LLM 生成:支持同步和流式(SSE)响应
- 知识库访问控制:knowledge_base_id 归属校验 + collection_name 绑定验证

新增模块(9个文件):
- models.py: 请求/响应模型(GraphRAGQueryRequest, RetrievalStrategy, GraphContext 等)
- milvus_client.py: Milvus 向量检索客户端(OpenAI Embeddings + asyncio.to_thread)
- kg_client.py: KG 服务 REST 客户端(全文检索 + 子图导出,fail-open)
- context_builder.py: 三元组文本化(10 种关系模板)+ 上下文构建
- generator.py: LLM 生成(ChatOpenAI,支持同步和流式)
- retriever.py: 检索编排(并行检索 + 融合排序)
- kb_access.py: 知识库访问校验(归属验证 + collection 绑定,fail-close)
- interface.py: FastAPI 端点(/query, /retrieve, /query/stream)
- __init__.py: 模块入口

修改文件(3个):
- app/core/config.py: 添加 13 个 graphrag_* 配置项
- app/module/__init__.py: 注册 kg_graphrag_router
- pyproject.toml: 添加 pymilvus 依赖

测试覆盖(79 tests):
- test_context_builder.py: 13 tests(三元组文本化 + 上下文构建)
- test_kg_client.py: 14 tests(KG 响应解析 + PagedResponse + 边字段映射)
- test_milvus_client.py: 8 tests(向量检索 + asyncio.to_thread)
- test_retriever.py: 11 tests(并行检索 + 融合排序 + fail-open)
- test_kb_access.py: 18 tests(归属校验 + collection 绑定 + 跨用户负例)
- test_interface.py: 15 tests(端点级回归 + 403 short-circuit)

关键设计:
- Fail-open: Milvus/KG 服务失败不阻塞管道,返回空结果
- Fail-close: 访问控制失败拒绝请求,防止授权绕过
- 并行检索: asyncio.gather() 并发运行向量和图检索
- 融合排序: Min-max 归一化 + 加权融合(vector_weight/graph_weight)
- 延迟初始化: 所有客户端在首次请求时初始化
- 配置回退: graphrag_llm_* 为空时回退到 kg_llm_*

安全修复:
- P1-1: KG 响应解析(PagedResponse.content)
- P1-2: 子图边字段映射(sourceEntityId/targetEntityId)
- P1-3: collection_name 越权风险(归属校验 + 绑定验证)
- P1-4: 同步 Milvus I/O(asyncio.to_thread)
- P1-5: 测试覆盖(79 tests,包括安全负例)

测试结果:79 tests pass 
2026-02-20 09:41:55 +08:00

183 lines
6.6 KiB
Python

"""三元组文本化 + 上下文构建的单元测试。"""
from app.module.kg_graphrag.context_builder import (
RELATION_TEMPLATES,
build_context,
textualize_subgraph,
)
from app.module.kg_graphrag.models import (
EntitySummary,
RelationSummary,
VectorChunk,
)
# ---------------------------------------------------------------------------
# textualize_subgraph 测试
# ---------------------------------------------------------------------------
class TestTextualizeSubgraph:
"""textualize_subgraph 函数的测试。"""
def test_single_relation(self):
entities = [
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
EntitySummary(id="2", name="user_id", type="Field"),
]
relations = [
RelationSummary(
source_name="用户行为数据",
source_type="Dataset",
target_name="user_id",
target_type="Field",
relation_type="HAS_FIELD",
),
]
result = textualize_subgraph(entities, relations)
assert "Dataset'用户行为数据'包含字段Field'user_id'" in result
def test_multiple_relations(self):
entities = [
EntitySummary(id="1", name="用户行为数据", type="Dataset"),
EntitySummary(id="2", name="清洗管道", type="Workflow"),
]
relations = [
RelationSummary(
source_name="清洗管道",
source_type="Workflow",
target_name="用户行为数据",
target_type="Dataset",
relation_type="USES_DATASET",
),
RelationSummary(
source_name="用户行为数据",
source_type="Dataset",
target_name="外部系统",
target_type="DataSource",
relation_type="SOURCED_FROM",
),
]
result = textualize_subgraph(entities, relations)
assert "Workflow'清洗管道'使用了数据集Dataset'用户行为数据'" in result
assert "Dataset'用户行为数据'的知识来源于DataSource'外部系统'" in result
def test_all_relation_templates(self):
"""验证所有 10 种关系模板都能正确生成。"""
for rel_type, template in RELATION_TEMPLATES.items():
relations = [
RelationSummary(
source_name="A",
source_type="TypeA",
target_name="B",
target_type="TypeB",
relation_type=rel_type,
),
]
result = textualize_subgraph([], relations)
assert "TypeA'A'" in result
assert "TypeB'B'" in result
assert result # 非空
def test_unknown_relation_type(self):
"""未知关系类型使用通用模板。"""
relations = [
RelationSummary(
source_name="X",
source_type="T1",
target_name="Y",
target_type="T2",
relation_type="CUSTOM_REL",
),
]
result = textualize_subgraph([], relations)
assert "T1'X'与T2'Y'存在CUSTOM_REL关系" in result
def test_orphan_entity_with_description(self):
"""无关系的独立实体(有描述)。"""
entities = [
EntitySummary(id="1", name="孤立实体", type="Dataset", description="这是一个测试实体"),
]
result = textualize_subgraph(entities, [])
assert "Dataset'孤立实体': 这是一个测试实体" in result
def test_orphan_entity_without_description(self):
"""无关系的独立实体(无描述)。"""
entities = [
EntitySummary(id="1", name="孤立实体", type="Dataset"),
]
result = textualize_subgraph(entities, [])
assert "存在Dataset'孤立实体'" in result
def test_empty_inputs(self):
result = textualize_subgraph([], [])
assert result == ""
def test_entity_with_relation_not_orphan(self):
"""有关系的实体不应出现在独立实体部分。"""
entities = [
EntitySummary(id="1", name="A", type="Dataset"),
EntitySummary(id="2", name="B", type="Field"),
EntitySummary(id="3", name="C", type="Workflow"),
]
relations = [
RelationSummary(
source_name="A",
source_type="Dataset",
target_name="B",
target_type="Field",
relation_type="HAS_FIELD",
),
]
result = textualize_subgraph(entities, relations)
# A 和 B 有关系,不应作为独立实体出现
# C 无关系,应出现
assert "存在Workflow'C'" in result
lines = result.strip().split("\n")
assert len(lines) == 2 # 一条关系 + 一个独立实体
# ---------------------------------------------------------------------------
# build_context 测试
# ---------------------------------------------------------------------------
class TestBuildContext:
"""build_context 函数的测试。"""
def test_both_vector_and_graph(self):
chunks = [
VectorChunk(id="1", text="文档片段一", score=0.9),
VectorChunk(id="2", text="文档片段二", score=0.8),
]
graph_text = "Dataset'用户数据'包含字段Field'user_id'"
result = build_context(chunks, graph_text)
assert "## 相关文档" in result
assert "[1] 文档片段一" in result
assert "[2] 文档片段二" in result
assert "## 知识图谱上下文" in result
assert graph_text in result
def test_vector_only(self):
chunks = [VectorChunk(id="1", text="文档片段", score=0.9)]
result = build_context(chunks, "")
assert "## 相关文档" in result
assert "## 知识图谱上下文" not in result
def test_graph_only(self):
result = build_context([], "图谱内容")
assert "## 知识图谱上下文" in result
assert "## 相关文档" not in result
def test_empty_both(self):
result = build_context([], "")
assert "未检索到相关上下文信息" in result
def test_context_section_order(self):
"""验证文档在图谱之前。"""
chunks = [VectorChunk(id="1", text="doc", score=0.9)]
result = build_context(chunks, "graph")
doc_pos = result.index("## 相关文档")
graph_pos = result.index("## 知识图谱上下文")
assert doc_pos < graph_pos