You've already forked DataMate
- 实现三层对齐策略:规则层 + 向量相似度层 + LLM 仲裁层 - 规则层:名称规范化(NFKC、小写、去标点/空格)+ 规则评分 - 向量层:OpenAI Embeddings + cosine 相似度计算 - LLM 层:仅对边界样本调用,严格 JSON schema 校验 - 使用 Union-Find 实现传递合并 - 支持批内对齐(库内对齐待 KG 服务 API 支持) 核心组件: - EntityAligner 类:align() (async)、align_rules_only() (sync) - 配置项:kg_alignment_enabled(默认 false)、embedding_model、阈值 - 失败策略:fail-open(对齐失败不中断请求) 集成: - 已集成到抽取主链路(extract → align → return) - extract() 调用 async align() - extract_sync() 调用 sync align_rules_only() 修复: - P1-1:使用 (name, type) 作为 key,避免同名跨类型误合并 - P1-2:LLM 计数在 finally 块中增加,异常也计数 - P1-3:添加库内对齐说明(待后续实现) 新增 41 个测试用例,全部通过 测试结果:41 tests pass
235 lines
7.9 KiB
Python
235 lines
7.9 KiB
Python
"""基于 LLM 的知识图谱三元组抽取器。
|
|
|
|
利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系,
|
|
支持 schema-guided 抽取以提升准确率。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from typing import Sequence
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
|
from pydantic import SecretStr
|
|
|
|
from app.core.logging import get_logger
|
|
from app.module.kg_extraction.aligner import EntityAligner
|
|
from app.module.kg_extraction.models import (
|
|
ExtractionRequest,
|
|
ExtractionResult,
|
|
ExtractionSchema,
|
|
GraphEdge,
|
|
GraphNode,
|
|
Triple,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _text_fingerprint(text: str) -> str:
|
|
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
|
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
|
|
|
|
|
|
class KnowledgeGraphExtractor:
|
|
"""基于 LLMGraphTransformer 的三元组抽取器。
|
|
|
|
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
|
也可直接构造以覆盖默认参数。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "gpt-4o-mini",
|
|
base_url: str | None = None,
|
|
api_key: SecretStr = SecretStr("EMPTY"),
|
|
temperature: float = 0.0,
|
|
timeout: int = 60,
|
|
max_retries: int = 2,
|
|
aligner: EntityAligner | None = None,
|
|
) -> None:
|
|
logger.info(
|
|
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
|
model_name,
|
|
base_url or "default",
|
|
timeout,
|
|
max_retries,
|
|
)
|
|
self._llm = ChatOpenAI(
|
|
model=model_name,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
temperature=temperature,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
)
|
|
self._aligner = aligner or EntityAligner()
|
|
|
|
@classmethod
|
|
def from_settings(cls) -> KnowledgeGraphExtractor:
|
|
"""从全局 Settings 创建抽取器实例。"""
|
|
from app.core.config import settings
|
|
|
|
return cls(
|
|
model_name=settings.kg_llm_model,
|
|
base_url=settings.kg_llm_base_url,
|
|
api_key=settings.kg_llm_api_key,
|
|
temperature=settings.kg_llm_temperature,
|
|
timeout=settings.kg_llm_timeout_seconds,
|
|
max_retries=settings.kg_llm_max_retries,
|
|
aligner=EntityAligner.from_settings(),
|
|
)
|
|
|
|
def _build_transformer(
|
|
self,
|
|
schema: ExtractionSchema | None = None,
|
|
) -> LLMGraphTransformer:
|
|
"""根据可选的 schema 约束构造 LLMGraphTransformer。"""
|
|
kwargs: dict = {"llm": self._llm}
|
|
|
|
if schema:
|
|
if schema.entity_types:
|
|
kwargs["allowed_nodes"] = [et.name for et in schema.entity_types]
|
|
if schema.relation_types:
|
|
kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types]
|
|
|
|
return LLMGraphTransformer(**kwargs)
|
|
|
|
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
|
"""从文本中异步抽取三元组。"""
|
|
text_hash = _text_fingerprint(request.text)
|
|
logger.info(
|
|
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
len(request.text),
|
|
text_hash,
|
|
)
|
|
|
|
transformer = self._build_transformer(request.schema)
|
|
documents = [Document(page_content=request.text)]
|
|
|
|
try:
|
|
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
|
except Exception:
|
|
logger.exception(
|
|
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
text_hash,
|
|
)
|
|
raise
|
|
|
|
result = self._convert_result(graph_documents, request)
|
|
result = await self._aligner.align(result)
|
|
logger.info(
|
|
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
|
request.graph_id,
|
|
len(result.nodes),
|
|
len(result.edges),
|
|
len(result.triples),
|
|
)
|
|
return result
|
|
|
|
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
|
"""同步版本的三元组抽取。"""
|
|
text_hash = _text_fingerprint(request.text)
|
|
logger.info(
|
|
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
len(request.text),
|
|
text_hash,
|
|
)
|
|
|
|
transformer = self._build_transformer(request.schema)
|
|
documents = [Document(page_content=request.text)]
|
|
|
|
try:
|
|
graph_documents = transformer.convert_to_graph_documents(documents)
|
|
except Exception:
|
|
logger.exception(
|
|
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
text_hash,
|
|
)
|
|
raise
|
|
|
|
result = self._convert_result(graph_documents, request)
|
|
result = self._aligner.align_rules_only(result)
|
|
logger.info(
|
|
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
|
request.graph_id,
|
|
len(result.nodes),
|
|
len(result.edges),
|
|
)
|
|
return result
|
|
|
|
async def extract_batch(
|
|
self,
|
|
requests: Sequence[ExtractionRequest],
|
|
) -> list[ExtractionResult]:
|
|
"""批量抽取,逐条处理。
|
|
|
|
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
|
|
"""
|
|
logger.info("Starting batch extraction: count=%d", len(requests))
|
|
results: list[ExtractionResult] = []
|
|
for i, req in enumerate(requests):
|
|
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
|
|
result = await self.extract(req)
|
|
results.append(result)
|
|
logger.info("Batch extraction complete: count=%d", len(results))
|
|
return results
|
|
|
|
@staticmethod
|
|
def _convert_result(
|
|
graph_documents: list,
|
|
request: ExtractionRequest,
|
|
) -> ExtractionResult:
|
|
"""将 LangChain GraphDocument 转换为内部数据模型。"""
|
|
nodes: list[GraphNode] = []
|
|
edges: list[GraphEdge] = []
|
|
triples: list[Triple] = []
|
|
seen_nodes: set[str] = set()
|
|
|
|
for doc in graph_documents:
|
|
for node in doc.nodes:
|
|
node_key = f"{node.id}:{node.type}"
|
|
if node_key not in seen_nodes:
|
|
seen_nodes.add(node_key)
|
|
nodes.append(
|
|
GraphNode(
|
|
name=node.id,
|
|
type=node.type,
|
|
properties=node.properties if hasattr(node, "properties") else {},
|
|
)
|
|
)
|
|
|
|
for rel in doc.relationships:
|
|
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
|
|
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
|
|
|
|
edges.append(
|
|
GraphEdge(
|
|
source=rel.source.id,
|
|
target=rel.target.id,
|
|
relation_type=rel.type,
|
|
properties=rel.properties if hasattr(rel, "properties") else {},
|
|
)
|
|
)
|
|
triples.append(
|
|
Triple(subject=source_node, predicate=rel.type, object=target_node)
|
|
)
|
|
|
|
return ExtractionResult(
|
|
nodes=nodes,
|
|
edges=edges,
|
|
triples=triples,
|
|
raw_text=request.text,
|
|
source_id=request.source_id,
|
|
)
|