You've already forked DataMate
- 实现三层对齐策略:规则层 + 向量相似度层 + LLM 仲裁层 - 规则层:名称规范化(NFKC、小写、去标点/空格)+ 规则评分 - 向量层:OpenAI Embeddings + cosine 相似度计算 - LLM 层:仅对边界样本调用,严格 JSON schema 校验 - 使用 Union-Find 实现传递合并 - 支持批内对齐(库内对齐待 KG 服务 API 支持) 核心组件: - EntityAligner 类:align() (async)、align_rules_only() (sync) - 配置项:kg_alignment_enabled(默认 false)、embedding_model、阈值 - 失败策略:fail-open(对齐失败不中断请求) 集成: - 已集成到抽取主链路(extract → align → return) - extract() 调用 async align() - extract_sync() 调用 sync align_rules_only() 修复: - P1-1:使用 (name, type) 作为 key,避免同名跨类型误合并 - P1-2:LLM 计数在 finally 块中增加,异常也计数 - P1-3:添加库内对齐说明(待后续实现) 新增 41 个测试用例,全部通过 测试结果:41 tests pass
478 lines
17 KiB
Python
478 lines
17 KiB
Python
"""实体对齐器测试。
|
|
|
|
Run with: pytest app/module/kg_extraction/test_aligner.py -v
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.module.kg_extraction.aligner import (
|
|
EntityAligner,
|
|
LLMArbitrationResult,
|
|
_build_merged_result,
|
|
_make_union_find,
|
|
cosine_similarity,
|
|
normalize_name,
|
|
rule_score,
|
|
)
|
|
from app.module.kg_extraction.models import (
|
|
ExtractionResult,
|
|
GraphEdge,
|
|
GraphNode,
|
|
Triple,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# normalize_name
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNormalizeName:
|
|
def test_basic_lowercase(self):
|
|
assert normalize_name("Hello World") == "hello world"
|
|
|
|
def test_unicode_nfkc(self):
|
|
assert normalize_name("\uff28ello") == "hello"
|
|
|
|
def test_punctuation_removed(self):
|
|
assert normalize_name("U.S.A.") == "usa"
|
|
|
|
def test_whitespace_collapsed(self):
|
|
assert normalize_name(" hello world ") == "hello world"
|
|
|
|
def test_empty_string(self):
|
|
assert normalize_name("") == ""
|
|
|
|
def test_chinese_preserved(self):
|
|
assert normalize_name("\u5f20\u4e09") == "\u5f20\u4e09"
|
|
|
|
def test_mixed_chinese_english(self):
|
|
assert normalize_name("\u5f20\u4e09 (Zhang San)") == "\u5f20\u4e09 zhang san"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# rule_score
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRuleScore:
|
|
def test_exact_match(self):
|
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
b = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
assert rule_score(a, b) == 1.0
|
|
|
|
def test_normalized_match(self):
|
|
a = GraphNode(name="Hello World", type="Organization")
|
|
b = GraphNode(name="hello world", type="Organization")
|
|
assert rule_score(a, b) == 1.0
|
|
|
|
def test_type_mismatch(self):
|
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
b = GraphNode(name="\u5f20\u4e09", type="Organization")
|
|
assert rule_score(a, b) == 0.0
|
|
|
|
def test_substring_match(self):
|
|
a = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
|
b = GraphNode(name="\u5317\u4eac\u5927\u5b66\u8ba1\u7b97\u673a\u5b66\u9662", type="Organization")
|
|
assert rule_score(a, b) == 0.5
|
|
|
|
def test_no_match(self):
|
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
b = GraphNode(name="\u674e\u56db", type="Person")
|
|
assert rule_score(a, b) == 0.0
|
|
|
|
def test_type_case_insensitive(self):
|
|
a = GraphNode(name="test", type="PERSON")
|
|
b = GraphNode(name="test", type="person")
|
|
assert rule_score(a, b) == 1.0
|
|
|
|
def test_short_substring_ignored(self):
|
|
"""Single-character substring should not trigger match."""
|
|
a = GraphNode(name="A", type="Person")
|
|
b = GraphNode(name="AB", type="Person")
|
|
assert rule_score(a, b) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# cosine_similarity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCosineSimilarity:
|
|
def test_identical(self):
|
|
assert cosine_similarity([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
|
|
|
def test_orthogonal(self):
|
|
assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
|
|
|
def test_opposite(self):
|
|
assert cosine_similarity([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
|
|
|
def test_zero_vector(self):
|
|
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Union-Find
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUnionFind:
|
|
def test_basic(self):
|
|
parent, find, union = _make_union_find(4)
|
|
union(0, 1)
|
|
union(2, 3)
|
|
assert find(0) == find(1)
|
|
assert find(2) == find(3)
|
|
assert find(0) != find(2)
|
|
|
|
def test_transitive(self):
|
|
parent, find, union = _make_union_find(3)
|
|
union(0, 1)
|
|
union(1, 2)
|
|
assert find(0) == find(2)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _build_merged_result
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_result(nodes, edges=None, triples=None):
|
|
return ExtractionResult(
|
|
nodes=nodes,
|
|
edges=edges or [],
|
|
triples=triples or [],
|
|
raw_text="test text",
|
|
source_id="src-1",
|
|
)
|
|
|
|
|
|
class TestBuildMergedResult:
|
|
def test_no_merge_returns_original(self):
|
|
nodes = [
|
|
GraphNode(name="A", type="Person"),
|
|
GraphNode(name="B", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
parent, find, _ = _make_union_find(2)
|
|
merged = _build_merged_result(result, parent, find)
|
|
assert merged is result
|
|
|
|
def test_canonical_picks_longest_name(self):
|
|
nodes = [
|
|
GraphNode(name="AI", type="Tech"),
|
|
GraphNode(name="Artificial Intelligence", type="Tech"),
|
|
]
|
|
result = _make_result(nodes)
|
|
parent, find, union = _make_union_find(2)
|
|
union(0, 1)
|
|
merged = _build_merged_result(result, parent, find)
|
|
assert len(merged.nodes) == 1
|
|
assert merged.nodes[0].name == "Artificial Intelligence"
|
|
|
|
def test_edge_remap_and_dedup(self):
|
|
nodes = [
|
|
GraphNode(name="Alice", type="Person"),
|
|
GraphNode(name="alice", type="Person"),
|
|
GraphNode(name="Bob", type="Person"),
|
|
]
|
|
edges = [
|
|
GraphEdge(source="Alice", target="Bob", relation_type="knows"),
|
|
GraphEdge(source="alice", target="Bob", relation_type="knows"),
|
|
]
|
|
result = _make_result(nodes, edges)
|
|
parent, find, union = _make_union_find(3)
|
|
union(0, 1)
|
|
merged = _build_merged_result(result, parent, find)
|
|
assert len(merged.edges) == 1
|
|
assert merged.edges[0].source == "Alice"
|
|
|
|
def test_triple_remap_and_dedup(self):
|
|
n1 = GraphNode(name="Alice", type="Person")
|
|
n2 = GraphNode(name="alice", type="Person")
|
|
n3 = GraphNode(name="MIT", type="Organization")
|
|
triples = [
|
|
Triple(subject=n1, predicate="works_at", object=n3),
|
|
Triple(subject=n2, predicate="works_at", object=n3),
|
|
]
|
|
result = _make_result([n1, n2, n3], triples=triples)
|
|
parent, find, union = _make_union_find(3)
|
|
union(0, 1)
|
|
merged = _build_merged_result(result, parent, find)
|
|
assert len(merged.triples) == 1
|
|
assert merged.triples[0].subject.name == "Alice"
|
|
|
|
def test_preserves_metadata(self):
|
|
nodes = [
|
|
GraphNode(name="A", type="Person"),
|
|
GraphNode(name="A", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
parent, find, union = _make_union_find(2)
|
|
union(0, 1)
|
|
merged = _build_merged_result(result, parent, find)
|
|
assert merged.raw_text == "test text"
|
|
assert merged.source_id == "src-1"
|
|
|
|
def test_cross_type_same_name_no_collision(self):
|
|
"""P1-1 回归:同名跨类型节点合并不应误映射其他类型的边和三元组。
|
|
|
|
场景:Person "张三" 和 "张三先生" 合并为 "张三先生",
|
|
但 Organization "张三" 不应被重写。
|
|
"""
|
|
nodes = [
|
|
GraphNode(name="张三", type="Person"), # idx 0
|
|
GraphNode(name="张三先生", type="Person"), # idx 1
|
|
GraphNode(name="张三", type="Organization"), # idx 2 - 同名不同类型
|
|
GraphNode(name="北京", type="Location"), # idx 3
|
|
]
|
|
edges = [
|
|
GraphEdge(source="张三", target="北京", relation_type="lives_in"),
|
|
GraphEdge(source="张三", target="北京", relation_type="located_in"),
|
|
]
|
|
triples = [
|
|
Triple(
|
|
subject=GraphNode(name="张三", type="Person"),
|
|
predicate="lives_in",
|
|
object=GraphNode(name="北京", type="Location"),
|
|
),
|
|
Triple(
|
|
subject=GraphNode(name="张三", type="Organization"),
|
|
predicate="located_in",
|
|
object=GraphNode(name="北京", type="Location"),
|
|
),
|
|
]
|
|
result = _make_result(nodes, edges, triples)
|
|
parent, find, union = _make_union_find(4)
|
|
union(0, 1) # 合并 Person "张三" 和 "张三先生"
|
|
merged = _build_merged_result(result, parent, find)
|
|
|
|
# 应有 3 个节点:张三先生(Person), 张三(Org), 北京(Location)
|
|
assert len(merged.nodes) == 3
|
|
merged_names = {(n.name, n.type) for n in merged.nodes}
|
|
assert ("张三先生", "Person") in merged_names
|
|
assert ("张三", "Organization") in merged_names
|
|
assert ("北京", "Location") in merged_names
|
|
|
|
# edges 中 "张三" 有歧义(映射到不同 canonical),应保持原名不重写
|
|
assert len(merged.edges) == 2
|
|
|
|
# triples 有类型信息,可精确区分
|
|
assert len(merged.triples) == 2
|
|
person_triple = [t for t in merged.triples if t.subject.type == "Person"][0]
|
|
org_triple = [t for t in merged.triples if t.subject.type == "Organization"][0]
|
|
assert person_triple.subject.name == "张三先生" # Person 被重写
|
|
assert org_triple.subject.name == "张三" # Organization 保持原名
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EntityAligner
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEntityAligner:
|
|
def _run(self, coro):
|
|
"""Helper to run async coroutine in sync test."""
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
return loop.run_until_complete(coro)
|
|
finally:
|
|
loop.close()
|
|
|
|
def test_disabled_returns_original(self):
|
|
aligner = EntityAligner(enabled=False)
|
|
result = _make_result([GraphNode(name="A", type="Person")])
|
|
aligned = self._run(aligner.align(result))
|
|
assert aligned is result
|
|
|
|
def test_single_node_returns_original(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
result = _make_result([GraphNode(name="A", type="Person")])
|
|
aligned = self._run(aligner.align(result))
|
|
assert aligned is result
|
|
|
|
def test_rule_merge_exact_names(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u674e\u56db", type="Person"),
|
|
]
|
|
edges = [
|
|
GraphEdge(source="\u5f20\u4e09", target="\u674e\u56db", relation_type="knows"),
|
|
]
|
|
result = _make_result(nodes, edges)
|
|
aligned = self._run(aligner.align(result))
|
|
assert len(aligned.nodes) == 2
|
|
names = {n.name for n in aligned.nodes}
|
|
assert "\u5f20\u4e09" in names
|
|
assert "\u674e\u56db" in names
|
|
|
|
def test_rule_merge_case_insensitive(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="Hello World", type="Org"),
|
|
GraphNode(name="hello world", type="Org"),
|
|
GraphNode(name="Test", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
aligned = self._run(aligner.align(result))
|
|
assert len(aligned.nodes) == 2
|
|
|
|
def test_rule_merge_deduplicates_edges(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="Hello World", type="Org"),
|
|
GraphNode(name="hello world", type="Org"),
|
|
GraphNode(name="Test", type="Person"),
|
|
]
|
|
edges = [
|
|
GraphEdge(source="Hello World", target="Test", relation_type="employs"),
|
|
GraphEdge(source="hello world", target="Test", relation_type="employs"),
|
|
]
|
|
result = _make_result(nodes, edges)
|
|
aligned = self._run(aligner.align(result))
|
|
assert len(aligned.edges) == 1
|
|
|
|
def test_rule_merge_deduplicates_triples(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
n1 = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
n2 = GraphNode(name="\u5f20\u4e09", type="Person")
|
|
n3 = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
|
triples = [
|
|
Triple(subject=n1, predicate="works_at", object=n3),
|
|
Triple(subject=n2, predicate="works_at", object=n3),
|
|
]
|
|
result = _make_result([n1, n2, n3], triples=triples)
|
|
aligned = self._run(aligner.align(result))
|
|
assert len(aligned.triples) == 1
|
|
|
|
def test_type_mismatch_no_merge(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u5f20\u4e09", type="Organization"),
|
|
]
|
|
result = _make_result(nodes)
|
|
aligned = self._run(aligner.align(result))
|
|
assert len(aligned.nodes) == 2
|
|
|
|
def test_fail_open_on_error(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
with patch.object(aligner, "_align_impl", side_effect=RuntimeError("boom")):
|
|
aligned = self._run(aligner.align(result))
|
|
assert aligned is result
|
|
|
|
def test_align_rules_only_sync(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
|
GraphNode(name="\u674e\u56db", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
aligned = aligner.align_rules_only(result)
|
|
assert len(aligned.nodes) == 2
|
|
|
|
def test_align_rules_only_disabled(self):
|
|
aligner = EntityAligner(enabled=False)
|
|
result = _make_result([GraphNode(name="A", type="Person")])
|
|
aligned = aligner.align_rules_only(result)
|
|
assert aligned is result
|
|
|
|
def test_align_rules_only_fail_open(self):
|
|
aligner = EntityAligner(enabled=True)
|
|
nodes = [
|
|
GraphNode(name="A", type="Person"),
|
|
GraphNode(name="B", type="Person"),
|
|
]
|
|
result = _make_result(nodes)
|
|
with patch(
|
|
"app.module.kg_extraction.aligner.rule_score", side_effect=RuntimeError("boom")
|
|
):
|
|
aligned = aligner.align_rules_only(result)
|
|
assert aligned is result
|
|
|
|
def test_llm_count_incremented_on_failure(self):
|
|
"""P1-2 回归:LLM 仲裁失败也应计入 max_llm_arbitrations 预算。"""
|
|
max_arb = 2
|
|
aligner = EntityAligner(
|
|
enabled=True,
|
|
max_llm_arbitrations=max_arb,
|
|
llm_arbitration_enabled=True,
|
|
)
|
|
# 构建 4 个同类型节点,规则层子串匹配产生多个 vector 候选
|
|
nodes = [
|
|
GraphNode(name="北京大学", type="Organization"),
|
|
GraphNode(name="北京大学计算机学院", type="Organization"),
|
|
GraphNode(name="北京大学数学学院", type="Organization"),
|
|
GraphNode(name="北京大学物理学院", type="Organization"),
|
|
]
|
|
result = _make_result(nodes)
|
|
|
|
# Mock embedding 使所有候选都落入 LLM 仲裁区间
|
|
fake_embedding = [1.0, 0.0, 0.0]
|
|
# 微调使 cosine 在 llm_threshold 和 auto_threshold 之间
|
|
import math
|
|
|
|
# cos(θ) = 0.85 → 在默认 [0.78, 0.92) 区间
|
|
angle = math.acos(0.85)
|
|
emb_a = [1.0, 0.0]
|
|
emb_b = [math.cos(angle), math.sin(angle)]
|
|
|
|
async def fake_embed(texts):
|
|
# 偶数索引返回 emb_a,奇数返回 emb_b
|
|
return [emb_a if i % 2 == 0 else emb_b for i in range(len(texts))]
|
|
|
|
mock_llm_arbitrate = AsyncMock(side_effect=RuntimeError("LLM down"))
|
|
|
|
with patch.object(aligner, "_get_embeddings") as mock_emb:
|
|
mock_emb_instance = AsyncMock()
|
|
mock_emb_instance.aembed_documents = fake_embed
|
|
mock_emb.return_value = mock_emb_instance
|
|
with patch.object(aligner, "_llm_arbitrate", mock_llm_arbitrate):
|
|
aligned = self._run(aligner.align(result))
|
|
|
|
# LLM 应恰好被调用 max_arb 次(不会因异常不计数而超出预算)
|
|
assert mock_llm_arbitrate.call_count <= max_arb
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LLMArbitrationResult
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLLMArbitrationResult:
|
|
def test_valid_parse(self):
|
|
data = {"is_same": True, "confidence": 0.95, "reason": "Same entity"}
|
|
result = LLMArbitrationResult.model_validate(data)
|
|
assert result.is_same is True
|
|
assert result.confidence == 0.95
|
|
|
|
def test_confidence_bounds(self):
|
|
with pytest.raises(Exception):
|
|
LLMArbitrationResult.model_validate(
|
|
{"is_same": True, "confidence": 1.5, "reason": ""}
|
|
)
|
|
|
|
def test_missing_reason_defaults(self):
|
|
result = LLMArbitrationResult.model_validate(
|
|
{"is_same": False, "confidence": 0.1}
|
|
)
|
|
assert result.reason == ""
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|