"""基于 LLM 的知识图谱三元组抽取器。 利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系, 支持 schema-guided 抽取以提升准确率。 """ from __future__ import annotations import logging from typing import Sequence from langchain_core.documents import Document from langchain_openai import ChatOpenAI from langchain_experimental.graph_transformers import LLMGraphTransformer from app.module.kg_extraction.models import ( ExtractionRequest, ExtractionResult, ExtractionSchema, GraphEdge, GraphNode, Triple, ) logger = logging.getLogger(__name__) class KnowledgeGraphExtractor: """基于 LLMGraphTransformer 的三元组抽取器。 Parameters ---------- model_name : str OpenAI 兼容模型名称。 base_url : str | None 自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。 api_key : str API 密钥。 temperature : float 生成温度,抽取任务建议使用较低值。 """ def __init__( self, model_name: str = "gpt-4o-mini", base_url: str | None = None, api_key: str = "EMPTY", temperature: float = 0.0, ) -> None: self._llm = ChatOpenAI( model=model_name, base_url=base_url, api_key=api_key, temperature=temperature, ) def _build_transformer( self, schema: ExtractionSchema | None = None, ) -> LLMGraphTransformer: """根据可选的 schema 约束构造 LLMGraphTransformer。""" kwargs: dict = {"llm": self._llm} if schema: if schema.entity_types: kwargs["allowed_nodes"] = [et.name for et in schema.entity_types] if schema.relation_types: kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types] return LLMGraphTransformer(**kwargs) async def extract(self, request: ExtractionRequest) -> ExtractionResult: """从文本中抽取三元组。 Parameters ---------- request : ExtractionRequest 包含文本、schema 约束等信息的抽取请求。 Returns ------- ExtractionResult 抽取得到的节点、边和三元组。 """ transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = await transformer.aconvert_to_graph_documents(documents) except Exception: logger.exception("LLM graph extraction failed for source_id=%s", request.source_id) return ExtractionResult(raw_text=request.text, source_id=request.source_id) return self._convert_result(graph_documents, request) def extract_sync(self, request: ExtractionRequest) -> ExtractionResult: """同步版本的三元组抽取。""" transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = transformer.convert_to_graph_documents(documents) except Exception: logger.exception("LLM graph extraction failed for source_id=%s", request.source_id) return ExtractionResult(raw_text=request.text, source_id=request.source_id) return self._convert_result(graph_documents, request) async def extract_batch( self, requests: Sequence[ExtractionRequest], ) -> list[ExtractionResult]: """批量抽取。 对多段文本逐一抽取并汇总结果。 如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。 """ results: list[ExtractionResult] = [] for req in requests: result = await self.extract(req) results.append(result) return results @staticmethod def _convert_result( graph_documents: list, request: ExtractionRequest, ) -> ExtractionResult: """将 LangChain GraphDocument 转换为内部数据模型。""" nodes: list[GraphNode] = [] edges: list[GraphEdge] = [] triples: list[Triple] = [] seen_nodes: set[str] = set() for doc in graph_documents: # 收集节点 for node in doc.nodes: node_key = f"{node.id}:{node.type}" if node_key not in seen_nodes: seen_nodes.add(node_key) nodes.append( GraphNode( name=node.id, type=node.type, properties=node.properties if hasattr(node, "properties") else {}, ) ) # 收集关系 for rel in doc.relationships: source_node = GraphNode( name=rel.source.id, type=rel.source.type, ) target_node = GraphNode( name=rel.target.id, type=rel.target.type, ) edges.append( GraphEdge( source=rel.source.id, target=rel.target.id, relation_type=rel.type, properties=rel.properties if hasattr(rel, "properties") else {}, ) ) triples.append( Triple( subject=source_node, predicate=rel.type, object=target_node, ) ) return ExtractionResult( nodes=nodes, edges=edges, triples=triples, raw_text=request.text, source_id=request.source_id, )