"""实体对齐器:对抽取结果中的实体进行去重和合并。 三层对齐策略: 1. 规则层:名称规范化 + 别名匹配 + 类型硬过滤 2. 向量相似度层:基于 embedding 的 cosine 相似度 3. LLM 仲裁层:仅对边界样本调用,严格 JSON schema 校验 失败策略:fail-open —— 对齐失败不阻断抽取请求。 """ from __future__ import annotations import json import re import unicodedata from langchain_openai import ChatOpenAI, OpenAIEmbeddings from pydantic import BaseModel, Field, SecretStr from app.core.logging import get_logger from app.module.kg_extraction.models import ( ExtractionResult, GraphEdge, GraphNode, Triple, ) logger = get_logger(__name__) # --------------------------------------------------------------------------- # Rule Layer # --------------------------------------------------------------------------- def normalize_name(name: str) -> str: """名称规范化:Unicode NFKC -> 小写 -> 去标点 -> 合并空白。""" name = unicodedata.normalize("NFKC", name) name = name.lower() name = re.sub(r"[^\w\s]", "", name) name = re.sub(r"\s+", " ", name).strip() return name def rule_score(a: GraphNode, b: GraphNode) -> float: """规则层匹配分数。 Returns: 1.0 规范化名称完全一致且类型兼容 0.5 一方名称是另一方子串且类型兼容(别名/缩写) 0.0 类型不兼容或名称无关联 """ # 类型硬过滤 if a.type.lower() != b.type.lower(): return 0.0 norm_a = normalize_name(a.name) norm_b = normalize_name(b.name) # 完全匹配 if norm_a == norm_b: return 1.0 # 子串匹配(别名/缩写),要求双方规范化名称至少 2 字符 if len(norm_a) >= 2 and len(norm_b) >= 2: if norm_a in norm_b or norm_b in norm_a: return 0.5 return 0.0 # --------------------------------------------------------------------------- # Vector Similarity Layer # --------------------------------------------------------------------------- def cosine_similarity(a: list[float], b: list[float]) -> float: """计算两个向量的余弦相似度。""" dot = sum(x * y for x, y in zip(a, b)) norm_a = sum(x * x for x in a) ** 0.5 norm_b = sum(x * x for x in b) ** 0.5 if norm_a == 0.0 or norm_b == 0.0: return 0.0 return dot / (norm_a * norm_b) def _entity_text(node: GraphNode) -> str: """构造用于 embedding 的实体文本表示。""" return f"{node.type}: {node.name}" # --------------------------------------------------------------------------- # LLM Arbitration Layer # --------------------------------------------------------------------------- _LLM_PROMPT = ( "判断以下两个实体是否指向同一个现实世界的实体或概念。\n\n" "实体 A:\n- 名称: {name_a}\n- 类型: {type_a}\n\n" "实体 B:\n- 名称: {name_b}\n- 类型: {type_b}\n\n" '请严格按以下 JSON 格式返回,不要包含任何其他内容:\n' '{{"is_same": true, "confidence": 0.95, "reason": "简要理由"}}' ) class LLMArbitrationResult(BaseModel): """LLM 仲裁返回结构。""" is_same: bool confidence: float = Field(ge=0.0, le=1.0) reason: str = "" # --------------------------------------------------------------------------- # Union-Find # --------------------------------------------------------------------------- def _make_union_find(n: int): """创建 Union-Find 数据结构,返回 (parent, find, union)。""" parent = list(range(n)) def find(x: int) -> int: while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x def union(x: int, y: int) -> None: px, py = find(x), find(y) if px != py: parent[px] = py return parent, find, union # --------------------------------------------------------------------------- # Merge Result Builder # --------------------------------------------------------------------------- def _build_merged_result( original: ExtractionResult, parent: list[int], find, ) -> ExtractionResult: """根据 Union-Find 结果构建合并后的 ExtractionResult。""" nodes = original.nodes # Group by root groups: dict[int, list[int]] = {} for i in range(len(nodes)): root = find(i) groups.setdefault(root, []).append(i) # 无合并发生时直接返回原结果 if len(groups) == len(nodes): return original # Canonical: 选择每组中名称最长的节点 # 使用 (name, type) 作为 key 避免同名跨类型节点误映射 node_map: dict[tuple[str, str], str] = {} merged_nodes: list[GraphNode] = [] for members in groups.values(): best_idx = max(members, key=lambda idx: len(nodes[idx].name)) canon = nodes[best_idx] merged_nodes.append(canon) for idx in members: node_map[(nodes[idx].name, nodes[idx].type)] = canon.name logger.info( "Alignment merged %d nodes -> %d nodes", len(nodes), len(merged_nodes), ) # 为 edges 构建仅名称的映射(仅当同名节点映射结果无歧义时才包含) _edge_remap: dict[str, set[str]] = {} for (name, _type), canon_name in node_map.items(): _edge_remap.setdefault(name, set()).add(canon_name) edge_name_map: dict[str, str] = { name: next(iter(canon_names)) for name, canon_names in _edge_remap.items() if len(canon_names) == 1 } # 更新 edges(重命名 + 去重) seen_edges: set[str] = set() merged_edges: list[GraphEdge] = [] for edge in original.edges: src = edge_name_map.get(edge.source, edge.source) tgt = edge_name_map.get(edge.target, edge.target) key = f"{src}|{edge.relation_type}|{tgt}" if key not in seen_edges: seen_edges.add(key) merged_edges.append( GraphEdge( source=src, target=tgt, relation_type=edge.relation_type, properties=edge.properties, ) ) # 更新 triples(使用 (name, type) 精确查找,避免跨类型误映射) seen_triples: set[str] = set() merged_triples: list[Triple] = [] for triple in original.triples: sub_key = (triple.subject.name, triple.subject.type) obj_key = (triple.object.name, triple.object.type) sub_name = node_map.get(sub_key, triple.subject.name) obj_name = node_map.get(obj_key, triple.object.name) key = f"{sub_name}|{triple.predicate}|{obj_name}" if key not in seen_triples: seen_triples.add(key) merged_triples.append( Triple( subject=GraphNode(name=sub_name, type=triple.subject.type), predicate=triple.predicate, object=GraphNode(name=obj_name, type=triple.object.type), ) ) return ExtractionResult( nodes=merged_nodes, edges=merged_edges, triples=merged_triples, raw_text=original.raw_text, source_id=original.source_id, ) # --------------------------------------------------------------------------- # EntityAligner # --------------------------------------------------------------------------- class EntityAligner: """实体对齐器。 通过 ``from_settings()`` 工厂方法从全局配置创建实例, 也可直接构造以覆盖默认参数。 """ def __init__( self, *, enabled: bool = False, embedding_model: str = "text-embedding-3-small", embedding_base_url: str | None = None, embedding_api_key: SecretStr = SecretStr("EMPTY"), llm_model: str = "gpt-4o-mini", llm_base_url: str | None = None, llm_api_key: SecretStr = SecretStr("EMPTY"), llm_timeout: int = 30, vector_auto_merge_threshold: float = 0.92, vector_llm_threshold: float = 0.78, llm_arbitration_enabled: bool = True, max_llm_arbitrations: int = 10, ) -> None: self._enabled = enabled self._embedding_model = embedding_model self._embedding_base_url = embedding_base_url self._embedding_api_key = embedding_api_key self._llm_model = llm_model self._llm_base_url = llm_base_url self._llm_api_key = llm_api_key self._llm_timeout = llm_timeout self._vector_auto_threshold = vector_auto_merge_threshold self._vector_llm_threshold = vector_llm_threshold self._llm_arbitration_enabled = llm_arbitration_enabled self._max_llm_arbitrations = max_llm_arbitrations # Lazy init self._embeddings: OpenAIEmbeddings | None = None self._llm: ChatOpenAI | None = None @classmethod def from_settings(cls) -> EntityAligner: """从全局 Settings 创建对齐器实例。""" from app.core.config import settings return cls( enabled=settings.kg_alignment_enabled, embedding_model=settings.kg_alignment_embedding_model, embedding_base_url=settings.kg_llm_base_url, embedding_api_key=settings.kg_llm_api_key, llm_model=settings.kg_llm_model, llm_base_url=settings.kg_llm_base_url, llm_api_key=settings.kg_llm_api_key, llm_timeout=settings.kg_llm_timeout_seconds, vector_auto_merge_threshold=settings.kg_alignment_vector_threshold, vector_llm_threshold=settings.kg_alignment_llm_threshold, ) def _get_embeddings(self) -> OpenAIEmbeddings: if self._embeddings is None: self._embeddings = OpenAIEmbeddings( model=self._embedding_model, base_url=self._embedding_base_url, api_key=self._embedding_api_key, ) return self._embeddings def _get_llm(self) -> ChatOpenAI: if self._llm is None: self._llm = ChatOpenAI( model=self._llm_model, base_url=self._llm_base_url, api_key=self._llm_api_key, temperature=0.0, timeout=self._llm_timeout, ) return self._llm # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ async def align(self, result: ExtractionResult) -> ExtractionResult: """对抽取结果中的实体进行对齐去重(异步,三层策略)。 Fail-open:对齐失败时返回原始结果,不阻断请求。 注意:当前仅支持批内对齐(单次抽取结果内部的 pairwise 合并)。 库内对齐(对现有图谱实体召回/匹配)需要 KG 服务 API 支持,待后续实现。 """ if not self._enabled or len(result.nodes) <= 1: return result try: return await self._align_impl(result) except Exception: logger.exception( "Entity alignment failed, returning original result (fail-open)" ) return result def align_rules_only(self, result: ExtractionResult) -> ExtractionResult: """仅使用规则层对齐(同步,用于 extract_sync 路径)。 Fail-open:对齐失败时返回原始结果。 """ if not self._enabled or len(result.nodes) <= 1: return result try: nodes = result.nodes parent, find, union = _make_union_find(len(nodes)) for i in range(len(nodes)): for j in range(i + 1, len(nodes)): if find(i) == find(j): continue if rule_score(nodes[i], nodes[j]) >= 1.0: union(i, j) return _build_merged_result(result, parent, find) except Exception: logger.exception( "Rule-only alignment failed, returning original result (fail-open)" ) return result # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ async def _align_impl(self, result: ExtractionResult) -> ExtractionResult: """三层对齐的核心实现。 当前仅在单次抽取结果的节点列表内做 pairwise 对齐。 若需与已有图谱实体匹配(库内对齐),需扩展入参以支持 graph_id + 候选实体检索上下文,依赖 KG 服务 API。 """ nodes = result.nodes n = len(nodes) parent, find, union = _make_union_find(n) # Phase 1: Rule layer vector_candidates: list[tuple[int, int]] = [] for i in range(n): for j in range(i + 1, n): if find(i) == find(j): continue score = rule_score(nodes[i], nodes[j]) if score >= 1.0: union(i, j) logger.debug( "Rule merge: '%s' <-> '%s'", nodes[i].name, nodes[j].name ) elif score > 0: vector_candidates.append((i, j)) # Phase 2: Vector similarity llm_candidates: list[tuple[int, int, float]] = [] if vector_candidates: try: emb_map = await self._embed_candidates(nodes, vector_candidates) for i, j in vector_candidates: if find(i) == find(j): continue sim = cosine_similarity(emb_map[i], emb_map[j]) if sim >= self._vector_auto_threshold: union(i, j) logger.debug( "Vector merge: '%s' <-> '%s' (sim=%.3f)", nodes[i].name, nodes[j].name, sim, ) elif sim >= self._vector_llm_threshold: llm_candidates.append((i, j, sim)) except Exception: logger.warning( "Vector similarity failed, skipping vector layer", exc_info=True ) # Phase 3: LLM arbitration (boundary cases only) if llm_candidates and self._llm_arbitration_enabled: llm_count = 0 for i, j, sim in llm_candidates: if llm_count >= self._max_llm_arbitrations or find(i) == find(j): continue try: if await self._llm_arbitrate(nodes[i], nodes[j]): union(i, j) logger.debug( "LLM merge: '%s' <-> '%s' (sim=%.3f)", nodes[i].name, nodes[j].name, sim, ) except Exception: logger.warning( "LLM arbitration failed for '%s' <-> '%s'", nodes[i].name, nodes[j].name, ) finally: llm_count += 1 return _build_merged_result(result, parent, find) async def _embed_candidates( self, nodes: list[GraphNode], candidates: list[tuple[int, int]] ) -> dict[int, list[float]]: """对候选实体计算 embedding,返回 {index: embedding}。""" unique_indices: set[int] = set() for i, j in candidates: unique_indices.add(i) unique_indices.add(j) idx_list = sorted(unique_indices) texts = [_entity_text(nodes[i]) for i in idx_list] embeddings = await self._get_embeddings().aembed_documents(texts) return dict(zip(idx_list, embeddings)) async def _llm_arbitrate(self, a: GraphNode, b: GraphNode) -> bool: """LLM 仲裁两个实体是否相同,严格 JSON schema 校验。""" prompt = _LLM_PROMPT.format( name_a=a.name, type_a=a.type, name_b=b.name, type_b=b.type, ) response = await self._get_llm().ainvoke(prompt) content = response.content.strip() parsed = json.loads(content) result = LLMArbitrationResult.model_validate(parsed) logger.debug( "LLM arbitration: '%s' <-> '%s' -> is_same=%s, confidence=%.2f", a.name, b.name, result.is_same, result.confidence, ) return result.is_same and result.confidence >= 0.7