feat(kg): 实现实体对齐功能(aligner.py)

- 实现三层对齐策略:规则层 + 向量相似度层 + LLM 仲裁层
- 规则层:名称规范化(NFKC、小写、去标点/空格)+ 规则评分
- 向量层:OpenAI Embeddings + cosine 相似度计算
- LLM 层:仅对边界样本调用,严格 JSON schema 校验
- 使用 Union-Find 实现传递合并
- 支持批内对齐(库内对齐待 KG 服务 API 支持)

核心组件:
- EntityAligner 类:align() (async)、align_rules_only() (sync)
- 配置项:kg_alignment_enabled(默认 false)、embedding_model、阈值
- 失败策略:fail-open(对齐失败不中断请求)

集成:
- 已集成到抽取主链路(extract → align → return)
- extract() 调用 async align()
- extract_sync() 调用 sync align_rules_only()

修复:
- P1-1:使用 (name, type) 作为 key,避免同名跨类型误合并
- P1-2:LLM 计数在 finally 块中增加,异常也计数
- P1-3:添加库内对齐说明(待后续实现)

新增 41 个测试用例,全部通过
测试结果:41 tests pass
This commit is contained in:
2026-02-19 18:26:54 +08:00
parent 7abdafc338
commit 0ed7dcbee7
5 changed files with 969 additions and 0 deletions

View File

