You've already forked DataMate
实现功能: - 创建 kg_extraction/interface.py(FastAPI 路由) - 实现 POST /api/kg/extract(单条文本抽取) - 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条) - 集成到 FastAPI 主路由(/api/kg/ 前缀) 技术实现: - 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature) - 安全性: - API Key 使用 SecretStr 保护 - 错误信息脱敏(使用 trace_id,不暴露原始异常) - 请求文本不写入日志(使用 SHA-256 hash) - 强制要求 X-User-Id 头(鉴权边界) - 超时控制: - kg_llm_timeout_seconds(60秒) - kg_llm_max_retries(2次) - 输入校验: - graph_id 和 source_id 使用 UUID pattern - source_type 使用 Enum(4个值) - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符) - 审计日志:记录 caller、trace_id、text_hash 代码审查: - 经过 3 轮 Codex 审查和 2 轮 Claude 修复 - 所有问题已解决(5个 P1/P2 + 3个 P3) - 语法检查通过 API 端点: - POST /api/kg/extract:单条文本抽取 - POST /api/kg/extract/batch:批量抽取(最多 50 条) 配置环境变量: - KG_LLM_API_KEY:LLM API 密钥 - KG_LLM_BASE_URL:自定义端点(可选) - KG_LLM_MODEL:模型名称(默认 gpt-4o-mini) - KG_LLM_TEMPERATURE:生成温度(默认 0.0) - KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60) - KG_LLM_MAX_RETRIES:重试次数(默认 2)
229 lines
7.6 KiB
Python
229 lines
7.6 KiB
Python
"""基于 LLM 的知识图谱三元组抽取器。
|
|
|
|
利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系,
|
|
支持 schema-guided 抽取以提升准确率。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from typing import Sequence
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
|
from pydantic import SecretStr
|
|
|
|
from app.core.logging import get_logger
|
|
from app.module.kg_extraction.models import (
|
|
ExtractionRequest,
|
|
ExtractionResult,
|
|
ExtractionSchema,
|
|
GraphEdge,
|
|
GraphNode,
|
|
Triple,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _text_fingerprint(text: str) -> str:
|
|
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
|
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
|
|
|
|
|
|
class KnowledgeGraphExtractor:
|
|
"""基于 LLMGraphTransformer 的三元组抽取器。
|
|
|
|
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
|
也可直接构造以覆盖默认参数。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "gpt-4o-mini",
|
|
base_url: str | None = None,
|
|
api_key: SecretStr = SecretStr("EMPTY"),
|
|
temperature: float = 0.0,
|
|
timeout: int = 60,
|
|
max_retries: int = 2,
|
|
) -> None:
|
|
logger.info(
|
|
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
|
model_name,
|
|
base_url or "default",
|
|
timeout,
|
|
max_retries,
|
|
)
|
|
self._llm = ChatOpenAI(
|
|
model=model_name,
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
temperature=temperature,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
)
|
|
|
|
@classmethod
|
|
def from_settings(cls) -> KnowledgeGraphExtractor:
|
|
"""从全局 Settings 创建抽取器实例。"""
|
|
from app.core.config import settings
|
|
|
|
return cls(
|
|
model_name=settings.kg_llm_model,
|
|
base_url=settings.kg_llm_base_url,
|
|
api_key=settings.kg_llm_api_key,
|
|
temperature=settings.kg_llm_temperature,
|
|
timeout=settings.kg_llm_timeout_seconds,
|
|
max_retries=settings.kg_llm_max_retries,
|
|
)
|
|
|
|
def _build_transformer(
|
|
self,
|
|
schema: ExtractionSchema | None = None,
|
|
) -> LLMGraphTransformer:
|
|
"""根据可选的 schema 约束构造 LLMGraphTransformer。"""
|
|
kwargs: dict = {"llm": self._llm}
|
|
|
|
if schema:
|
|
if schema.entity_types:
|
|
kwargs["allowed_nodes"] = [et.name for et in schema.entity_types]
|
|
if schema.relation_types:
|
|
kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types]
|
|
|
|
return LLMGraphTransformer(**kwargs)
|
|
|
|
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
|
"""从文本中异步抽取三元组。"""
|
|
text_hash = _text_fingerprint(request.text)
|
|
logger.info(
|
|
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
len(request.text),
|
|
text_hash,
|
|
)
|
|
|
|
transformer = self._build_transformer(request.schema)
|
|
documents = [Document(page_content=request.text)]
|
|
|
|
try:
|
|
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
|
except Exception:
|
|
logger.exception(
|
|
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
text_hash,
|
|
)
|
|
raise
|
|
|
|
result = self._convert_result(graph_documents, request)
|
|
logger.info(
|
|
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
|
request.graph_id,
|
|
len(result.nodes),
|
|
len(result.edges),
|
|
len(result.triples),
|
|
)
|
|
return result
|
|
|
|
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
|
"""同步版本的三元组抽取。"""
|
|
text_hash = _text_fingerprint(request.text)
|
|
logger.info(
|
|
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
len(request.text),
|
|
text_hash,
|
|
)
|
|
|
|
transformer = self._build_transformer(request.schema)
|
|
documents = [Document(page_content=request.text)]
|
|
|
|
try:
|
|
graph_documents = transformer.convert_to_graph_documents(documents)
|
|
except Exception:
|
|
logger.exception(
|
|
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
|
request.graph_id,
|
|
request.source_id,
|
|
text_hash,
|
|
)
|
|
raise
|
|
|
|
result = self._convert_result(graph_documents, request)
|
|
logger.info(
|
|
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
|
request.graph_id,
|
|
len(result.nodes),
|
|
len(result.edges),
|
|
)
|
|
return result
|
|
|
|
async def extract_batch(
|
|
self,
|
|
requests: Sequence[ExtractionRequest],
|
|
) -> list[ExtractionResult]:
|
|
"""批量抽取,逐条处理。
|
|
|
|
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
|
|
"""
|
|
logger.info("Starting batch extraction: count=%d", len(requests))
|
|
results: list[ExtractionResult] = []
|
|
for i, req in enumerate(requests):
|
|
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
|
|
result = await self.extract(req)
|
|
results.append(result)
|
|
logger.info("Batch extraction complete: count=%d", len(results))
|
|
return results
|
|
|
|
@staticmethod
|
|
def _convert_result(
|
|
graph_documents: list,
|
|
request: ExtractionRequest,
|
|
) -> ExtractionResult:
|
|
"""将 LangChain GraphDocument 转换为内部数据模型。"""
|
|
nodes: list[GraphNode] = []
|
|
edges: list[GraphEdge] = []
|
|
triples: list[Triple] = []
|
|
seen_nodes: set[str] = set()
|
|
|
|
for doc in graph_documents:
|
|
for node in doc.nodes:
|
|
node_key = f"{node.id}:{node.type}"
|
|
if node_key not in seen_nodes:
|
|
seen_nodes.add(node_key)
|
|
nodes.append(
|
|
GraphNode(
|
|
name=node.id,
|
|
type=node.type,
|
|
properties=node.properties if hasattr(node, "properties") else {},
|
|
)
|
|
)
|
|
|
|
for rel in doc.relationships:
|
|
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
|
|
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
|
|
|
|
edges.append(
|
|
GraphEdge(
|
|
source=rel.source.id,
|
|
target=rel.target.id,
|
|
relation_type=rel.type,
|
|
properties=rel.properties if hasattr(rel, "properties") else {},
|
|
)
|
|
)
|
|
triples.append(
|
|
Triple(subject=source_node, predicate=rel.type, object=target_node)
|
|
)
|
|
|
|
return ExtractionResult(
|
|
nodes=nodes,
|
|
edges=edges,
|
|
triples=triples,
|
|
raw_text=request.text,
|
|
source_id=request.source_id,
|
|
)
|