Files
DataMate/runtime/datamate-python/app/module/kg_extraction/extractor.py
Jerry Yan 0e0782a452 feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口
实现功能:
- 创建 kg_extraction/interface.py(FastAPI 路由)
- 实现 POST /api/kg/extract(单条文本抽取)
- 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条)
- 集成到 FastAPI 主路由(/api/kg/ 前缀)

技术实现:
- 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature)
- 安全性:
  - API Key 使用 SecretStr 保护
  - 错误信息脱敏(使用 trace_id,不暴露原始异常)
  - 请求文本不写入日志(使用 SHA-256 hash)
  - 强制要求 X-User-Id 头(鉴权边界)
- 超时控制:
  - kg_llm_timeout_seconds(60秒)
  - kg_llm_max_retries(2次)
- 输入校验:
  - graph_id 和 source_id 使用 UUID pattern
  - source_type 使用 Enum(4个值)
  - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符)
- 审计日志:记录 caller、trace_id、text_hash

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有问题已解决(5个 P1/P2 + 3个 P3)
- 语法检查通过

API 端点:
- POST /api/kg/extract:单条文本抽取
- POST /api/kg/extract/batch:批量抽取(最多 50 条)

配置环境变量:
- KG_LLM_API_KEY:LLM API 密钥
- KG_LLM_BASE_URL:自定义端点(可选)
- KG_LLM_MODEL:模型名称(默认 gpt-4o-mini)
- KG_LLM_TEMPERATURE:生成温度(默认 0.0)
- KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60)
- KG_LLM_MAX_RETRIES:重试次数(默认 2)
2026-02-17 22:01:06 +08:00

229 lines
7.6 KiB
Python

"""基于 LLM 的知识图谱三元组抽取器。
利用 LangChain 的 LLMGraphTransformer 从非结构化文本中抽取实体和关系,
支持 schema-guided 抽取以提升准确率。
"""
from __future__ import annotations
import hashlib
from typing import Sequence
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from pydantic import SecretStr
from app.core.logging import get_logger
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
GraphEdge,
GraphNode,
Triple,
)
logger = get_logger(__name__)
def _text_fingerprint(text: str) -> str:
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
class KnowledgeGraphExtractor:
"""基于 LLMGraphTransformer 的三元组抽取器。
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
也可直接构造以覆盖默认参数。
"""
def __init__(
self,
model_name: str = "gpt-4o-mini",
base_url: str | None = None,
api_key: SecretStr = SecretStr("EMPTY"),
temperature: float = 0.0,
timeout: int = 60,
max_retries: int = 2,
) -> None:
logger.info(
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
model_name,
base_url or "default",
timeout,
max_retries,
)
self._llm = ChatOpenAI(
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature,
timeout=timeout,
max_retries=max_retries,
)
@classmethod
def from_settings(cls) -> KnowledgeGraphExtractor:
"""从全局 Settings 创建抽取器实例。"""
from app.core.config import settings
return cls(
model_name=settings.kg_llm_model,
base_url=settings.kg_llm_base_url,
api_key=settings.kg_llm_api_key,
temperature=settings.kg_llm_temperature,
timeout=settings.kg_llm_timeout_seconds,
max_retries=settings.kg_llm_max_retries,
)
def _build_transformer(
self,
schema: ExtractionSchema | None = None,
) -> LLMGraphTransformer:
"""根据可选的 schema 约束构造 LLMGraphTransformer。"""
kwargs: dict = {"llm": self._llm}
if schema:
if schema.entity_types:
kwargs["allowed_nodes"] = [et.name for et in schema.entity_types]
if schema.relation_types:
kwargs["allowed_relationships"] = [rt.name for rt in schema.relation_types]
return LLMGraphTransformer(**kwargs)
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
"""从文本中异步抽取三元组。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = await transformer.aconvert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
logger.info(
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
len(result.triples),
)
return result
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
"""同步版本的三元组抽取。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = transformer.convert_to_graph_documents(documents)
except Exception:
logger.exception(
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
result = self._convert_result(graph_documents, request)
logger.info(
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
)
return result
async def extract_batch(
self,
requests: Sequence[ExtractionRequest],
) -> list[ExtractionResult]:
"""批量抽取,逐条处理。
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
"""
logger.info("Starting batch extraction: count=%d", len(requests))
results: list[ExtractionResult] = []
for i, req in enumerate(requests):
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
result = await self.extract(req)
results.append(result)
logger.info("Batch extraction complete: count=%d", len(results))
return results
@staticmethod
def _convert_result(
graph_documents: list,
request: ExtractionRequest,
) -> ExtractionResult:
"""将 LangChain GraphDocument 转换为内部数据模型。"""
nodes: list[GraphNode] = []
edges: list[GraphEdge] = []
triples: list[Triple] = []
seen_nodes: set[str] = set()
for doc in graph_documents:
for node in doc.nodes:
node_key = f"{node.id}:{node.type}"
if node_key not in seen_nodes:
seen_nodes.add(node_key)
nodes.append(
GraphNode(
name=node.id,
type=node.type,
properties=node.properties if hasattr(node, "properties") else {},
)
)
for rel in doc.relationships:
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
edges.append(
GraphEdge(
source=rel.source.id,
target=rel.target.id,
relation_type=rel.type,
properties=rel.properties if hasattr(rel, "properties") else {},
)
)
triples.append(
Triple(subject=source_node, predicate=rel.type, object=target_node)
)
return ExtractionResult(
nodes=nodes,
edges=edges,
triples=triples,
raw_text=request.text,
source_id=request.source_id,
)