@@ -0,0 +1,478 @@
"""实体对齐器:对抽取结果中的实体进行去重和合并。
三层对齐策略:
1. 规则层:名称规范化 + 别名匹配 + 类型硬过滤
2. 向量相似度层:基于 embedding 的 cosine 相似度
3. LLM 仲裁层:仅对边界样本调用,严格 JSON schema 校验
失败策略:fail-open —— 对齐失败不阻断抽取请求。
"""
from __future__ import annotations
import json
import re
import unicodedata
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel, Field, SecretStr
from app.core.logging import get_logger
from app.module.kg_extraction.models import (
ExtractionResult,
GraphEdge,
GraphNode,
Triple,
)
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Rule Layer
# ---------------------------------------------------------------------------
def normalize_name(name: str) -> str:
"""名称规范化:Unicode NFKC -> 小写 -> 去标点 -> 合并空白。"""
name = unicodedata.normalize("NFKC", name)
name = name.lower()
name = re.sub(r"[^\w\s]", "", name)
name = re.sub(r"\s+", " ", name).strip()
return name
def rule_score(a: GraphNode, b: GraphNode) -> float:
"""规则层匹配分数。
Returns:
1.0 规范化名称完全一致且类型兼容
0.5 一方名称是另一方子串且类型兼容(别名/缩写)
0.0 类型不兼容或名称无关联
"""
# 类型硬过滤
if a.type.lower() != b.type.lower():
return 0.0
norm_a = normalize_name(a.name)
norm_b = normalize_name(b.name)
# 完全匹配
if norm_a == norm_b:
return 1.0
# 子串匹配(别名/缩写),要求双方规范化名称至少 2 字符
if len(norm_a) >= 2 and len(norm_b) >= 2:
if norm_a in norm_b or norm_b in norm_a:
return 0.5
return 0.0
# ---------------------------------------------------------------------------
# Vector Similarity Layer
# ---------------------------------------------------------------------------
def cosine_similarity(a: list[float], b: list[float]) -> float:
"""计算两个向量的余弦相似度。"""
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return dot / (norm_a * norm_b)
def _entity_text(node: GraphNode) -> str:
"""构造用于 embedding 的实体文本表示。"""
return f"{node.type}: {node.name}"
# ---------------------------------------------------------------------------
# LLM Arbitration Layer
# ---------------------------------------------------------------------------
_LLM_PROMPT = (
"判断以下两个实体是否指向同一个现实世界的实体或概念。\n\n"
"实体 A:\n- 名称: {name_a}\n- 类型: {type_a}\n\n"
"实体 B:\n- 名称: {name_b}\n- 类型: {type_b}\n\n"
'请严格按以下 JSON 格式返回,不要包含任何其他内容:\n'
'{{"is_same": true, "confidence": 0.95, "reason": "简要理由"}}'
)
class LLMArbitrationResult(BaseModel):
"""LLM 仲裁返回结构。"""
is_same: bool
confidence: float = Field(ge=0.0, le=1.0)
reason: str = ""
# ---------------------------------------------------------------------------
# Union-Find
# ---------------------------------------------------------------------------
def _make_union_find(n: int):
"""创建 Union-Find 数据结构,返回 (parent, find, union)。"""
parent = list(range(n))
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(x: int, y: int) -> None:
px, py = find(x), find(y)
if px != py:
parent[px] = py
return parent, find, union
# ---------------------------------------------------------------------------
# Merge Result Builder
# ---------------------------------------------------------------------------
def _build_merged_result(
original: ExtractionResult,
parent: list[int],
find,
) -> ExtractionResult:
"""根据 Union-Find 结果构建合并后的 ExtractionResult。"""
nodes = original.nodes
# Group by root
groups: dict[int, list[int]] = {}
for i in range(len(nodes)):
root = find(i)
groups.setdefault(root, []).append(i)
# 无合并发生时直接返回原结果
if len(groups) == len(nodes):
return original
# Canonical: 选择每组中名称最长的节点
# 使用 (name, type) 作为 key 避免同名跨类型节点误映射
node_map: dict[tuple[str, str], str] = {}
merged_nodes: list[GraphNode] = []
for members in groups.values():
best_idx = max(members, key=lambda idx: len(nodes[idx].name))
canon = nodes[best_idx]
merged_nodes.append(canon)
for idx in members:
node_map[(nodes[idx].name, nodes[idx].type)] = canon.name
logger.info(
"Alignment merged %d nodes -> %d nodes",
len(nodes),
len(merged_nodes),
)
# 为 edges 构建仅名称的映射(仅当同名节点映射结果无歧义时才包含)
_edge_remap: dict[str, set[str]] = {}
for (name, _type), canon_name in node_map.items():
_edge_remap.setdefault(name, set()).add(canon_name)
edge_name_map: dict[str, str] = {
name: next(iter(canon_names))
for name, canon_names in _edge_remap.items()
if len(canon_names) == 1
}
# 更新 edges(重命名 + 去重)
seen_edges: set[str] = set()
merged_edges: list[GraphEdge] = []
for edge in original.edges:
src = edge_name_map.get(edge.source, edge.source)
tgt = edge_name_map.get(edge.target, edge.target)
key = f"{src}|{edge.relation_type}|{tgt}"
if key not in seen_edges:
seen_edges.add(key)
merged_edges.append(
GraphEdge(
source=src,
target=tgt,
relation_type=edge.relation_type,
properties=edge.properties,
)
)
# 更新 triples(使用 (name, type) 精确查找,避免跨类型误映射)
seen_triples: set[str] = set()
merged_triples: list[Triple] = []
for triple in original.triples:
sub_key = (triple.subject.name, triple.subject.type)
obj_key = (triple.object.name, triple.object.type)
sub_name = node_map.get(sub_key, triple.subject.name)
obj_name = node_map.get(obj_key, triple.object.name)
key = f"{sub_name}|{triple.predicate}|{obj_name}"
if key not in seen_triples:
seen_triples.add(key)
merged_triples.append(
Triple(
subject=GraphNode(name=sub_name, type=triple.subject.type),
predicate=triple.predicate,
object=GraphNode(name=obj_name, type=triple.object.type),
)
)
return ExtractionResult(
nodes=merged_nodes,
edges=merged_edges,
triples=merged_triples,
raw_text=original.raw_text,
source_id=original.source_id,
)
# ---------------------------------------------------------------------------
# EntityAligner
# ---------------------------------------------------------------------------
class EntityAligner:
"""实体对齐器。
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
也可直接构造以覆盖默认参数。
"""
def __init__(
self,
*,
enabled: bool = False,
embedding_model: str = "text-embedding-3-small",
embedding_base_url: str | None = None,
embedding_api_key: SecretStr = SecretStr("EMPTY"),
llm_model: str = "gpt-4o-mini",
llm_base_url: str | None = None,
llm_api_key: SecretStr = SecretStr("EMPTY"),
llm_timeout: int = 30,
vector_auto_merge_threshold: float = 0.92,
vector_llm_threshold: float = 0.78,
llm_arbitration_enabled: bool = True,
max_llm_arbitrations: int = 10,
) -> None:
self._enabled = enabled
self._embedding_model = embedding_model
self._embedding_base_url = embedding_base_url
self._embedding_api_key = embedding_api_key
self._llm_model = llm_model
self._llm_base_url = llm_base_url
self._llm_api_key = llm_api_key
self._llm_timeout = llm_timeout
self._vector_auto_threshold = vector_auto_merge_threshold
self._vector_llm_threshold = vector_llm_threshold
self._llm_arbitration_enabled = llm_arbitration_enabled
self._max_llm_arbitrations = max_llm_arbitrations
# Lazy init
self._embeddings: OpenAIEmbeddings | None = None
self._llm: ChatOpenAI | None = None
@classmethod
def from_settings(cls) -> EntityAligner:
"""从全局 Settings 创建对齐器实例。"""
from app.core.config import settings
return cls(
enabled=settings.kg_alignment_enabled,
embedding_model=settings.kg_alignment_embedding_model,
embedding_base_url=settings.kg_llm_base_url,
embedding_api_key=settings.kg_llm_api_key,
llm_model=settings.kg_llm_model,
llm_base_url=settings.kg_llm_base_url,
llm_api_key=settings.kg_llm_api_key,
llm_timeout=settings.kg_llm_timeout_seconds,
vector_auto_merge_threshold=settings.kg_alignment_vector_threshold,
vector_llm_threshold=settings.kg_alignment_llm_threshold,
)
def _get_embeddings(self) -> OpenAIEmbeddings:
if self._embeddings is None:
self._embeddings = OpenAIEmbeddings(
model=self._embedding_model,
base_url=self._embedding_base_url,
api_key=self._embedding_api_key,
)
return self._embeddings
def _get_llm(self) -> ChatOpenAI:
if self._llm is None:
self._llm = ChatOpenAI(
model=self._llm_model,
base_url=self._llm_base_url,
api_key=self._llm_api_key,
temperature=0.0,
timeout=self._llm_timeout,
)
return self._llm
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def align(self, result: ExtractionResult) -> ExtractionResult:
"""对抽取结果中的实体进行对齐去重(异步,三层策略)。
Fail-open:对齐失败时返回原始结果,不阻断请求。
注意:当前仅支持批内对齐(单次抽取结果内部的 pairwise 合并)。
库内对齐(对现有图谱实体召回/匹配)需要 KG 服务 API 支持,待后续实现。
"""
if not self._enabled or len(result.nodes) <= 1:
return result
try:
return await self._align_impl(result)
except Exception:
logger.exception(
"Entity alignment failed, returning original result (fail-open)"
)
return result
def align_rules_only(self, result: ExtractionResult) -> ExtractionResult:
"""仅使用规则层对齐(同步,用于 extract_sync 路径)。
Fail-open:对齐失败时返回原始结果。
"""
if not self._enabled or len(result.nodes) <= 1:
return result
try:
nodes = result.nodes
parent, find, union = _make_union_find(len(nodes))
for i in range(len(nodes)):
for j in range(i + 1, len(nodes)):
if find(i) == find(j):
continue
if rule_score(nodes[i], nodes[j]) >= 1.0:
union(i, j)
return _build_merged_result(result, parent, find)
except Exception:
logger.exception(
"Rule-only alignment failed, returning original result (fail-open)"
)
return result
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
async def _align_impl(self, result: ExtractionResult) -> ExtractionResult:
"""三层对齐的核心实现。
当前仅在单次抽取结果的节点列表内做 pairwise 对齐。
若需与已有图谱实体匹配(库内对齐),需扩展入参以支持
graph_id + 候选实体检索上下文,依赖 KG 服务 API。
"""
nodes = result.nodes
n = len(nodes)
parent, find, union = _make_union_find(n)
# Phase 1: Rule layer
vector_candidates: list[tuple[int, int]] = []
for i in range(n):
for j in range(i + 1, n):
if find(i) == find(j):
continue
score = rule_score(nodes[i], nodes[j])
if score >= 1.0:
union(i, j)
logger.debug(
"Rule merge: '%s' <-> '%s'", nodes[i].name, nodes[j].name
)
elif score > 0:
vector_candidates.append((i, j))
# Phase 2: Vector similarity
llm_candidates: list[tuple[int, int, float]] = []
if vector_candidates:
try:
emb_map = await self._embed_candidates(nodes, vector_candidates)
for i, j in vector_candidates:
if find(i) == find(j):
continue
sim = cosine_similarity(emb_map[i], emb_map[j])
if sim >= self._vector_auto_threshold:
union(i, j)
logger.debug(
"Vector merge: '%s' <-> '%s' (sim=%.3f)",
nodes[i].name,
nodes[j].name,
sim,
)
elif sim >= self._vector_llm_threshold:
llm_candidates.append((i, j, sim))
except Exception:
logger.warning(
"Vector similarity failed, skipping vector layer", exc_info=True
)
# Phase 3: LLM arbitration (boundary cases only)
if llm_candidates and self._llm_arbitration_enabled:
llm_count = 0
for i, j, sim in llm_candidates:
if llm_count >= self._max_llm_arbitrations or find(i) == find(j):
continue
try:
if await self._llm_arbitrate(nodes[i], nodes[j]):
union(i, j)
logger.debug(
"LLM merge: '%s' <-> '%s' (sim=%.3f)",
nodes[i].name,
nodes[j].name,
sim,
)
except Exception:
logger.warning(
"LLM arbitration failed for '%s' <-> '%s'",
nodes[i].name,
nodes[j].name,
)
finally:
llm_count += 1
return _build_merged_result(result, parent, find)
async def _embed_candidates(
self, nodes: list[GraphNode], candidates: list[tuple[int, int]]
) -> dict[int, list[float]]:
"""对候选实体计算 embedding,返回 {index: embedding}。"""
unique_indices: set[int] = set()
for i, j in candidates:
unique_indices.add(i)
unique_indices.add(j)
idx_list = sorted(unique_indices)
texts = [_entity_text(nodes[i]) for i in idx_list]
embeddings = await self._get_embeddings().aembed_documents(texts)
return dict(zip(idx_list, embeddings))
async def _llm_arbitrate(self, a: GraphNode, b: GraphNode) -> bool:
"""LLM 仲裁两个实体是否相同,严格 JSON schema 校验。"""
prompt = _LLM_PROMPT.format(
name_a=a.name,
type_a=a.type,
name_b=b.name,
type_b=b.type,
)
response = await self._get_llm().ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = LLMArbitrationResult.model_validate(parsed)
logger.debug(
"LLM arbitration: '%s' <-> '%s' -> is_same=%s, confidence=%.2f",
a.name,
b.name,
result.is_same,
result.confidence,
)
return result.is_same and result.confidence >= 0.7