Files
DataMate/runtime/datamate-python/app/module/kg_extraction/extractor.py
Jerry Yan 0ed7dcbee7 feat(kg): 实现实体对齐功能(aligner.py)
- 实现三层对齐策略:规则层 + 向量相似度层 + LLM 仲裁层
- 规则层:名称规范化(NFKC、小写、去标点/空格)+ 规则评分
- 向量层:OpenAI Embeddings + cosine 相似度计算
- LLM 层:仅对边界样本调用,严格 JSON schema 校验
- 使用 Union-Find 实现传递合并
- 支持批内对齐(库内对齐待 KG 服务 API 支持)

核心组件:
- EntityAligner 类:align() (async)、align_rules_only() (sync)
- 配置项:kg_alignment_enabled(默认 false)、embedding_model、阈值
- 失败策略:fail-open(对齐失败不中断请求)

集成:
- 已集成到抽取主链路(extract → align → return)
- extract() 调用 async align()
- extract_sync() 调用 sync align_rules_only()

修复:
- P1-1:使用 (name, type) 作为 key,避免同名跨类型误合并
- P1-2:LLM 计数在 finally 块中增加,异常也计数
- P1-3:添加库内对齐说明(待后续实现)

新增 41 个测试用例,全部通过
测试结果:41 tests pass
2026-02-19 18:26:54 +08:00

235 lines
7.9 KiB
Python

"""基于 LLM 的知识图谱三元组抽取器。
利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系,
支持 schema-guided 抽取以提升准确率。
"""
from __future__ import annotations
import hashlib
from typing import Sequence
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from pydantic import SecretStr
from app.core.logging import get_logger
from app.module.kg_extraction.aligner import EntityAligner
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
GraphEdge,
GraphNode,
Triple,
)
logger = get_logger(__name__)
def _text_fingerprint(text: str) -> str:
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
class KnowledgeGraphExtractor:
"""基于 LLMGraphTransformer 的三元组抽取器。
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
也可直接构造以覆盖默认参数。
"""
def __init__(
self,
model_name: str = "gpt-4o-mini",
base_url: str | None = None,
api_key: SecretStr = SecretStr("EMPTY"),
temperature: float = 0.0,
timeout: int = 60,
max_retries: int = 2,
aligner: EntityAligner | None = None,
) -> None:
logger.info(
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
model_name,
base_url or "default",
timeout,
max_retries,
)
self._llm = ChatOpenAI(
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature,
timeout=timeout,
max_retries=max_retries,
)
self._aligner = aligner or EntityAligner()
@classmethod
def from_settings(cls) -> KnowledgeGraphExtractor:
"""从全局 Settings 创建抽取器实例。"""
from app.core.config import settings
return cls(
model_name=settings.kg_llm_model,
base_url=settings.kg_llm_base_url,
api_key=settings.kg_llm_api_key,
temperature=settings.kg_llm_temperature,
timeout=settings.kg_llm_timeout_seconds,
max_retries=settings.kg_llm_max_retries,
aligner=EntityAligner.from_settings(),
)
def _build_transformer(
self,
schema: ExtractionSchema | None = None,
) -> LLMGraphTransformer:
"""根据可选的 schema 约束构造 LLMGraphTransformer。"""
kwargs: dict = {"llm": self._llm}
if schema:
if schema.entity_types:
kwargs["allowed_nodes"] = [et.name for et in schema.entity_types]
if schema.relation_types:
kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types]
return LLMGraphTransformer(**kwargs)
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
"""从文本中异步抽取三元组。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = await transformer.aconvert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
result = await self._aligner.align(result)
logger.info(
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
len(result.triples),
)
return result
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
"""同步版本的三元组抽取。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = transformer.convert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
result = self._aligner.align_rules_only(result)
logger.info(
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
)
return result
async def extract_batch(
self,
requests: Sequence[ExtractionRequest],
) -> list[ExtractionResult]:
"""批量抽取,逐条处理。
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
"""
logger.info("Starting batch extraction: count=%d", len(requests))
results: list[ExtractionResult] = []
for i, req in enumerate(requests):
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
result = await self.extract(req)
results.append(result)
logger.info("Batch extraction complete: count=%d", len(results))
return results
@staticmethod
def _convert_result(
graph_documents: list,
request: ExtractionRequest,
) -> ExtractionResult:
"""将 LangChain GraphDocument 转换为内部数据模型。"""
nodes: list[GraphNode] = []
edges: list[GraphEdge] = []
triples: list[Triple] = []
seen_nodes: set[str] = set()
for doc in graph_documents:
for node in doc.nodes:
node_key = f"{node.id}:{node.type}"
if node_key not in seen_nodes:
seen_nodes.add(node_key)
nodes.append(
GraphNode(
name=node.id,
type=node.type,
properties=node.properties if hasattr(node, "properties") else {},
)
)
for rel in doc.relationships:
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
edges.append(
GraphEdge(
source=rel.source.id,
target=rel.target.id,
relation_type=rel.type,
properties=rel.properties if hasattr(rel, "properties") else {},
)
)
triples.append(
Triple(subject=source_node, predicate=rel.type, object=target_node)
)
return ExtractionResult(
nodes=nodes,
edges=edges,
triples=triples,
raw_text=request.text,
source_id=request.source_id,
)