"""基于 LLM 的知识图谱三元组抽取器。 利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系, 支持 schema-guided 抽取以提升准确率。 """ from __future__ import annotations import hashlib from typing import Sequence from langchain_core.documents import Document from langchain_openai import ChatOpenAI from langchain_experimental.graph_transformers import LLMGraphTransformer from pydantic import SecretStr from app.core.logging import get_logger from app.module.kg_extraction.aligner import EntityAligner from app.module.kg_extraction.models import ( ExtractionRequest, ExtractionResult, ExtractionSchema, GraphEdge, GraphNode, Triple, ) logger = get_logger(__name__) def _text_fingerprint(text: str) -> str: """返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。""" return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12] class KnowledgeGraphExtractor: """基于 LLMGraphTransformer 的三元组抽取器。 通过 ``from_settings()`` 工厂方法从全局配置创建实例, 也可直接构造以覆盖默认参数。 """ def __init__( self, model_name: str = "gpt-4o-mini", base_url: str | None = None, api_key: SecretStr = SecretStr("EMPTY"), temperature: float = 0.0, timeout: int = 60, max_retries: int = 2, aligner: EntityAligner | None = None, ) -> None: logger.info( "Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)", model_name, base_url or "default", timeout, max_retries, ) self._llm = ChatOpenAI( model=model_name, base_url=base_url, api_key=api_key, temperature=temperature, timeout=timeout, max_retries=max_retries, ) self._aligner = aligner or EntityAligner() @classmethod def from_settings(cls) -> KnowledgeGraphExtractor: """从全局 Settings 创建抽取器实例。""" from app.core.config import settings return cls( model_name=settings.kg_llm_model, base_url=settings.kg_llm_base_url, api_key=settings.kg_llm_api_key, temperature=settings.kg_llm_temperature, timeout=settings.kg_llm_timeout_seconds, max_retries=settings.kg_llm_max_retries, aligner=EntityAligner.from_settings(), ) def _build_transformer( self, schema: ExtractionSchema | None = None, ) -> LLMGraphTransformer: """根据可选的 schema 约束构造 LLMGraphTransformer。""" kwargs: dict = {"llm": self._llm} if schema: if schema.entity_types: kwargs["allowed_nodes"] = [et.name for et in schema.entity_types] if schema.relation_types: kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types] return LLMGraphTransformer(**kwargs) async def extract(self, request: ExtractionRequest) -> ExtractionResult: """从文本中异步抽取三元组。""" text_hash = _text_fingerprint(request.text) logger.info( "Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s", request.graph_id, request.source_id, len(request.text), text_hash, ) transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = await transformer.aconvert_to_graph_documents(documents) except Exception: logger.exception( "LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s", request.graph_id, request.source_id, text_hash, ) raise result = self._convert_result(graph_documents, request) result = await self._aligner.align(result) logger.info( "Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d", request.graph_id, len(result.nodes), len(result.edges), len(result.triples), ) return result def extract_sync(self, request: ExtractionRequest) -> ExtractionResult: """同步版本的三元组抽取。""" text_hash = _text_fingerprint(request.text) logger.info( "Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s", request.graph_id, request.source_id, len(request.text), text_hash, ) transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = transformer.convert_to_graph_documents(documents) except Exception: logger.exception( "LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s", request.graph_id, request.source_id, text_hash, ) raise result = self._convert_result(graph_documents, request) result = self._aligner.align_rules_only(result) logger.info( "Sync extraction complete: graph_id=%s, nodes=%d, edges=%d", request.graph_id, len(result.nodes), len(result.edges), ) return result async def extract_batch( self, requests: Sequence[ExtractionRequest], ) -> list[ExtractionResult]: """批量抽取,逐条处理。 如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。 """ logger.info("Starting batch extraction: count=%d", len(requests)) results: list[ExtractionResult] = [] for i, req in enumerate(requests): logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id) result = await self.extract(req) results.append(result) logger.info("Batch extraction complete: count=%d", len(results)) return results @staticmethod def _convert_result( graph_documents: list, request: ExtractionRequest, ) -> ExtractionResult: """将 LangChain GraphDocument 转换为内部数据模型。""" nodes: list[GraphNode] = [] edges: list[GraphEdge] = [] triples: list[Triple] = [] seen_nodes: set[str] = set() for doc in graph_documents: for node in doc.nodes: node_key = f"{node.id}:{node.type}" if node_key not in seen_nodes: seen_nodes.add(node_key) nodes.append( GraphNode( name=node.id, type=node.type, properties=node.properties if hasattr(node, "properties") else {}, ) ) for rel in doc.relationships: source_node = GraphNode(name=rel.source.id, type=rel.source.type) target_node = GraphNode(name=rel.target.id, type=rel.target.type) edges.append( GraphEdge( source=rel.source.id, target=rel.target.id, relation_type=rel.type, properties=rel.properties if hasattr(rel, "properties") else {}, ) ) triples.append( Triple(subject=source_node, predicate=rel.type, object=target_node) ) return ExtractionResult( nodes=nodes, edges=edges, triples=triples, raw_text=request.text, source_id=request.source_id, )