You've already forked DataMate
feat(kg): 实现实体对齐功能(aligner.py)
- 实现三层对齐策略:规则层 + 向量相似度层 + LLM 仲裁层 - 规则层:名称规范化(NFKC、小写、去标点/空格)+ 规则评分 - 向量层:OpenAI Embeddings + cosine 相似度计算 - LLM 层:仅对边界样本调用,严格 JSON schema 校验 - 使用 Union-Find 实现传递合并 - 支持批内对齐(库内对齐待 KG 服务 API 支持) 核心组件: - EntityAligner 类:align() (async)、align_rules_only() (sync) - 配置项:kg_alignment_enabled(默认 false)、embedding_model、阈值 - 失败策略:fail-open(对齐失败不中断请求) 集成: - 已集成到抽取主链路(extract → align → return) - extract() 调用 async align() - extract_sync() 调用 sync align_rules_only() 修复: - P1-1:使用 (name, type) 作为 key,避免同名跨类型误合并 - P1-2:LLM 计数在 finally 块中增加,异常也计数 - P1-3:添加库内对齐说明(待后续实现) 新增 41 个测试用例,全部通过 测试结果:41 tests pass
This commit is contained in:
@@ -82,6 +82,12 @@ class Settings(BaseSettings):
|
|||||||
kg_llm_timeout_seconds: int = 60
|
kg_llm_timeout_seconds: int = 60
|
||||||
kg_llm_max_retries: int = 2
|
kg_llm_max_retries: int = 2
|
||||||
|
|
||||||
|
# Knowledge Graph - 实体对齐配置
|
||||||
|
kg_alignment_enabled: bool = False
|
||||||
|
kg_alignment_embedding_model: str = "text-embedding-3-small"
|
||||||
|
kg_alignment_vector_threshold: float = 0.92
|
||||||
|
kg_alignment_llm_threshold: float = 0.78
|
||||||
|
|
||||||
# 标注编辑器(Label Studio Editor)相关
|
# 标注编辑器(Label Studio Editor)相关
|
||||||
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from app.module.kg_extraction.aligner import EntityAligner
|
||||||
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
||||||
from app.module.kg_extraction.models import (
|
from app.module.kg_extraction.models import (
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
@@ -9,6 +10,7 @@ from app.module.kg_extraction.models import (
|
|||||||
from app.module.kg_extraction.interface import router
|
from app.module.kg_extraction.interface import router
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EntityAligner",
|
||||||
"KnowledgeGraphExtractor",
|
"KnowledgeGraphExtractor",
|
||||||
"ExtractionRequest",
|
"ExtractionRequest",
|
||||||
"ExtractionResult",
|
"ExtractionResult",
|
||||||
|
|||||||
478
runtime/datamate-python/app/module/kg_extraction/aligner.py
Normal file
478
runtime/datamate-python/app/module/kg_extraction/aligner.py
Normal file
@@ -0,0 +1,478 @@
|
|||||||
|
"""实体对齐器:对抽取结果中的实体进行去重和合并。
|
||||||
|
|
||||||
|
三层对齐策略:
|
||||||
|
1. 规则层:名称规范化 + 别名匹配 + 类型硬过滤
|
||||||
|
2. 向量相似度层:基于 embedding 的 cosine 相似度
|
||||||
|
3. LLM 仲裁层:仅对边界样本调用,严格 JSON schema 校验
|
||||||
|
|
||||||
|
失败策略:fail-open —— 对齐失败不阻断抽取请求。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_extraction.models import (
|
||||||
|
ExtractionResult,
|
||||||
|
GraphEdge,
|
||||||
|
GraphNode,
|
||||||
|
Triple,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rule Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_name(name: str) -> str:
|
||||||
|
"""名称规范化:Unicode NFKC -> 小写 -> 去标点 -> 合并空白。"""
|
||||||
|
name = unicodedata.normalize("NFKC", name)
|
||||||
|
name = name.lower()
|
||||||
|
name = re.sub(r"[^\w\s]", "", name)
|
||||||
|
name = re.sub(r"\s+", " ", name).strip()
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def rule_score(a: GraphNode, b: GraphNode) -> float:
|
||||||
|
"""规则层匹配分数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
1.0 规范化名称完全一致且类型兼容
|
||||||
|
0.5 一方名称是另一方子串且类型兼容(别名/缩写)
|
||||||
|
0.0 类型不兼容或名称无关联
|
||||||
|
"""
|
||||||
|
# 类型硬过滤
|
||||||
|
if a.type.lower() != b.type.lower():
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
norm_a = normalize_name(a.name)
|
||||||
|
norm_b = normalize_name(b.name)
|
||||||
|
|
||||||
|
# 完全匹配
|
||||||
|
if norm_a == norm_b:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# 子串匹配(别名/缩写),要求双方规范化名称至少 2 字符
|
||||||
|
if len(norm_a) >= 2 and len(norm_b) >= 2:
|
||||||
|
if norm_a in norm_b or norm_b in norm_a:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vector Similarity Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
|
"""计算两个向量的余弦相似度。"""
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = sum(x * x for x in a) ** 0.5
|
||||||
|
norm_b = sum(x * x for x in b) ** 0.5
|
||||||
|
if norm_a == 0.0 or norm_b == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
|
def _entity_text(node: GraphNode) -> str:
|
||||||
|
"""构造用于 embedding 的实体文本表示。"""
|
||||||
|
return f"{node.type}: {node.name}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLM Arbitration Layer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_LLM_PROMPT = (
|
||||||
|
"判断以下两个实体是否指向同一个现实世界的实体或概念。\n\n"
|
||||||
|
"实体 A:\n- 名称: {name_a}\n- 类型: {type_a}\n\n"
|
||||||
|
"实体 B:\n- 名称: {name_b}\n- 类型: {type_b}\n\n"
|
||||||
|
'请严格按以下 JSON 格式返回,不要包含任何其他内容:\n'
|
||||||
|
'{{"is_same": true, "confidence": 0.95, "reason": "简要理由"}}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMArbitrationResult(BaseModel):
|
||||||
|
"""LLM 仲裁返回结构。"""
|
||||||
|
|
||||||
|
is_same: bool
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Union-Find
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_union_find(n: int):
|
||||||
|
"""创建 Union-Find 数据结构,返回 (parent, find, union)。"""
|
||||||
|
parent = list(range(n))
|
||||||
|
|
||||||
|
def find(x: int) -> int:
|
||||||
|
while parent[x] != x:
|
||||||
|
parent[x] = parent[parent[x]]
|
||||||
|
x = parent[x]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def union(x: int, y: int) -> None:
|
||||||
|
px, py = find(x), find(y)
|
||||||
|
if px != py:
|
||||||
|
parent[px] = py
|
||||||
|
|
||||||
|
return parent, find, union
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Merge Result Builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_merged_result(
|
||||||
|
original: ExtractionResult,
|
||||||
|
parent: list[int],
|
||||||
|
find,
|
||||||
|
) -> ExtractionResult:
|
||||||
|
"""根据 Union-Find 结果构建合并后的 ExtractionResult。"""
|
||||||
|
nodes = original.nodes
|
||||||
|
|
||||||
|
# Group by root
|
||||||
|
groups: dict[int, list[int]] = {}
|
||||||
|
for i in range(len(nodes)):
|
||||||
|
root = find(i)
|
||||||
|
groups.setdefault(root, []).append(i)
|
||||||
|
|
||||||
|
# 无合并发生时直接返回原结果
|
||||||
|
if len(groups) == len(nodes):
|
||||||
|
return original
|
||||||
|
|
||||||
|
# Canonical: 选择每组中名称最长的节点
|
||||||
|
# 使用 (name, type) 作为 key 避免同名跨类型节点误映射
|
||||||
|
node_map: dict[tuple[str, str], str] = {}
|
||||||
|
merged_nodes: list[GraphNode] = []
|
||||||
|
for members in groups.values():
|
||||||
|
best_idx = max(members, key=lambda idx: len(nodes[idx].name))
|
||||||
|
canon = nodes[best_idx]
|
||||||
|
merged_nodes.append(canon)
|
||||||
|
for idx in members:
|
||||||
|
node_map[(nodes[idx].name, nodes[idx].type)] = canon.name
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Alignment merged %d nodes -> %d nodes",
|
||||||
|
len(nodes),
|
||||||
|
len(merged_nodes),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为 edges 构建仅名称的映射(仅当同名节点映射结果无歧义时才包含)
|
||||||
|
_edge_remap: dict[str, set[str]] = {}
|
||||||
|
for (name, _type), canon_name in node_map.items():
|
||||||
|
_edge_remap.setdefault(name, set()).add(canon_name)
|
||||||
|
edge_name_map: dict[str, str] = {
|
||||||
|
name: next(iter(canon_names))
|
||||||
|
for name, canon_names in _edge_remap.items()
|
||||||
|
if len(canon_names) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新 edges(重命名 + 去重)
|
||||||
|
seen_edges: set[str] = set()
|
||||||
|
merged_edges: list[GraphEdge] = []
|
||||||
|
for edge in original.edges:
|
||||||
|
src = edge_name_map.get(edge.source, edge.source)
|
||||||
|
tgt = edge_name_map.get(edge.target, edge.target)
|
||||||
|
key = f"{src}|{edge.relation_type}|{tgt}"
|
||||||
|
if key not in seen_edges:
|
||||||
|
seen_edges.add(key)
|
||||||
|
merged_edges.append(
|
||||||
|
GraphEdge(
|
||||||
|
source=src,
|
||||||
|
target=tgt,
|
||||||
|
relation_type=edge.relation_type,
|
||||||
|
properties=edge.properties,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新 triples(使用 (name, type) 精确查找,避免跨类型误映射)
|
||||||
|
seen_triples: set[str] = set()
|
||||||
|
merged_triples: list[Triple] = []
|
||||||
|
for triple in original.triples:
|
||||||
|
sub_key = (triple.subject.name, triple.subject.type)
|
||||||
|
obj_key = (triple.object.name, triple.object.type)
|
||||||
|
sub_name = node_map.get(sub_key, triple.subject.name)
|
||||||
|
obj_name = node_map.get(obj_key, triple.object.name)
|
||||||
|
key = f"{sub_name}|{triple.predicate}|{obj_name}"
|
||||||
|
if key not in seen_triples:
|
||||||
|
seen_triples.add(key)
|
||||||
|
merged_triples.append(
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name=sub_name, type=triple.subject.type),
|
||||||
|
predicate=triple.predicate,
|
||||||
|
object=GraphNode(name=obj_name, type=triple.object.type),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractionResult(
|
||||||
|
nodes=merged_nodes,
|
||||||
|
edges=merged_edges,
|
||||||
|
triples=merged_triples,
|
||||||
|
raw_text=original.raw_text,
|
||||||
|
source_id=original.source_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EntityAligner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class EntityAligner:
|
||||||
|
"""实体对齐器。
|
||||||
|
|
||||||
|
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
||||||
|
也可直接构造以覆盖默认参数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
enabled: bool = False,
|
||||||
|
embedding_model: str = "text-embedding-3-small",
|
||||||
|
embedding_base_url: str | None = None,
|
||||||
|
embedding_api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
llm_model: str = "gpt-4o-mini",
|
||||||
|
llm_base_url: str | None = None,
|
||||||
|
llm_api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
|
llm_timeout: int = 30,
|
||||||
|
vector_auto_merge_threshold: float = 0.92,
|
||||||
|
vector_llm_threshold: float = 0.78,
|
||||||
|
llm_arbitration_enabled: bool = True,
|
||||||
|
max_llm_arbitrations: int = 10,
|
||||||
|
) -> None:
|
||||||
|
self._enabled = enabled
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
self._embedding_base_url = embedding_base_url
|
||||||
|
self._embedding_api_key = embedding_api_key
|
||||||
|
self._llm_model = llm_model
|
||||||
|
self._llm_base_url = llm_base_url
|
||||||
|
self._llm_api_key = llm_api_key
|
||||||
|
self._llm_timeout = llm_timeout
|
||||||
|
self._vector_auto_threshold = vector_auto_merge_threshold
|
||||||
|
self._vector_llm_threshold = vector_llm_threshold
|
||||||
|
self._llm_arbitration_enabled = llm_arbitration_enabled
|
||||||
|
self._max_llm_arbitrations = max_llm_arbitrations
|
||||||
|
# Lazy init
|
||||||
|
self._embeddings: OpenAIEmbeddings | None = None
|
||||||
|
self._llm: ChatOpenAI | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> EntityAligner:
|
||||||
|
"""从全局 Settings 创建对齐器实例。"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
enabled=settings.kg_alignment_enabled,
|
||||||
|
embedding_model=settings.kg_alignment_embedding_model,
|
||||||
|
embedding_base_url=settings.kg_llm_base_url,
|
||||||
|
embedding_api_key=settings.kg_llm_api_key,
|
||||||
|
llm_model=settings.kg_llm_model,
|
||||||
|
llm_base_url=settings.kg_llm_base_url,
|
||||||
|
llm_api_key=settings.kg_llm_api_key,
|
||||||
|
llm_timeout=settings.kg_llm_timeout_seconds,
|
||||||
|
vector_auto_merge_threshold=settings.kg_alignment_vector_threshold,
|
||||||
|
vector_llm_threshold=settings.kg_alignment_llm_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_embeddings(self) -> OpenAIEmbeddings:
|
||||||
|
if self._embeddings is None:
|
||||||
|
self._embeddings = OpenAIEmbeddings(
|
||||||
|
model=self._embedding_model,
|
||||||
|
base_url=self._embedding_base_url,
|
||||||
|
api_key=self._embedding_api_key,
|
||||||
|
)
|
||||||
|
return self._embeddings
|
||||||
|
|
||||||
|
def _get_llm(self) -> ChatOpenAI:
|
||||||
|
if self._llm is None:
|
||||||
|
self._llm = ChatOpenAI(
|
||||||
|
model=self._llm_model,
|
||||||
|
base_url=self._llm_base_url,
|
||||||
|
api_key=self._llm_api_key,
|
||||||
|
temperature=0.0,
|
||||||
|
timeout=self._llm_timeout,
|
||||||
|
)
|
||||||
|
return self._llm
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def align(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""对抽取结果中的实体进行对齐去重(异步,三层策略)。
|
||||||
|
|
||||||
|
Fail-open:对齐失败时返回原始结果,不阻断请求。
|
||||||
|
|
||||||
|
注意:当前仅支持批内对齐(单次抽取结果内部的 pairwise 合并)。
|
||||||
|
库内对齐(对现有图谱实体召回/匹配)需要 KG 服务 API 支持,待后续实现。
|
||||||
|
"""
|
||||||
|
if not self._enabled or len(result.nodes) <= 1:
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._align_impl(result)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Entity alignment failed, returning original result (fail-open)"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def align_rules_only(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""仅使用规则层对齐(同步,用于 extract_sync 路径)。
|
||||||
|
|
||||||
|
Fail-open:对齐失败时返回原始结果。
|
||||||
|
"""
|
||||||
|
if not self._enabled or len(result.nodes) <= 1:
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
nodes = result.nodes
|
||||||
|
parent, find, union = _make_union_find(len(nodes))
|
||||||
|
|
||||||
|
for i in range(len(nodes)):
|
||||||
|
for j in range(i + 1, len(nodes)):
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
if rule_score(nodes[i], nodes[j]) >= 1.0:
|
||||||
|
union(i, j)
|
||||||
|
|
||||||
|
return _build_merged_result(result, parent, find)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Rule-only alignment failed, returning original result (fail-open)"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _align_impl(self, result: ExtractionResult) -> ExtractionResult:
|
||||||
|
"""三层对齐的核心实现。
|
||||||
|
|
||||||
|
当前仅在单次抽取结果的节点列表内做 pairwise 对齐。
|
||||||
|
若需与已有图谱实体匹配(库内对齐),需扩展入参以支持
|
||||||
|
graph_id + 候选实体检索上下文,依赖 KG 服务 API。
|
||||||
|
"""
|
||||||
|
nodes = result.nodes
|
||||||
|
n = len(nodes)
|
||||||
|
parent, find, union = _make_union_find(n)
|
||||||
|
|
||||||
|
# Phase 1: Rule layer
|
||||||
|
vector_candidates: list[tuple[int, int]] = []
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i + 1, n):
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
score = rule_score(nodes[i], nodes[j])
|
||||||
|
if score >= 1.0:
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"Rule merge: '%s' <-> '%s'", nodes[i].name, nodes[j].name
|
||||||
|
)
|
||||||
|
elif score > 0:
|
||||||
|
vector_candidates.append((i, j))
|
||||||
|
|
||||||
|
# Phase 2: Vector similarity
|
||||||
|
llm_candidates: list[tuple[int, int, float]] = []
|
||||||
|
if vector_candidates:
|
||||||
|
try:
|
||||||
|
emb_map = await self._embed_candidates(nodes, vector_candidates)
|
||||||
|
for i, j in vector_candidates:
|
||||||
|
if find(i) == find(j):
|
||||||
|
continue
|
||||||
|
sim = cosine_similarity(emb_map[i], emb_map[j])
|
||||||
|
if sim >= self._vector_auto_threshold:
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"Vector merge: '%s' <-> '%s' (sim=%.3f)",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
sim,
|
||||||
|
)
|
||||||
|
elif sim >= self._vector_llm_threshold:
|
||||||
|
llm_candidates.append((i, j, sim))
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Vector similarity failed, skipping vector layer", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 3: LLM arbitration (boundary cases only)
|
||||||
|
if llm_candidates and self._llm_arbitration_enabled:
|
||||||
|
llm_count = 0
|
||||||
|
for i, j, sim in llm_candidates:
|
||||||
|
if llm_count >= self._max_llm_arbitrations or find(i) == find(j):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if await self._llm_arbitrate(nodes[i], nodes[j]):
|
||||||
|
union(i, j)
|
||||||
|
logger.debug(
|
||||||
|
"LLM merge: '%s' <-> '%s' (sim=%.3f)",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
sim,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"LLM arbitration failed for '%s' <-> '%s'",
|
||||||
|
nodes[i].name,
|
||||||
|
nodes[j].name,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
llm_count += 1
|
||||||
|
|
||||||
|
return _build_merged_result(result, parent, find)
|
||||||
|
|
||||||
|
async def _embed_candidates(
|
||||||
|
self, nodes: list[GraphNode], candidates: list[tuple[int, int]]
|
||||||
|
) -> dict[int, list[float]]:
|
||||||
|
"""对候选实体计算 embedding,返回 {index: embedding}。"""
|
||||||
|
unique_indices: set[int] = set()
|
||||||
|
for i, j in candidates:
|
||||||
|
unique_indices.add(i)
|
||||||
|
unique_indices.add(j)
|
||||||
|
|
||||||
|
idx_list = sorted(unique_indices)
|
||||||
|
texts = [_entity_text(nodes[i]) for i in idx_list]
|
||||||
|
embeddings = await self._get_embeddings().aembed_documents(texts)
|
||||||
|
return dict(zip(idx_list, embeddings))
|
||||||
|
|
||||||
|
async def _llm_arbitrate(self, a: GraphNode, b: GraphNode) -> bool:
|
||||||
|
"""LLM 仲裁两个实体是否相同,严格 JSON schema 校验。"""
|
||||||
|
prompt = _LLM_PROMPT.format(
|
||||||
|
name_a=a.name,
|
||||||
|
type_a=a.type,
|
||||||
|
name_b=b.name,
|
||||||
|
type_b=b.type,
|
||||||
|
)
|
||||||
|
response = await self._get_llm().ainvoke(prompt)
|
||||||
|
content = response.content.strip()
|
||||||
|
|
||||||
|
parsed = json.loads(content)
|
||||||
|
result = LLMArbitrationResult.model_validate(parsed)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"LLM arbitration: '%s' <-> '%s' -> is_same=%s, confidence=%.2f",
|
||||||
|
a.name,
|
||||||
|
b.name,
|
||||||
|
result.is_same,
|
||||||
|
result.confidence,
|
||||||
|
)
|
||||||
|
return result.is_same and result.confidence >= 0.7
|
||||||
@@ -15,6 +15,7 @@ from langchain_experimental.graph_transformers import LLMGraphTransformer
|
|||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_extraction.aligner import EntityAligner
|
||||||
from app.module.kg_extraction.models import (
|
from app.module.kg_extraction.models import (
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
ExtractionResult,
|
ExtractionResult,
|
||||||
@@ -47,6 +48,7 @@ class KnowledgeGraphExtractor:
|
|||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
timeout: int = 60,
|
timeout: int = 60,
|
||||||
max_retries: int = 2,
|
max_retries: int = 2,
|
||||||
|
aligner: EntityAligner | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
||||||
@@ -63,6 +65,7 @@ class KnowledgeGraphExtractor:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
)
|
)
|
||||||
|
self._aligner = aligner or EntityAligner()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_settings(cls) -> KnowledgeGraphExtractor:
|
def from_settings(cls) -> KnowledgeGraphExtractor:
|
||||||
@@ -76,6 +79,7 @@ class KnowledgeGraphExtractor:
|
|||||||
temperature=settings.kg_llm_temperature,
|
temperature=settings.kg_llm_temperature,
|
||||||
timeout=settings.kg_llm_timeout_seconds,
|
timeout=settings.kg_llm_timeout_seconds,
|
||||||
max_retries=settings.kg_llm_max_retries,
|
max_retries=settings.kg_llm_max_retries,
|
||||||
|
aligner=EntityAligner.from_settings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_transformer(
|
def _build_transformer(
|
||||||
@@ -119,6 +123,7 @@ class KnowledgeGraphExtractor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
result = self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
result = await self._aligner.align(result)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
||||||
request.graph_id,
|
request.graph_id,
|
||||||
@@ -154,6 +159,7 @@ class KnowledgeGraphExtractor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
result = self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
result = self._aligner.align_rules_only(result)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
||||||
request.graph_id,
|
request.graph_id,
|
||||||
|
|||||||
477
runtime/datamate-python/app/module/kg_extraction/test_aligner.py
Normal file
477
runtime/datamate-python/app/module/kg_extraction/test_aligner.py
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
"""实体对齐器测试。
|
||||||
|
|
||||||
|
Run with: pytest app/module/kg_extraction/test_aligner.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.module.kg_extraction.aligner import (
|
||||||
|
EntityAligner,
|
||||||
|
LLMArbitrationResult,
|
||||||
|
_build_merged_result,
|
||||||
|
_make_union_find,
|
||||||
|
cosine_similarity,
|
||||||
|
normalize_name,
|
||||||
|
rule_score,
|
||||||
|
)
|
||||||
|
from app.module.kg_extraction.models import (
|
||||||
|
ExtractionResult,
|
||||||
|
GraphEdge,
|
||||||
|
GraphNode,
|
||||||
|
Triple,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# normalize_name
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeName:
|
||||||
|
def test_basic_lowercase(self):
|
||||||
|
assert normalize_name("Hello World") == "hello world"
|
||||||
|
|
||||||
|
def test_unicode_nfkc(self):
|
||||||
|
assert normalize_name("\uff28ello") == "hello"
|
||||||
|
|
||||||
|
def test_punctuation_removed(self):
|
||||||
|
assert normalize_name("U.S.A.") == "usa"
|
||||||
|
|
||||||
|
def test_whitespace_collapsed(self):
|
||||||
|
assert normalize_name(" hello world ") == "hello world"
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert normalize_name("") == ""
|
||||||
|
|
||||||
|
def test_chinese_preserved(self):
|
||||||
|
assert normalize_name("\u5f20\u4e09") == "\u5f20\u4e09"
|
||||||
|
|
||||||
|
def test_mixed_chinese_english(self):
|
||||||
|
assert normalize_name("\u5f20\u4e09 (Zhang San)") == "\u5f20\u4e09 zhang san"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rule_score
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleScore:
|
||||||
|
def test_exact_match(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_normalized_match(self):
|
||||||
|
a = GraphNode(name="Hello World", type="Organization")
|
||||||
|
b = GraphNode(name="hello world", type="Organization")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_type_mismatch(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u5f20\u4e09", type="Organization")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
def test_substring_match(self):
|
||||||
|
a = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
||||||
|
b = GraphNode(name="\u5317\u4eac\u5927\u5b66\u8ba1\u7b97\u673a\u5b66\u9662", type="Organization")
|
||||||
|
assert rule_score(a, b) == 0.5
|
||||||
|
|
||||||
|
def test_no_match(self):
|
||||||
|
a = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
b = GraphNode(name="\u674e\u56db", type="Person")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
def test_type_case_insensitive(self):
|
||||||
|
a = GraphNode(name="test", type="PERSON")
|
||||||
|
b = GraphNode(name="test", type="person")
|
||||||
|
assert rule_score(a, b) == 1.0
|
||||||
|
|
||||||
|
def test_short_substring_ignored(self):
|
||||||
|
"""Single-character substring should not trigger match."""
|
||||||
|
a = GraphNode(name="A", type="Person")
|
||||||
|
b = GraphNode(name="AB", type="Person")
|
||||||
|
assert rule_score(a, b) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cosine_similarity
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCosineSimilarity:
|
||||||
|
def test_identical(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [1.0, 0.0]) == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_orthogonal(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [0.0, 1.0]) == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_opposite(self):
|
||||||
|
assert cosine_similarity([1.0, 0.0], [-1.0, 0.0]) == pytest.approx(-1.0)
|
||||||
|
|
||||||
|
def test_zero_vector(self):
|
||||||
|
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Union-Find
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnionFind:
|
||||||
|
def test_basic(self):
|
||||||
|
parent, find, union = _make_union_find(4)
|
||||||
|
union(0, 1)
|
||||||
|
union(2, 3)
|
||||||
|
assert find(0) == find(1)
|
||||||
|
assert find(2) == find(3)
|
||||||
|
assert find(0) != find(2)
|
||||||
|
|
||||||
|
def test_transitive(self):
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
union(1, 2)
|
||||||
|
assert find(0) == find(2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_merged_result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_result(nodes, edges=None, triples=None):
|
||||||
|
return ExtractionResult(
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges or [],
|
||||||
|
triples=triples or [],
|
||||||
|
raw_text="test text",
|
||||||
|
source_id="src-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildMergedResult:
|
||||||
|
def test_no_merge_returns_original(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="B", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, _ = _make_union_find(2)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert merged is result
|
||||||
|
|
||||||
|
def test_canonical_picks_longest_name(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="AI", type="Tech"),
|
||||||
|
GraphNode(name="Artificial Intelligence", type="Tech"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, union = _make_union_find(2)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.nodes) == 1
|
||||||
|
assert merged.nodes[0].name == "Artificial Intelligence"
|
||||||
|
|
||||||
|
def test_edge_remap_and_dedup(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Alice", type="Person"),
|
||||||
|
GraphNode(name="alice", type="Person"),
|
||||||
|
GraphNode(name="Bob", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="Alice", target="Bob", relation_type="knows"),
|
||||||
|
GraphEdge(source="alice", target="Bob", relation_type="knows"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.edges) == 1
|
||||||
|
assert merged.edges[0].source == "Alice"
|
||||||
|
|
||||||
|
def test_triple_remap_and_dedup(self):
|
||||||
|
n1 = GraphNode(name="Alice", type="Person")
|
||||||
|
n2 = GraphNode(name="alice", type="Person")
|
||||||
|
n3 = GraphNode(name="MIT", type="Organization")
|
||||||
|
triples = [
|
||||||
|
Triple(subject=n1, predicate="works_at", object=n3),
|
||||||
|
Triple(subject=n2, predicate="works_at", object=n3),
|
||||||
|
]
|
||||||
|
result = _make_result([n1, n2, n3], triples=triples)
|
||||||
|
parent, find, union = _make_union_find(3)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert len(merged.triples) == 1
|
||||||
|
assert merged.triples[0].subject.name == "Alice"
|
||||||
|
|
||||||
|
def test_preserves_metadata(self):
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
parent, find, union = _make_union_find(2)
|
||||||
|
union(0, 1)
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
assert merged.raw_text == "test text"
|
||||||
|
assert merged.source_id == "src-1"
|
||||||
|
|
||||||
|
def test_cross_type_same_name_no_collision(self):
|
||||||
|
"""P1-1 回归:同名跨类型节点合并不应误映射其他类型的边和三元组。
|
||||||
|
|
||||||
|
场景:Person "张三" 和 "张三先生" 合并为 "张三先生",
|
||||||
|
但 Organization "张三" 不应被重写。
|
||||||
|
"""
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="张三", type="Person"), # idx 0
|
||||||
|
GraphNode(name="张三先生", type="Person"), # idx 1
|
||||||
|
GraphNode(name="张三", type="Organization"), # idx 2 - 同名不同类型
|
||||||
|
GraphNode(name="北京", type="Location"), # idx 3
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="张三", target="北京", relation_type="lives_in"),
|
||||||
|
GraphEdge(source="张三", target="北京", relation_type="located_in"),
|
||||||
|
]
|
||||||
|
triples = [
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name="张三", type="Person"),
|
||||||
|
predicate="lives_in",
|
||||||
|
object=GraphNode(name="北京", type="Location"),
|
||||||
|
),
|
||||||
|
Triple(
|
||||||
|
subject=GraphNode(name="张三", type="Organization"),
|
||||||
|
predicate="located_in",
|
||||||
|
object=GraphNode(name="北京", type="Location"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges, triples)
|
||||||
|
parent, find, union = _make_union_find(4)
|
||||||
|
union(0, 1) # 合并 Person "张三" 和 "张三先生"
|
||||||
|
merged = _build_merged_result(result, parent, find)
|
||||||
|
|
||||||
|
# 应有 3 个节点:张三先生(Person), 张三(Org), 北京(Location)
|
||||||
|
assert len(merged.nodes) == 3
|
||||||
|
merged_names = {(n.name, n.type) for n in merged.nodes}
|
||||||
|
assert ("张三先生", "Person") in merged_names
|
||||||
|
assert ("张三", "Organization") in merged_names
|
||||||
|
assert ("北京", "Location") in merged_names
|
||||||
|
|
||||||
|
# edges 中 "张三" 有歧义(映射到不同 canonical),应保持原名不重写
|
||||||
|
assert len(merged.edges) == 2
|
||||||
|
|
||||||
|
# triples 有类型信息,可精确区分
|
||||||
|
assert len(merged.triples) == 2
|
||||||
|
person_triple = [t for t in merged.triples if t.subject.type == "Person"][0]
|
||||||
|
org_triple = [t for t in merged.triples if t.subject.type == "Organization"][0]
|
||||||
|
assert person_triple.subject.name == "张三先生" # Person 被重写
|
||||||
|
assert org_triple.subject.name == "张三" # Organization 保持原名
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EntityAligner
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityAligner:
|
||||||
|
def _run(self, coro):
|
||||||
|
"""Helper to run async coroutine in sync test."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(coro)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_disabled_returns_original(self):
|
||||||
|
aligner = EntityAligner(enabled=False)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_single_node_returns_original(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_rule_merge_exact_names(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u674e\u56db", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="\u5f20\u4e09", target="\u674e\u56db", relation_type="knows"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
names = {n.name for n in aligned.nodes}
|
||||||
|
assert "\u5f20\u4e09" in names
|
||||||
|
assert "\u674e\u56db" in names
|
||||||
|
|
||||||
|
def test_rule_merge_case_insensitive(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Hello World", type="Org"),
|
||||||
|
GraphNode(name="hello world", type="Org"),
|
||||||
|
GraphNode(name="Test", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_rule_merge_deduplicates_edges(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="Hello World", type="Org"),
|
||||||
|
GraphNode(name="hello world", type="Org"),
|
||||||
|
GraphNode(name="Test", type="Person"),
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
GraphEdge(source="Hello World", target="Test", relation_type="employs"),
|
||||||
|
GraphEdge(source="hello world", target="Test", relation_type="employs"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes, edges)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.edges) == 1
|
||||||
|
|
||||||
|
def test_rule_merge_deduplicates_triples(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
n1 = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
n2 = GraphNode(name="\u5f20\u4e09", type="Person")
|
||||||
|
n3 = GraphNode(name="\u5317\u4eac\u5927\u5b66", type="Organization")
|
||||||
|
triples = [
|
||||||
|
Triple(subject=n1, predicate="works_at", object=n3),
|
||||||
|
Triple(subject=n2, predicate="works_at", object=n3),
|
||||||
|
]
|
||||||
|
result = _make_result([n1, n2, n3], triples=triples)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.triples) == 1
|
||||||
|
|
||||||
|
def test_type_mismatch_no_merge(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Organization"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_fail_open_on_error(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
with patch.object(aligner, "_align_impl", side_effect=RuntimeError("boom")):
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_align_rules_only_sync(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u5f20\u4e09", type="Person"),
|
||||||
|
GraphNode(name="\u674e\u56db", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert len(aligned.nodes) == 2
|
||||||
|
|
||||||
|
def test_align_rules_only_disabled(self):
|
||||||
|
aligner = EntityAligner(enabled=False)
|
||||||
|
result = _make_result([GraphNode(name="A", type="Person")])
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_align_rules_only_fail_open(self):
|
||||||
|
aligner = EntityAligner(enabled=True)
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="A", type="Person"),
|
||||||
|
GraphNode(name="B", type="Person"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
with patch(
|
||||||
|
"app.module.kg_extraction.aligner.rule_score", side_effect=RuntimeError("boom")
|
||||||
|
):
|
||||||
|
aligned = aligner.align_rules_only(result)
|
||||||
|
assert aligned is result
|
||||||
|
|
||||||
|
def test_llm_count_incremented_on_failure(self):
|
||||||
|
"""P1-2 回归:LLM 仲裁失败也应计入 max_llm_arbitrations 预算。"""
|
||||||
|
max_arb = 2
|
||||||
|
aligner = EntityAligner(
|
||||||
|
enabled=True,
|
||||||
|
max_llm_arbitrations=max_arb,
|
||||||
|
llm_arbitration_enabled=True,
|
||||||
|
)
|
||||||
|
# 构建 4 个同类型节点,规则层子串匹配产生多个 vector 候选
|
||||||
|
nodes = [
|
||||||
|
GraphNode(name="北京大学", type="Organization"),
|
||||||
|
GraphNode(name="北京大学计算机学院", type="Organization"),
|
||||||
|
GraphNode(name="北京大学数学学院", type="Organization"),
|
||||||
|
GraphNode(name="北京大学物理学院", type="Organization"),
|
||||||
|
]
|
||||||
|
result = _make_result(nodes)
|
||||||
|
|
||||||
|
# Mock embedding 使所有候选都落入 LLM 仲裁区间
|
||||||
|
fake_embedding = [1.0, 0.0, 0.0]
|
||||||
|
# 微调使 cosine 在 llm_threshold 和 auto_threshold 之间
|
||||||
|
import math
|
||||||
|
|
||||||
|
# cos(θ) = 0.85 → 在默认 [0.78, 0.92) 区间
|
||||||
|
angle = math.acos(0.85)
|
||||||
|
emb_a = [1.0, 0.0]
|
||||||
|
emb_b = [math.cos(angle), math.sin(angle)]
|
||||||
|
|
||||||
|
async def fake_embed(texts):
|
||||||
|
# 偶数索引返回 emb_a,奇数返回 emb_b
|
||||||
|
return [emb_a if i % 2 == 0 else emb_b for i in range(len(texts))]
|
||||||
|
|
||||||
|
mock_llm_arbitrate = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||||
|
|
||||||
|
with patch.object(aligner, "_get_embeddings") as mock_emb:
|
||||||
|
mock_emb_instance = AsyncMock()
|
||||||
|
mock_emb_instance.aembed_documents = fake_embed
|
||||||
|
mock_emb.return_value = mock_emb_instance
|
||||||
|
with patch.object(aligner, "_llm_arbitrate", mock_llm_arbitrate):
|
||||||
|
aligned = self._run(aligner.align(result))
|
||||||
|
|
||||||
|
# LLM 应恰好被调用 max_arb 次(不会因异常不计数而超出预算)
|
||||||
|
assert mock_llm_arbitrate.call_count <= max_arb
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLMArbitrationResult
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMArbitrationResult:
|
||||||
|
def test_valid_parse(self):
|
||||||
|
data = {"is_same": True, "confidence": 0.95, "reason": "Same entity"}
|
||||||
|
result = LLMArbitrationResult.model_validate(data)
|
||||||
|
assert result.is_same is True
|
||||||
|
assert result.confidence == 0.95
|
||||||
|
|
||||||
|
def test_confidence_bounds(self):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
LLMArbitrationResult.model_validate(
|
||||||
|
{"is_same": True, "confidence": 1.5, "reason": ""}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_missing_reason_defaults(self):
|
||||||
|
result = LLMArbitrationResult.model_validate(
|
||||||
|
{"is_same": False, "confidence": 0.1}
|
||||||
|
)
|
||||||
|
assert result.reason == ""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user