diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index 82bd65a..a3fcea7 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -82,6 +82,12 @@ class Settings(BaseSettings): kg_llm_timeout_seconds: int = 60 kg_llm_max_retries: int = 2 + # Knowledge Graph - 实体对齐配置 + kg_alignment_enabled: bool = False + kg_alignment_embedding_model: str = "text-embedding-3-small" + kg_alignment_vector_threshold: float = 0.92 + kg_alignment_llm_threshold: float = 0.78 + # 标注编辑器(Label Studio Editor)相关 editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数 diff --git a/runtime/datamate-python/app/module/kg_extraction/__init__.py b/runtime/datamate-python/app/module/kg_extraction/__init__.py index f8a973b..d302668 100644 --- a/runtime/datamate-python/app/module/kg_extraction/__init__.py +++ b/runtime/datamate-python/app/module/kg_extraction/__init__.py @@ -1,3 +1,4 @@ +from app.module.kg_extraction.aligner import EntityAligner from app.module.kg_extraction.extractor import KnowledgeGraphExtractor from app.module.kg_extraction.models import ( ExtractionRequest, @@ -9,6 +10,7 @@ from app.module.kg_extraction.models import ( from app.module.kg_extraction.interface import router __all__ = [ + "EntityAligner", "KnowledgeGraphExtractor", "ExtractionRequest", "ExtractionResult", diff --git a/runtime/datamate-python/app/module/kg_extraction/aligner.py b/runtime/datamate-python/app/module/kg_extraction/aligner.py new file mode 100644 index 0000000..f542452 --- /dev/null +++ b/runtime/datamate-python/app/module/kg_extraction/aligner.py @@ -0,0 +1,478 @@ +"""实体对齐器:对抽取结果中的实体进行去重和合并。 + +三层对齐策略: +1. 规则层:名称规范化 + 别名匹配 + 类型硬过滤 +2. 向量相似度层:基于 embedding 的 cosine 相似度 +3. LLM 仲裁层:仅对边界样本调用,严格 JSON schema 校验 + +失败策略:fail-open —— 对齐失败不阻断抽取请求。 +""" + +from __future__ import annotations + +import json +import re +import unicodedata + +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from pydantic import BaseModel, Field, SecretStr + +from app.core.logging import get_logger +from app.module.kg_extraction.models import ( + ExtractionResult, + GraphEdge, + GraphNode, + Triple, +) + +logger = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Rule Layer +# --------------------------------------------------------------------------- + + +def normalize_name(name: str) -> str: + """名称规范化:Unicode NFKC -> 小写 -> 去标点 -> 合并空白。""" + name = unicodedata.normalize("NFKC", name) + name = name.lower() + name = re.sub(r"[^\w\s]", "", name) + name = re.sub(r"\s+", " ", name).strip() + return name + + +def rule_score(a: GraphNode, b: GraphNode) -> float: + """规则层匹配分数。 + + Returns: + 1.0 规范化名称完全一致且类型兼容 + 0.5 一方名称是另一方子串且类型兼容(别名/缩写) + 0.0 类型不兼容或名称无关联 + """ + # 类型硬过滤 + if a.type.lower() != b.type.lower(): + return 0.0 + + norm_a = normalize_name(a.name) + norm_b = normalize_name(b.name) + + # 完全匹配 + if norm_a == norm_b: + return 1.0 + + # 子串匹配(别名/缩写),要求双方规范化名称至少 2 字符 + if len(norm_a) >= 2 and len(norm_b) >= 2: + if norm_a in norm_b or norm_b in norm_a: + return 0.5 + + return 0.0 + + +# --------------------------------------------------------------------------- +# Vector Similarity Layer +# --------------------------------------------------------------------------- + + +def cosine_similarity(a: list[float], b: list[float]) -> float: + """计算两个向量的余弦相似度。""" + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + return dot / (norm_a * norm_b) + + +def _entity_text(node: GraphNode) -> str: + """构造用于 embedding 的实体文本表示。""" + return f"{node.type}: {node.name}" + + +# --------------------------------------------------------------------------- +# LLM Arbitration Layer +# --------------------------------------------------------------------------- + +_LLM_PROMPT = ( + "判断以下两个实体是否指向同一个现实世界的实体或概念。\n\n" + "实体 A:\n- 名称: {name_a}\n- 类型: {type_a}\n\n" + "实体 B:\n- 名称: {name_b}\n- 类型: {type_b}\n\n" + '请严格按以下 JSON 格式返回,不要包含任何其他内容:\n' + '{{"is_same": true, "confidence": 0.95, "reason": "简要理由"}}' +) + + +class LLMArbitrationResult(BaseModel): + """LLM 仲裁返回结构。""" + + is_same: bool + confidence: float = Field(ge=0.0, le=1.0) + reason: str = "" + + +# --------------------------------------------------------------------------- +# Union-Find +# --------------------------------------------------------------------------- + + +def _make_union_find(n: int): + """创建 Union-Find 数据结构,返回 (parent, find, union)。""" + parent = list(range(n)) + + def find(x: int) -> int: + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(x: int, y: int) -> None: + px, py = find(x), find(y) + if px != py: + parent[px] = py + + return parent, find, union + + +# --------------------------------------------------------------------------- +# Merge Result Builder +# --------------------------------------------------------------------------- + + +def _build_merged_result( + original: ExtractionResult, + parent: list[int], + find, +) -> ExtractionResult: + """根据 Union-Find 结果构建合并后的 ExtractionResult。""" + nodes = original.nodes + + # Group by root + groups: dict[int, list[int]] = {} + for i in range(len(nodes)): + root = find(i) + groups.setdefault(root, []).append(i) + + # 无合并发生时直接返回原结果 + if len(groups) == len(nodes): + return original + + # Canonical: 选择每组中名称最长的节点 + # 使用 (name, type) 作为 key 避免同名跨类型节点误映射 + node_map: dict[tuple[str, str], str] = {} + merged_nodes: list[GraphNode] = [] + for members in groups.values(): + best_idx = max(members, key=lambda idx: len(nodes[idx].name)) + canon = nodes[best_idx] + merged_nodes.append(canon) + for idx in members: + node_map[(nodes[idx].name, nodes[idx].type)] = canon.name + + logger.info( + "Alignment merged %d nodes -> %d nodes", + len(nodes), + len(merged_nodes), + ) + + # 为 edges 构建仅名称的映射(仅当同名节点映射结果无歧义时才包含) + _edge_remap: dict[str, set[str]] = {} + for (name, _type), canon_name in node_map.items(): + _edge_remap.setdefault(name, set()).add(canon_name) + edge_name_map: dict[str, str] = { + name: next(iter(canon_names)) + for name, canon_names in _edge_remap.items() + if len(canon_names) == 1 + } + + # 更新 edges(重命名 + 去重) + seen_edges: set[str] = set() + merged_edges: list[GraphEdge] = [] + for edge in original.edges: + src = edge_name_map.get(edge.source, edge.source) + tgt = edge_name_map.get(edge.target, edge.target) + key = f"{src}|{edge.relation_type}|{tgt}" + if key not in seen_edges: + seen_edges.add(key) + merged_edges.append( + GraphEdge( + source=src, + target=tgt, + relation_type=edge.relation_type, + properties=edge.properties, + ) + ) + + # 更新 triples(使用 (name, type) 精确查找,避免跨类型误映射) + seen_triples: set[str] = set() + merged_triples: list[Triple] = [] + for triple in original.triples: + sub_key = (triple.subject.name, triple.subject.type) + obj_key = (triple.object.name, triple.object.type) + sub_name = node_map.get(sub_key, triple.subject.name) + obj_name = node_map.get(obj_key, triple.object.name) + key = f"{sub_name}|{triple.predicate}|{obj_name}" + if key not in seen_triples: + seen_triples.add(key) + merged_triples.append( + Triple( + subject=GraphNode(name=sub_name, type=triple.subject.type), + predicate=triple.predicate, + object=GraphNode(name=obj_name, type=triple.object.type), + ) + ) + + return ExtractionResult( + nodes=merged_nodes, + edges=merged_edges, + triples=merged_triples, + raw_text=original.raw_text, + source_id=original.source_id, + ) + + +# --------------------------------------------------------------------------- +# EntityAligner +# --------------------------------------------------------------------------- + + +class EntityAligner: + """实体对齐器。 + + 通过 ``from_settings()`` 工厂方法从全局配置创建实例, + 也可直接构造以覆盖默认参数。 + """ + + def __init__( + self, + *, + enabled: bool = False, + embedding_model: str = "text-embedding-3-small", + embedding_base_url: str | None = None, + embedding_api_key: SecretStr = SecretStr("EMPTY"), + llm_model: str = "gpt-4o-mini", + llm_base_url: str | None = None, + llm_api_key: SecretStr = SecretStr("EMPTY"), + llm_timeout: int = 30, + vector_auto_merge_threshold: float = 0.92, + vector_llm_threshold: float = 0.78, + llm_arbitration_enabled: bool = True, + max_llm_arbitrations: int = 10, + ) -> None: + self._enabled = enabled + self._embedding_model = embedding_model + self._embedding_base_url = embedding_base_url + self._embedding_api_key = embedding_api_key + self._llm_model = llm_model + self._llm_base_url = llm_base_url + self._llm_api_key = llm_api_key + self._llm_timeout = llm_timeout + self._vector_auto_threshold = vector_auto_merge_threshold + self._vector_llm_threshold = vector_llm_threshold + self._llm_arbitration_enabled = llm_arbitration_enabled + self._max_llm_arbitrations = max_llm_arbitrations + # Lazy init + self._embeddings: OpenAIEmbeddings | None = None + self._llm: ChatOpenAI | None = None + + @classmethod + def from_settings(cls) -> EntityAligner: + """从全局 Settings 创建对齐器实例。""" + from app.core.config import settings + + return cls( + enabled=settings.kg_alignment_enabled, + embedding_model=settings.kg_alignment_embedding_model, + embedding_base_url=settings.kg_llm_base_url, + embedding_api_key=settings.kg_llm_api_key, + llm_model=settings.kg_llm_model, + llm_base_url=settings.kg_llm_base_url, + llm_api_key=settings.kg_llm_api_key, + llm_timeout=settings.kg_llm_timeout_seconds, + vector_auto_merge_threshold=settings.kg_alignment_vector_threshold, + vector_llm_threshold=settings.kg_alignment_llm_threshold, + ) + + def _get_embeddings(self) -> OpenAIEmbeddings: + if self._embeddings is None: + self._embeddings = OpenAIEmbeddings( + model=self._embedding_model, + base_url=self._embedding_base_url, + api_key=self._embedding_api_key, + ) + return self._embeddings + + def _get_llm(self) -> ChatOpenAI: + if self._llm is None: + self._llm = ChatOpenAI( + model=self._llm_model, + base_url=self._llm_base_url, + api_key=self._llm_api_key, + temperature=0.0, + timeout=self._llm_timeout, + ) + return self._llm + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def align(self, result: ExtractionResult) -> ExtractionResult: + """对抽取结果中的实体进行对齐去重(异步,三层策略)。 + + Fail-open:对齐失败时返回原始结果,不阻断请求。 + + 注意:当前仅支持批内对齐(单次抽取结果内部的 pairwise 合并)。 + 库内对齐(对现有图谱实体召回/匹配)需要 KG 服务 API 支持,待后续实现。 + """ + if not self._enabled or len(result.nodes) <= 1: + return result + + try: + return await self._align_impl(result) + except Exception: + logger.exception( + "Entity alignment failed, returning original result (fail-open)" + ) + return result + + def align_rules_only(self, result: ExtractionResult) -> ExtractionResult: + """仅使用规则层对齐(同步,用于 extract_sync 路径)。 + + Fail-open:对齐失败时返回原始结果。 + """ + if not self._enabled or len(result.nodes) <= 1: + return result + + try: + nodes = result.nodes + parent, find, union = _make_union_find(len(nodes)) + + for i in range(len(nodes)): + for j in range(i + 1, len(nodes)): + if find(i) == find(j): + continue + if rule_score(nodes[i], nodes[j]) >= 1.0: + union(i, j) + + return _build_merged_result(result, parent, find) + except Exception: + logger.exception( + "Rule-only alignment failed, returning original result (fail-open)" + ) + return result + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + async def _align_impl(self, result: ExtractionResult) -> ExtractionResult: + """三层对齐的核心实现。 + + 当前仅在单次抽取结果的节点列表内做 pairwise 对齐。 + 若需与已有图谱实体匹配(库内对齐),需扩展入参以支持 + graph_id + 候选实体检索上下文,依赖 KG 服务 API。 + """ + nodes = result.nodes + n = len(nodes) + parent, find, union = _make_union_find(n) + + # Phase 1: Rule layer + vector_candidates: list[tuple[int, int]] = [] + for i in range(n): + for j in range(i + 1, n): + if find(i) == find(j): + continue + score = rule_score(nodes[i], nodes[j]) + if score >= 1.0: + union(i, j) + logger.debug( + "Rule merge: '%s' <-> '%s'", nodes[i].name, nodes[j].name + ) + elif score > 0: + vector_candidates.append((i, j)) + + # Phase 2: Vector similarity + llm_candidates: list[tuple[int, int, float]] = [] + if vector_candidates: + try: + emb_map = await self._embed_candidates(nodes, vector_candidates) + for i, j in vector_candidates: + if find(i) == find(j): + continue + sim = cosine_similarity(emb_map[i], emb_map[j]) + if sim >= self._vector_auto_threshold: + union(i, j) + logger.debug( + "Vector merge: '%s' <-> '%s' (sim=%.3f)", + nodes[i].name, + nodes[j].name, + sim, + ) + elif sim >= self._vector_llm_threshold: + llm_candidates.append((i, j, sim)) + except Exception: + logger.warning( + "Vector similarity failed, skipping vector layer", exc_info=True + ) + + # Phase 3: LLM arbitration (boundary cases only) + if llm_candidates and self._llm_arbitration_enabled: + llm_count = 0 + for i, j, sim in llm_candidates: + if llm_count >= self._max_llm_arbitrations or find(i) == find(j): + continue + try: + if await self._llm_arbitrate(nodes[i], nodes[j]): + union(i, j) + logger.debug( + "LLM merge: '%s' <-> '%s' (sim=%.3f)", + nodes[i].name, + nodes[j].name, + sim, + ) + except Exception: + logger.warning( + "LLM arbitration failed for '%s' <-> '%s'", + nodes[i].name, + nodes[j].name, + ) + finally: + llm_count += 1 + + return _build_merged_result(result, parent, find) + + async def _embed_candidates( + self, nodes: list[GraphNode], candidates: list[tuple[int, int]] + ) -> dict[int, list[float]]: + """对候选实体计算 embedding,返回 {index: embedding}。""" + unique_indices: set[int] = set() + for i, j in candidates: + unique_indices.add(i) + unique_indices.add(j) + + idx_list = sorted(unique_indices) + texts = [_entity_text(nodes[i]) for i in idx_list] + embeddings = await self._get_embeddings().aembed_documents(texts) + return dict(zip(idx_list, embeddings)) + + async def _llm_arbitrate(self, a: GraphNode, b: GraphNode) -> bool: + """LLM 仲裁两个实体是否相同,严格 JSON schema 校验。""" + prompt = _LLM_PROMPT.format( + name_a=a.name, + type_a=a.type, + name_b=b.name, + type_b=b.type, + ) + response = await self._get_llm().ainvoke(prompt) + content = response.content.strip() + + parsed = json.loads(content) + result = LLMArbitrationResult.model_validate(parsed) + + logger.debug( + "LLM arbitration: '%s' <-> '%s' -> is_same=%s, confidence=%.2f", + a.name, + b.name, + result.is_same, + result.confidence, + ) + return result.is_same and result.confidence >= 0.7 diff --git a/runtime/datamate-python/app/module/kg_extraction/extractor.py b/runtime/datamate-python/app/module/kg_extraction/extractor.py index d3587c3..8d4ff80 100644 --- a/runtime/datamate-python/app/module/kg_extraction/extractor.py +++ b/runtime/datamate-python/app/module/kg_extraction/extractor.py @@ -15,6 +15,7 @@ from langchain_experimental.graph_transformers import LLMGraphTransformer from pydantic import SecretStr from app.core.logging import get_logger +from app.module.kg_extraction.aligner import EntityAligner from app.module.kg_extraction.models import ( ExtractionRequest, ExtractionResult, @@ -47,6 +48,7 @@ class KnowledgeGraphExtractor: temperature: float = 0.0, timeout: int = 60, max_retries: int = 2, + aligner: EntityAligner | None = None, ) -> None: logger.info( "Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)", @@ -63,6 +65,7 @@ class KnowledgeGraphExtractor: timeout=timeout, max_retries=max_retries, ) + self._aligner = aligner or EntityAligner() @classmethod def from_settings(cls) -> KnowledgeGraphExtractor: @@ -76,6 +79,7 @@ class KnowledgeGraphExtractor: temperature=settings.kg_llm_temperature, timeout=settings.kg_llm_timeout_seconds, max_retries=settings.kg_llm_max_retries, + aligner=EntityAligner.from_settings(), ) def _build_transformer( @@ -119,6 +123,7 @@ class KnowledgeGraphExtractor: raise result = self._convert_result(graph_documents, request) + result = await self._aligner.align(result) logger.info( "Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d", request.graph_id, @@ -154,6 +159,7 @@ class KnowledgeGraphExtractor: raise result = self._convert_result(graph_documents, request) + result = self._aligner.align_rules_only(result) logger.info( "Sync extraction complete: graph_id=%s, nodes=%d, edges=%d", request.graph_id, diff --git a/runtime/datamate-python/app/module/kg_extraction/test_aligner.py b/runtime/datamate-python/app/module/kg_extraction/test_aligner.py new file mode 100644 index 0000000..f3e716d --- /dev/null +++ b/runtime/datamate-python/app/module/kg_extraction/test_aligner.py @@ -0,0 +1,477 @@ +"""实体对齐器测试。 + +Run with: pytest app/module/kg_extraction/test_aligner.py -v +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from app.module.kg_extraction.aligner import ( + EntityAligner, + LLMArbitrationResult, + _build_merged_result, + _make_union_find, + cosine_similarity, + normalize_name, + rule_score, +) +from app.module.kg_extraction.models import ( + ExtractionResult, + GraphEdge, + GraphNode, + Triple, +) + + +# --------------------------------------------------------------------------- +# normalize_name +# --------------------------------------------------------------------------- + + +class TestNormalizeName: + def test_basic_lowercase(self): + assert normalize_name("Hello World") == "hello world" + + def test_unicode_nfkc(self): + assert normalize_name("\uff28ello") == "hello" + + def test_punctuation_removed(self): + assert normalize_name("U.S.A.") == "usa" + + def test_whitespace_collapsed(self): + assert normalize_name(" hello world ") == "hello world" + + def test_empty_string(self): + assert normalize_name("") == "" + + def test_chinese_preserved(self): + assert normalize_name("\u5f20\u4e09") == "\u5f20\u4e09" + + def test_mixed_chinese_english(self): + assert normalize_name("\u5f20\u4e09 (Zhang San)") == "\u5f20\u4e09 zhang san" + + +# --------------------------------------------------------------------------- +# rule_score +# --------------------------------------------------------------------------- + + +class TestRuleScore: + def test_exact_match(self): + a = GraphNode(name="\u5f20\u4e09", type="Person") + b = GraphNode(name="\u5f20\u4e09", type="Person") + assert rule_score(a, b) == 1.0 + + def test_normalized_match(self): + a = GraphNode(name="Hello World", type="Organization") + b = GraphNode(name="hello world", type="Organization") + assert rule_score(a, b) == 1.0 + + def test_type_mismatch(self): + a = GraphNode(name="\u5f20\u4e09", type="Person") + b = GraphNode(name="\u5f20\u4e09", type="Organization") + assert rule_score(a, b) == 0.0 + + def test_substring_match(self): + a = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization") + b = GraphNode(name="\u5317\u4eac\u5927\u5b66\u8ba1\u7b97\u673a\u5b66\u9662", type="Organization") + assert rule_score(a, b) == 0.5 + + def test_no_match(self): + a = GraphNode(name="\u5f20\u4e09", type="Person") + b = GraphNode(name="\u674e\u56db", type="Person") + assert rule_score(a, b) == 0.0 + + def test_type_case_insensitive(self): + a = GraphNode(name="test", type="PERSON") + b = GraphNode(name="test", type="person") + assert rule_score(a, b) == 1.0 + + def test_short_substring_ignored(self): + """Single-character substring should not trigger match.""" + a = GraphNode(name="A", type="Person") + b = GraphNode(name="AB", type="Person") + assert rule_score(a, b) == 0.0 + + +# --------------------------------------------------------------------------- +# cosine_similarity +# --------------------------------------------------------------------------- + + +class TestCosineSimilarity: + def test_identical(self): + assert cosine_similarity([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0) + + def test_orthogonal(self): + assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0) + + def test_opposite(self): + assert cosine_similarity([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0) + + def test_zero_vector(self): + assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 + + +# --------------------------------------------------------------------------- +# Union-Find +# --------------------------------------------------------------------------- + + +class TestUnionFind: + def test_basic(self): + parent, find, union = _make_union_find(4) + union(0, 1) + union(2, 3) + assert find(0) == find(1) + assert find(2) == find(3) + assert find(0) != find(2) + + def test_transitive(self): + parent, find, union = _make_union_find(3) + union(0, 1) + union(1, 2) + assert find(0) == find(2) + + +# --------------------------------------------------------------------------- +# _build_merged_result +# --------------------------------------------------------------------------- + + +def _make_result(nodes, edges=None, triples=None): + return ExtractionResult( + nodes=nodes, + edges=edges or [], + triples=triples or [], + raw_text="test text", + source_id="src-1", + ) + + +class TestBuildMergedResult: + def test_no_merge_returns_original(self): + nodes = [ + GraphNode(name="A", type="Person"), + GraphNode(name="B", type="Person"), + ] + result = _make_result(nodes) + parent, find, _ = _make_union_find(2) + merged = _build_merged_result(result, parent, find) + assert merged is result + + def test_canonical_picks_longest_name(self): + nodes = [ + GraphNode(name="AI", type="Tech"), + GraphNode(name="Artificial Intelligence", type="Tech"), + ] + result = _make_result(nodes) + parent, find, union = _make_union_find(2) + union(0, 1) + merged = _build_merged_result(result, parent, find) + assert len(merged.nodes) == 1 + assert merged.nodes[0].name == "Artificial Intelligence" + + def test_edge_remap_and_dedup(self): + nodes = [ + GraphNode(name="Alice", type="Person"), + GraphNode(name="alice", type="Person"), + GraphNode(name="Bob", type="Person"), + ] + edges = [ + GraphEdge(source="Alice", target="Bob", relation_type="knows"), + GraphEdge(source="alice", target="Bob", relation_type="knows"), + ] + result = _make_result(nodes, edges) + parent, find, union = _make_union_find(3) + union(0, 1) + merged = _build_merged_result(result, parent, find) + assert len(merged.edges) == 1 + assert merged.edges[0].source == "Alice" + + def test_triple_remap_and_dedup(self): + n1 = GraphNode(name="Alice", type="Person") + n2 = GraphNode(name="alice", type="Person") + n3 = GraphNode(name="MIT", type="Organization") + triples = [ + Triple(subject=n1, predicate="works_at", object=n3), + Triple(subject=n2, predicate="works_at", object=n3), + ] + result = _make_result([n1, n2, n3], triples=triples) + parent, find, union = _make_union_find(3) + union(0, 1) + merged = _build_merged_result(result, parent, find) + assert len(merged.triples) == 1 + assert merged.triples[0].subject.name == "Alice" + + def test_preserves_metadata(self): + nodes = [ + GraphNode(name="A", type="Person"), + GraphNode(name="A", type="Person"), + ] + result = _make_result(nodes) + parent, find, union = _make_union_find(2) + union(0, 1) + merged = _build_merged_result(result, parent, find) + assert merged.raw_text == "test text" + assert merged.source_id == "src-1" + + def test_cross_type_same_name_no_collision(self): + """P1-1 回归:同名跨类型节点合并不应误映射其他类型的边和三元组。 + + 场景:Person "张三" 和 "张三先生" 合并为 "张三先生", + 但 Organization "张三" 不应被重写。 + """ + nodes = [ + GraphNode(name="张三", type="Person"), # idx 0 + GraphNode(name="张三先生", type="Person"), # idx 1 + GraphNode(name="张三", type="Organization"), # idx 2 - 同名不同类型 + GraphNode(name="北京", type="Location"), # idx 3 + ] + edges = [ + GraphEdge(source="张三", target="北京", relation_type="lives_in"), + GraphEdge(source="张三", target="北京", relation_type="located_in"), + ] + triples = [ + Triple( + subject=GraphNode(name="张三", type="Person"), + predicate="lives_in", + object=GraphNode(name="北京", type="Location"), + ), + Triple( + subject=GraphNode(name="张三", type="Organization"), + predicate="located_in", + object=GraphNode(name="北京", type="Location"), + ), + ] + result = _make_result(nodes, edges, triples) + parent, find, union = _make_union_find(4) + union(0, 1) # 合并 Person "张三" 和 "张三先生" + merged = _build_merged_result(result, parent, find) + + # 应有 3 个节点:张三先生(Person), 张三(Org), 北京(Location) + assert len(merged.nodes) == 3 + merged_names = {(n.name, n.type) for n in merged.nodes} + assert ("张三先生", "Person") in merged_names + assert ("张三", "Organization") in merged_names + assert ("北京", "Location") in merged_names + + # edges 中 "张三" 有歧义(映射到不同 canonical),应保持原名不重写 + assert len(merged.edges) == 2 + + # triples 有类型信息,可精确区分 + assert len(merged.triples) == 2 + person_triple = [t for t in merged.triples if t.subject.type == "Person"][0] + org_triple = [t for t in merged.triples if t.subject.type == "Organization"][0] + assert person_triple.subject.name == "张三先生" # Person 被重写 + assert org_triple.subject.name == "张三" # Organization 保持原名 + + +# --------------------------------------------------------------------------- +# EntityAligner +# --------------------------------------------------------------------------- + + +class TestEntityAligner: + def _run(self, coro): + """Helper to run async coroutine in sync test.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + def test_disabled_returns_original(self): + aligner = EntityAligner(enabled=False) + result = _make_result([GraphNode(name="A", type="Person")]) + aligned = self._run(aligner.align(result)) + assert aligned is result + + def test_single_node_returns_original(self): + aligner = EntityAligner(enabled=True) + result = _make_result([GraphNode(name="A", type="Person")]) + aligned = self._run(aligner.align(result)) + assert aligned is result + + def test_rule_merge_exact_names(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u674e\u56db", type="Person"), + ] + edges = [ + GraphEdge(source="\u5f20\u4e09", target="\u674e\u56db", relation_type="knows"), + ] + result = _make_result(nodes, edges) + aligned = self._run(aligner.align(result)) + assert len(aligned.nodes) == 2 + names = {n.name for n in aligned.nodes} + assert "\u5f20\u4e09" in names + assert "\u674e\u56db" in names + + def test_rule_merge_case_insensitive(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="Hello World", type="Org"), + GraphNode(name="hello world", type="Org"), + GraphNode(name="Test", type="Person"), + ] + result = _make_result(nodes) + aligned = self._run(aligner.align(result)) + assert len(aligned.nodes) == 2 + + def test_rule_merge_deduplicates_edges(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="Hello World", type="Org"), + GraphNode(name="hello world", type="Org"), + GraphNode(name="Test", type="Person"), + ] + edges = [ + GraphEdge(source="Hello World", target="Test", relation_type="employs"), + GraphEdge(source="hello world", target="Test", relation_type="employs"), + ] + result = _make_result(nodes, edges) + aligned = self._run(aligner.align(result)) + assert len(aligned.edges) == 1 + + def test_rule_merge_deduplicates_triples(self): + aligner = EntityAligner(enabled=True) + n1 = GraphNode(name="\u5f20\u4e09", type="Person") + n2 = GraphNode(name="\u5f20\u4e09", type="Person") + n3 = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization") + triples = [ + Triple(subject=n1, predicate="works_at", object=n3), + Triple(subject=n2, predicate="works_at", object=n3), + ] + result = _make_result([n1, n2, n3], triples=triples) + aligned = self._run(aligner.align(result)) + assert len(aligned.triples) == 1 + + def test_type_mismatch_no_merge(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u5f20\u4e09", type="Organization"), + ] + result = _make_result(nodes) + aligned = self._run(aligner.align(result)) + assert len(aligned.nodes) == 2 + + def test_fail_open_on_error(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u5f20\u4e09", type="Person"), + ] + result = _make_result(nodes) + with patch.object(aligner, "_align_impl", side_effect=RuntimeError("boom")): + aligned = self._run(aligner.align(result)) + assert aligned is result + + def test_align_rules_only_sync(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u5f20\u4e09", type="Person"), + GraphNode(name="\u674e\u56db", type="Person"), + ] + result = _make_result(nodes) + aligned = aligner.align_rules_only(result) + assert len(aligned.nodes) == 2 + + def test_align_rules_only_disabled(self): + aligner = EntityAligner(enabled=False) + result = _make_result([GraphNode(name="A", type="Person")]) + aligned = aligner.align_rules_only(result) + assert aligned is result + + def test_align_rules_only_fail_open(self): + aligner = EntityAligner(enabled=True) + nodes = [ + GraphNode(name="A", type="Person"), + GraphNode(name="B", type="Person"), + ] + result = _make_result(nodes) + with patch( + "app.module.kg_extraction.aligner.rule_score", side_effect=RuntimeError("boom") + ): + aligned = aligner.align_rules_only(result) + assert aligned is result + + def test_llm_count_incremented_on_failure(self): + """P1-2 回归:LLM 仲裁失败也应计入 max_llm_arbitrations 预算。""" + max_arb = 2 + aligner = EntityAligner( + enabled=True, + max_llm_arbitrations=max_arb, + llm_arbitration_enabled=True, + ) + # 构建 4 个同类型节点,规则层子串匹配产生多个 vector 候选 + nodes = [ + GraphNode(name="北京大学", type="Organization"), + GraphNode(name="北京大学计算机学院", type="Organization"), + GraphNode(name="北京大学数学学院", type="Organization"), + GraphNode(name="北京大学物理学院", type="Organization"), + ] + result = _make_result(nodes) + + # Mock embedding 使所有候选都落入 LLM 仲裁区间 + fake_embedding = [1.0, 0.0, 0.0] + # 微调使 cosine 在 llm_threshold 和 auto_threshold 之间 + import math + + # cos(θ) = 0.85 → 在默认 [0.78, 0.92) 区间 + angle = math.acos(0.85) + emb_a = [1.0, 0.0] + emb_b = [math.cos(angle), math.sin(angle)] + + async def fake_embed(texts): + # 偶数索引返回 emb_a,奇数返回 emb_b + return [emb_a if i % 2 == 0 else emb_b for i in range(len(texts))] + + mock_llm_arbitrate = AsyncMock(side_effect=RuntimeError("LLM down")) + + with patch.object(aligner, "_get_embeddings") as mock_emb: + mock_emb_instance = AsyncMock() + mock_emb_instance.aembed_documents = fake_embed + mock_emb.return_value = mock_emb_instance + with patch.object(aligner, "_llm_arbitrate", mock_llm_arbitrate): + aligned = self._run(aligner.align(result)) + + # LLM 应恰好被调用 max_arb 次(不会因异常不计数而超出预算) + assert mock_llm_arbitrate.call_count <= max_arb + + +# --------------------------------------------------------------------------- +# LLMArbitrationResult +# --------------------------------------------------------------------------- + + +class TestLLMArbitrationResult: + def test_valid_parse(self): + data = {"is_same": True, "confidence": 0.95, "reason": "Same entity"} + result = LLMArbitrationResult.model_validate(data) + assert result.is_same is True + assert result.confidence == 0.95 + + def test_confidence_bounds(self): + with pytest.raises(Exception): + LLMArbitrationResult.model_validate( + {"is_same": True, "confidence": 1.5, "reason": ""} + ) + + def test_missing_reason_defaults(self): + result = LLMArbitrationResult.model_validate( + {"is_same": False, "confidence": 0.1} + ) + assert result.reason == "" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])