"""实体对齐器测试。 Run with: pytest app/module/kg_extraction/test_aligner.py -v """ from __future__ import annotations import asyncio from unittest.mock import AsyncMock, patch import pytest from app.module.kg_extraction.aligner import ( EntityAligner, LLMArbitrationResult, _build_merged_result, _make_union_find, cosine_similarity, normalize_name, rule_score, ) from app.module.kg_extraction.models import ( ExtractionResult, GraphEdge, GraphNode, Triple, ) # --------------------------------------------------------------------------- # normalize_name # --------------------------------------------------------------------------- class TestNormalizeName: def test_basic_lowercase(self): assert normalize_name("Hello World") == "hello world" def test_unicode_nfkc(self): assert normalize_name("\uff28ello") == "hello" def test_punctuation_removed(self): assert normalize_name("U.S.A.") == "usa" def test_whitespace_collapsed(self): assert normalize_name(" hello world ") == "hello world" def test_empty_string(self): assert normalize_name("") == "" def test_chinese_preserved(self): assert normalize_name("\u5f20\u4e09") == "\u5f20\u4e09" def test_mixed_chinese_english(self): assert normalize_name("\u5f20\u4e09 (Zhang San)") == "\u5f20\u4e09 zhang san" # --------------------------------------------------------------------------- # rule_score # --------------------------------------------------------------------------- class TestRuleScore: def test_exact_match(self): a = GraphNode(name="\u5f20\u4e09", type="Person") b = GraphNode(name="\u5f20\u4e09", type="Person") assert rule_score(a, b) == 1.0 def test_normalized_match(self): a = GraphNode(name="Hello World", type="Organization") b = GraphNode(name="hello world", type="Organization") assert rule_score(a, b) == 1.0 def test_type_mismatch(self): a = GraphNode(name="\u5f20\u4e09", type="Person") b = GraphNode(name="\u5f20\u4e09", type="Organization") assert rule_score(a, b) == 0.0 def test_substring_match(self): a = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization") b = GraphNode(name="\u5317\u4eac\u5927\u5b66\u8ba1\u7b97\u673a\u5b66\u9662", type="Organization") assert rule_score(a, b) == 0.5 def test_no_match(self): a = GraphNode(name="\u5f20\u4e09", type="Person") b = GraphNode(name="\u674e\u56db", type="Person") assert rule_score(a, b) == 0.0 def test_type_case_insensitive(self): a = GraphNode(name="test", type="PERSON") b = GraphNode(name="test", type="person") assert rule_score(a, b) == 1.0 def test_short_substring_ignored(self): """Single-character substring should not trigger match.""" a = GraphNode(name="A", type="Person") b = GraphNode(name="AB", type="Person") assert rule_score(a, b) == 0.0 # --------------------------------------------------------------------------- # cosine_similarity # --------------------------------------------------------------------------- class TestCosineSimilarity: def test_identical(self): assert cosine_similarity([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0) def test_orthogonal(self): assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0) def test_opposite(self): assert cosine_similarity([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0) def test_zero_vector(self): assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 # --------------------------------------------------------------------------- # Union-Find # --------------------------------------------------------------------------- class TestUnionFind: def test_basic(self): parent, find, union = _make_union_find(4) union(0, 1) union(2, 3) assert find(0) == find(1) assert find(2) == find(3) assert find(0) != find(2) def test_transitive(self): parent, find, union = _make_union_find(3) union(0, 1) union(1, 2) assert find(0) == find(2) # --------------------------------------------------------------------------- # _build_merged_result # --------------------------------------------------------------------------- def _make_result(nodes, edges=None, triples=None): return ExtractionResult( nodes=nodes, edges=edges or [], triples=triples or [], raw_text="test text", source_id="src-1", ) class TestBuildMergedResult: def test_no_merge_returns_original(self): nodes = [ GraphNode(name="A", type="Person"), GraphNode(name="B", type="Person"), ] result = _make_result(nodes) parent, find, _ = _make_union_find(2) merged = _build_merged_result(result, parent, find) assert merged is result def test_canonical_picks_longest_name(self): nodes = [ GraphNode(name="AI", type="Tech"), GraphNode(name="Artificial Intelligence", type="Tech"), ] result = _make_result(nodes) parent, find, union = _make_union_find(2) union(0, 1) merged = _build_merged_result(result, parent, find) assert len(merged.nodes) == 1 assert merged.nodes[0].name == "Artificial Intelligence" def test_edge_remap_and_dedup(self): nodes = [ GraphNode(name="Alice", type="Person"), GraphNode(name="alice", type="Person"), GraphNode(name="Bob", type="Person"), ] edges = [ GraphEdge(source="Alice", target="Bob", relation_type="knows"), GraphEdge(source="alice", target="Bob", relation_type="knows"), ] result = _make_result(nodes, edges) parent, find, union = _make_union_find(3) union(0, 1) merged = _build_merged_result(result, parent, find) assert len(merged.edges) == 1 assert merged.edges[0].source == "Alice" def test_triple_remap_and_dedup(self): n1 = GraphNode(name="Alice", type="Person") n2 = GraphNode(name="alice", type="Person") n3 = GraphNode(name="MIT", type="Organization") triples = [ Triple(subject=n1, predicate="works_at", object=n3), Triple(subject=n2, predicate="works_at", object=n3), ] result = _make_result([n1, n2, n3], triples=triples) parent, find, union = _make_union_find(3) union(0, 1) merged = _build_merged_result(result, parent, find) assert len(merged.triples) == 1 assert merged.triples[0].subject.name == "Alice" def test_preserves_metadata(self): nodes = [ GraphNode(name="A", type="Person"), GraphNode(name="A", type="Person"), ] result = _make_result(nodes) parent, find, union = _make_union_find(2) union(0, 1) merged = _build_merged_result(result, parent, find) assert merged.raw_text == "test text" assert merged.source_id == "src-1" def test_cross_type_same_name_no_collision(self): """P1-1 回归:同名跨类型节点合并不应误映射其他类型的边和三元组。 场景:Person "张三" 和 "张三先生" 合并为 "张三先生", 但 Organization "张三" 不应被重写。 """ nodes = [ GraphNode(name="张三", type="Person"), # idx 0 GraphNode(name="张三先生", type="Person"), # idx 1 GraphNode(name="张三", type="Organization"), # idx 2 - 同名不同类型 GraphNode(name="北京", type="Location"), # idx 3 ] edges = [ GraphEdge(source="张三", target="北京", relation_type="lives_in"), GraphEdge(source="张三", target="北京", relation_type="located_in"), ] triples = [ Triple( subject=GraphNode(name="张三", type="Person"), predicate="lives_in", object=GraphNode(name="北京", type="Location"), ), Triple( subject=GraphNode(name="张三", type="Organization"), predicate="located_in", object=GraphNode(name="北京", type="Location"), ), ] result = _make_result(nodes, edges, triples) parent, find, union = _make_union_find(4) union(0, 1) # 合并 Person "张三" 和 "张三先生" merged = _build_merged_result(result, parent, find) # 应有 3 个节点:张三先生(Person), 张三(Org), 北京(Location) assert len(merged.nodes) == 3 merged_names = {(n.name, n.type) for n in merged.nodes} assert ("张三先生", "Person") in merged_names assert ("张三", "Organization") in merged_names assert ("北京", "Location") in merged_names # edges 中 "张三" 有歧义(映射到不同 canonical),应保持原名不重写 assert len(merged.edges) == 2 # triples 有类型信息,可精确区分 assert len(merged.triples) == 2 person_triple = [t for t in merged.triples if t.subject.type == "Person"][0] org_triple = [t for t in merged.triples if t.subject.type == "Organization"][0] assert person_triple.subject.name == "张三先生" # Person 被重写 assert org_triple.subject.name == "张三" # Organization 保持原名 # --------------------------------------------------------------------------- # EntityAligner # --------------------------------------------------------------------------- class TestEntityAligner: def _run(self, coro): """Helper to run async coroutine in sync test.""" loop = asyncio.new_event_loop() try: return loop.run_until_complete(coro) finally: loop.close() def test_disabled_returns_original(self): aligner = EntityAligner(enabled=False) result = _make_result([GraphNode(name="A", type="Person")]) aligned = self._run(aligner.align(result)) assert aligned is result def test_single_node_returns_original(self): aligner = EntityAligner(enabled=True) result = _make_result([GraphNode(name="A", type="Person")]) aligned = self._run(aligner.align(result)) assert aligned is result def test_rule_merge_exact_names(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u674e\u56db", type="Person"), ] edges = [ GraphEdge(source="\u5f20\u4e09", target="\u674e\u56db", relation_type="knows"), ] result = _make_result(nodes, edges) aligned = self._run(aligner.align(result)) assert len(aligned.nodes) == 2 names = {n.name for n in aligned.nodes} assert "\u5f20\u4e09" in names assert "\u674e\u56db" in names def test_rule_merge_case_insensitive(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="Hello World", type="Org"), GraphNode(name="hello world", type="Org"), GraphNode(name="Test", type="Person"), ] result = _make_result(nodes) aligned = self._run(aligner.align(result)) assert len(aligned.nodes) == 2 def test_rule_merge_deduplicates_edges(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="Hello World", type="Org"), GraphNode(name="hello world", type="Org"), GraphNode(name="Test", type="Person"), ] edges = [ GraphEdge(source="Hello World", target="Test", relation_type="employs"), GraphEdge(source="hello world", target="Test", relation_type="employs"), ] result = _make_result(nodes, edges) aligned = self._run(aligner.align(result)) assert len(aligned.edges) == 1 def test_rule_merge_deduplicates_triples(self): aligner = EntityAligner(enabled=True) n1 = GraphNode(name="\u5f20\u4e09", type="Person") n2 = GraphNode(name="\u5f20\u4e09", type="Person") n3 = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization") triples = [ Triple(subject=n1, predicate="works_at", object=n3), Triple(subject=n2, predicate="works_at", object=n3), ] result = _make_result([n1, n2, n3], triples=triples) aligned = self._run(aligner.align(result)) assert len(aligned.triples) == 1 def test_type_mismatch_no_merge(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u5f20\u4e09", type="Organization"), ] result = _make_result(nodes) aligned = self._run(aligner.align(result)) assert len(aligned.nodes) == 2 def test_fail_open_on_error(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u5f20\u4e09", type="Person"), ] result = _make_result(nodes) with patch.object(aligner, "_align_impl", side_effect=RuntimeError("boom")): aligned = self._run(aligner.align(result)) assert aligned is result def test_align_rules_only_sync(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u5f20\u4e09", type="Person"), GraphNode(name="\u674e\u56db", type="Person"), ] result = _make_result(nodes) aligned = aligner.align_rules_only(result) assert len(aligned.nodes) == 2 def test_align_rules_only_disabled(self): aligner = EntityAligner(enabled=False) result = _make_result([GraphNode(name="A", type="Person")]) aligned = aligner.align_rules_only(result) assert aligned is result def test_align_rules_only_fail_open(self): aligner = EntityAligner(enabled=True) nodes = [ GraphNode(name="A", type="Person"), GraphNode(name="B", type="Person"), ] result = _make_result(nodes) with patch( "app.module.kg_extraction.aligner.rule_score", side_effect=RuntimeError("boom") ): aligned = aligner.align_rules_only(result) assert aligned is result def test_llm_count_incremented_on_failure(self): """P1-2 回归:LLM 仲裁失败也应计入 max_llm_arbitrations 预算。""" max_arb = 2 aligner = EntityAligner( enabled=True, max_llm_arbitrations=max_arb, llm_arbitration_enabled=True, ) # 构建 4 个同类型节点,规则层子串匹配产生多个 vector 候选 nodes = [ GraphNode(name="北京大学", type="Organization"), GraphNode(name="北京大学计算机学院", type="Organization"), GraphNode(name="北京大学数学学院", type="Organization"), GraphNode(name="北京大学物理学院", type="Organization"), ] result = _make_result(nodes) # Mock embedding 使所有候选都落入 LLM 仲裁区间 fake_embedding = [1.0, 0.0, 0.0] # 微调使 cosine 在 llm_threshold 和 auto_threshold 之间 import math # cos(θ) = 0.85 → 在默认 [0.78, 0.92) 区间 angle = math.acos(0.85) emb_a = [1.0, 0.0] emb_b = [math.cos(angle), math.sin(angle)] async def fake_embed(texts): # 偶数索引返回 emb_a,奇数返回 emb_b return [emb_a if i % 2 == 0 else emb_b for i in range(len(texts))] mock_llm_arbitrate = AsyncMock(side_effect=RuntimeError("LLM down")) with patch.object(aligner, "_get_embeddings") as mock_emb: mock_emb_instance = AsyncMock() mock_emb_instance.aembed_documents = fake_embed mock_emb.return_value = mock_emb_instance with patch.object(aligner, "_llm_arbitrate", mock_llm_arbitrate): aligned = self._run(aligner.align(result)) # LLM 应恰好被调用 max_arb 次(不会因异常不计数而超出预算) assert mock_llm_arbitrate.call_count <= max_arb # --------------------------------------------------------------------------- # LLMArbitrationResult # --------------------------------------------------------------------------- class TestLLMArbitrationResult: def test_valid_parse(self): data = {"is_same": True, "confidence": 0.95, "reason": "Same entity"} result = LLMArbitrationResult.model_validate(data) assert result.is_same is True assert result.confidence == 0.95 def test_confidence_bounds(self): with pytest.raises(Exception): LLMArbitrationResult.model_validate( {"is_same": True, "confidence": 1.5, "reason": ""} ) def test_missing_reason_defaults(self): result = LLMArbitrationResult.model_validate( {"is_same": False, "confidence": 0.1} ) assert result.reason == "" if __name__ == "__main__": pytest.main([__file__, "-v"])