You've already forked DataMate
feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口
实现功能: - 创建 kg_extraction/interface.py(FastAPI 路由) - 实现 POST /api/kg/extract(单条文本抽取) - 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条) - 集成到 FastAPI 主路由(/api/kg/ 前缀) 技术实现: - 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature) - 安全性: - API Key 使用 SecretStr 保护 - 错误信息脱敏(使用 trace_id,不暴露原始异常) - 请求文本不写入日志(使用 SHA-256 hash) - 强制要求 X-User-Id 头(鉴权边界) - 超时控制: - kg_llm_timeout_seconds(60秒) - kg_llm_max_retries(2次) - 输入校验: - graph_id 和 source_id 使用 UUID pattern - source_type 使用 Enum(4个值) - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符) - 审计日志:记录 caller、trace_id、text_hash 代码审查: - 经过 3 轮 Codex 审查和 2 轮 Claude 修复 - 所有问题已解决(5个 P1/P2 + 3个 P3) - 语法检查通过 API 端点: - POST /api/kg/extract:单条文本抽取 - POST /api/kg/extract/batch:批量抽取(最多 50 条) 配置环境变量: - KG_LLM_API_KEY:LLM API 密钥 - KG_LLM_BASE_URL:自定义端点(可选) - KG_LLM_MODEL:模型名称(默认 gpt-4o-mini) - KG_LLM_TEMPERATURE:生成温度(默认 0.0) - KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60) - KG_LLM_MAX_RETRIES:重试次数(默认 2)
This commit is contained in:
@@ -6,13 +6,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||
from pydantic import SecretStr
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_extraction.models import (
|
||||
ExtractionRequest,
|
||||
ExtractionResult,
|
||||
@@ -22,36 +24,58 @@ from app.module.kg_extraction.models import (
|
||||
Triple,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _text_fingerprint(text: str) -> str:
|
||||
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
class KnowledgeGraphExtractor:
|
||||
"""基于 LLMGraphTransformer 的三元组抽取器。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
OpenAI 兼容模型名称。
|
||||
base_url : str | None
|
||||
自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。
|
||||
api_key : str
|
||||
API 密钥。
|
||||
temperature : float
|
||||
生成温度,抽取任务建议使用较低值。
|
||||
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
||||
也可直接构造以覆盖默认参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-4o-mini",
|
||||
base_url: str | None = None,
|
||||
api_key: str = "EMPTY",
|
||||
api_key: SecretStr = SecretStr("EMPTY"),
|
||||
temperature: float = 0.0,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 2,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
||||
model_name,
|
||||
base_url or "default",
|
||||
timeout,
|
||||
max_retries,
|
||||
)
|
||||
self._llm = ChatOpenAI(
|
||||
model=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KnowledgeGraphExtractor:
|
||||
"""从全局 Settings 创建抽取器实例。"""
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(
|
||||
model_name=settings.kg_llm_model,
|
||||
base_url=settings.kg_llm_base_url,
|
||||
api_key=settings.kg_llm_api_key,
|
||||
temperature=settings.kg_llm_temperature,
|
||||
timeout=settings.kg_llm_timeout_seconds,
|
||||
max_retries=settings.kg_llm_max_retries,
|
||||
)
|
||||
|
||||
def _build_transformer(
|
||||
@@ -70,55 +94,89 @@ class KnowledgeGraphExtractor:
|
||||
return LLMGraphTransformer(**kwargs)
|
||||
|
||||
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
||||
"""从文本中抽取三元组。
|
||||
"""从文本中异步抽取三元组。"""
|
||||
text_hash = _text_fingerprint(request.text)
|
||||
logger.info(
|
||||
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
len(request.text),
|
||||
text_hash,
|
||||
)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request : ExtractionRequest
|
||||
包含文本、schema 约束等信息的抽取请求。
|
||||
|
||||
Returns
|
||||
-------
|
||||
ExtractionResult
|
||||
抽取得到的节点、边和三元组。
|
||||
"""
|
||||
transformer = self._build_transformer(request.schema)
|
||||
documents = [Document(page_content=request.text)]
|
||||
|
||||
try:
|
||||
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
||||
except Exception:
|
||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
||||
logger.exception(
|
||||
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
text_hash,
|
||||
)
|
||||
raise
|
||||
|
||||
return self._convert_result(graph_documents, request)
|
||||
result = self._convert_result(graph_documents, request)
|
||||
logger.info(
|
||||
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
||||
request.graph_id,
|
||||
len(result.nodes),
|
||||
len(result.edges),
|
||||
len(result.triples),
|
||||
)
|
||||
return result
|
||||
|
||||
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
||||
"""同步版本的三元组抽取。"""
|
||||
text_hash = _text_fingerprint(request.text)
|
||||
logger.info(
|
||||
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
len(request.text),
|
||||
text_hash,
|
||||
)
|
||||
|
||||
transformer = self._build_transformer(request.schema)
|
||||
documents = [Document(page_content=request.text)]
|
||||
|
||||
try:
|
||||
graph_documents = transformer.convert_to_graph_documents(documents)
|
||||
except Exception:
|
||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
||||
logger.exception(
|
||||
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
text_hash,
|
||||
)
|
||||
raise
|
||||
|
||||
return self._convert_result(graph_documents, request)
|
||||
result = self._convert_result(graph_documents, request)
|
||||
logger.info(
|
||||
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
||||
request.graph_id,
|
||||
len(result.nodes),
|
||||
len(result.edges),
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_batch(
|
||||
self,
|
||||
requests: Sequence[ExtractionRequest],
|
||||
) -> list[ExtractionResult]:
|
||||
"""批量抽取。
|
||||
"""批量抽取,逐条处理。
|
||||
|
||||
对多段文本逐一抽取并汇总结果。
|
||||
如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。
|
||||
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
|
||||
"""
|
||||
logger.info("Starting batch extraction: count=%d", len(requests))
|
||||
results: list[ExtractionResult] = []
|
||||
for req in requests:
|
||||
for i, req in enumerate(requests):
|
||||
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
|
||||
result = await self.extract(req)
|
||||
results.append(result)
|
||||
logger.info("Batch extraction complete: count=%d", len(results))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
@@ -133,7 +191,6 @@ class KnowledgeGraphExtractor:
|
||||
seen_nodes: set[str] = set()
|
||||
|
||||
for doc in graph_documents:
|
||||
# 收集节点
|
||||
for node in doc.nodes:
|
||||
node_key = f"{node.id}:{node.type}"
|
||||
if node_key not in seen_nodes:
|
||||
@@ -146,16 +203,9 @@ class KnowledgeGraphExtractor:
|
||||
)
|
||||
)
|
||||
|
||||
# 收集关系
|
||||
for rel in doc.relationships:
|
||||
source_node = GraphNode(
|
||||
name=rel.source.id,
|
||||
type=rel.source.type,
|
||||
)
|
||||
target_node = GraphNode(
|
||||
name=rel.target.id,
|
||||
type=rel.target.type,
|
||||
)
|
||||
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
|
||||
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
|
||||
|
||||
edges.append(
|
||||
GraphEdge(
|
||||
@@ -165,13 +215,8 @@ class KnowledgeGraphExtractor:
|
||||
properties=rel.properties if hasattr(rel, "properties") else {},
|
||||
)
|
||||
)
|
||||
|
||||
triples.append(
|
||||
Triple(
|
||||
subject=source_node,
|
||||
predicate=rel.type,
|
||||
object=target_node,
|
||||
)
|
||||
Triple(subject=source_node, predicate=rel.type, object=target_node)
|
||||
)
|
||||
|
||||
return ExtractionResult(
|
||||
|
||||
Reference in New Issue
Block a user