feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口

实现功能:
- 创建 kg_extraction/interface.py(FastAPI 路由)
- 实现 POST /api/kg/extract(单条文本抽取)
- 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条)
- 集成到 FastAPI 主路由(/api/kg/ 前缀)

技术实现:
- 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature)
- 安全性:
  - API Key 使用 SecretStr 保护
  - 错误信息脱敏(使用 trace_id,不暴露原始异常)
  - 请求文本不写入日志(使用 SHA-256 hash)
  - 强制要求 X-User-Id 头(鉴权边界)
- 超时控制:
  - kg_llm_timeout_seconds(60秒)
  - kg_llm_max_retries(2次)
- 输入校验:
  - graph_id 和 source_id 使用 UUID pattern
  - source_type 使用 Enum(4个值)
  - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符)
- 审计日志:记录 caller、trace_id、text_hash

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有问题已解决(5个 P1/P2 + 3个 P3)
- 语法检查通过

API 端点:
- POST /api/kg/extract:单条文本抽取
- POST /api/kg/extract/batch:批量抽取(最多 50 条)

配置环境变量:
- KG_LLM_API_KEY:LLM API 密钥
- KG_LLM_BASE_URL:自定义端点(可选)
- KG_LLM_MODEL:模型名称(默认 gpt-4o-mini)
- KG_LLM_TEMPERATURE:生成温度(默认 0.0)
- KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60)
- KG_LLM_MAX_RETRIES:重试次数(默认 2)
This commit is contained in:
2026-02-17 22:01:06 +08:00
parent 5a553ddde3
commit 0e0782a452
5 changed files with 302 additions and 52 deletions

View File

@@ -6,6 +6,7 @@ from app.module.kg_extraction.models import (
GraphNode,
GraphEdge,
)
from app.module.kg_extraction.interface import router
__all__ = [
"KnowledgeGraphExtractor",
@@ -14,4 +15,5 @@ __all__ = [
"Triple",
"GraphNode",
"GraphEdge",
"router",
]

View File

@@ -6,13 +6,15 @@
from __future__ import annotations
import logging
import hashlib
from typing import Sequence
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from pydantic import SecretStr
from app.core.logging import get_logger
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
@@ -22,36 +24,58 @@ from app.module.kg_extraction.models import (
Triple,
)
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
def _text_fingerprint(text: str) -> str:
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
class KnowledgeGraphExtractor:
"""基于 LLMGraphTransformer 的三元组抽取器。
Parameters
----------
model_name : str
OpenAI 兼容模型名称。
base_url : str | None
自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。
api_key : str
API 密钥。
temperature : float
生成温度,抽取任务建议使用较低值。
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
也可直接构造以覆盖默认参数。
"""
def __init__(
self,
model_name: str = "gpt-4o-mini",
base_url: str | None = None,
api_key: str = "EMPTY",
api_key: SecretStr = SecretStr("EMPTY"),
temperature: float = 0.0,
timeout: int = 60,
max_retries: int = 2,
) -> None:
logger.info(
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
model_name,
base_url or "default",
timeout,
max_retries,
)
self._llm = ChatOpenAI(
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature,
timeout=timeout,
max_retries=max_retries,
)
@classmethod
def from_settings(cls) -> KnowledgeGraphExtractor:
"""从全局 Settings 创建抽取器实例。"""
from app.core.config import settings
return cls(
model_name=settings.kg_llm_model,
base_url=settings.kg_llm_base_url,
api_key=settings.kg_llm_api_key,
temperature=settings.kg_llm_temperature,
timeout=settings.kg_llm_timeout_seconds,
max_retries=settings.kg_llm_max_retries,
)
def _build_transformer(
@@ -70,55 +94,89 @@ class KnowledgeGraphExtractor:
return LLMGraphTransformer(**kwargs)
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
"""从文本中抽取三元组。
"""从文本中异步抽取三元组。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
Parameters
----------
request : ExtractionRequest
包含文本、schema 约束等信息的抽取请求。
Returns
-------
ExtractionResult
抽取得到的节点、边和三元组。
"""
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = await transformer.aconvert_to_graph_documents(documents)
except Exception:
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
logger.exception(
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
return self._convert_result(graph_documents, request)
result = self._convert_result(graph_documents, request)
logger.info(
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
len(result.triples),
)
return result
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
"""同步版本的三元组抽取。"""
text_hash = _text_fingerprint(request.text)
logger.info(
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
request.graph_id,
request.source_id,
len(request.text),
text_hash,
)
transformer = self._build_transformer(request.schema)
documents = [Document(page_content=request.text)]
try:
graph_documents = transformer.convert_to_graph_documents(documents)
except Exception:
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
logger.exception(
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
request.graph_id,
request.source_id,
text_hash,
)
raise
return self._convert_result(graph_documents, request)
result = self._convert_result(graph_documents, request)
logger.info(
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
request.graph_id,
len(result.nodes),
len(result.edges),
)
return result
async def extract_batch(
self,
requests: Sequence[ExtractionRequest],
) -> list[ExtractionResult]:
"""批量抽取。
"""批量抽取,逐条处理
对多段文本逐一抽取并汇总结果
如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract
"""
logger.info("Starting batch extraction: count=%d", len(requests))
results: list[ExtractionResult] = []
for req in requests:
for i, req in enumerate(requests):
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
result = await self.extract(req)
results.append(result)
logger.info("Batch extraction complete: count=%d", len(results))
return results
@staticmethod
@@ -133,7 +191,6 @@ class KnowledgeGraphExtractor:
seen_nodes: set[str] = set()
for doc in graph_documents:
# 收集节点
for node in doc.nodes:
node_key = f"{node.id}:{node.type}"
if node_key not in seen_nodes:
@@ -146,16 +203,9 @@ class KnowledgeGraphExtractor:
)
)
# 收集关系
for rel in doc.relationships:
source_node = GraphNode(
name=rel.source.id,
type=rel.source.type,
)
target_node = GraphNode(
name=rel.target.id,
type=rel.target.type,
)
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
edges.append(
GraphEdge(
@@ -165,13 +215,8 @@ class KnowledgeGraphExtractor:
properties=rel.properties if hasattr(rel, "properties") else {},
)
)
triples.append(
Triple(
subject=source_node,
predicate=rel.type,
object=target_node,
)
Triple(subject=source_node, predicate=rel.type, object=target_node)
)
return ExtractionResult(

View File

@@ -0,0 +1,193 @@
"""知识图谱三元组抽取 API。
注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用,
外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。
当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from app.core.logging import get_logger
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
EntityTypeConstraint,
RelationTypeConstraint,
)
from app.module.shared.schema import StandardResponse
router = APIRouter(prefix="/kg", tags=["knowledge-graph"])
logger = get_logger(__name__)
# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM
_extractor: KnowledgeGraphExtractor | None = None
_UUID_PATTERN = (
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符
_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$"
def _get_extractor() -> KnowledgeGraphExtractor:
global _extractor
if _extractor is None:
_extractor = KnowledgeGraphExtractor.from_settings()
return _extractor
def _require_caller_id(
x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")],
) -> str:
"""从请求头提取调用方用户 ID,用于审计日志。
该接口为内部服务调用,调用方身份由上游 Java 后端通过
X-User-Id 请求头传递。缺失或为空时返回 401。
"""
caller = x_user_id.strip()
if not caller:
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
return caller
# ---------------------------------------------------------------------------
# Request / Response DTO(API 层,与内部 models 解耦)
# ---------------------------------------------------------------------------
class SourceType(str, Enum):
ANNOTATION = "ANNOTATION"
KNOWLEDGE_BASE = "KNOWLEDGE_BASE"
IMPORT = "IMPORT"
MANUAL = "MANUAL"
class ExtractRequest(BaseModel):
"""三元组抽取请求。"""
text: str = Field(
...,
min_length=1,
max_length=50000,
description="待抽取的文本内容",
examples=["张三是北京大学的教授,研究方向为人工智能。"],
)
graph_id: str = Field(
...,
pattern=_UUID_PATTERN,
description="目标图谱 ID(UUID 格式)",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符",
examples=[["Person", "Organization", "Location"]],
)
allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的关系类型列表(schema-guided 抽取)",
examples=[["works_at", "located_in"]],
)
source_id: Optional[str] = Field(
default=None,
pattern=_UUID_PATTERN,
description="来源 ID(数据集/知识库条目,UUID 格式)",
)
source_type: SourceType = Field(
default=SourceType.KNOWLEDGE_BASE,
description="来源类型",
)
class BatchExtractRequest(BaseModel):
"""批量三元组抽取请求。"""
items: list[ExtractRequest] = Field(
...,
min_length=1,
max_length=50,
description="抽取请求列表,单次最多 50 条",
)
def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest:
"""将 API DTO 转换为内部抽取请求。"""
schema: ExtractionSchema | None = None
if req.allowed_nodes or req.allowed_relationships:
schema = ExtractionSchema(
entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])],
relation_types=[
RelationTypeConstraint(name=r) for r in (req.allowed_relationships or [])
],
)
return ExtractionRequest(
text=req.text,
graph_id=req.graph_id,
schema=schema,
source_id=req.source_id,
source_type=req.source_type.value,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post(
"/extract",
response_model=StandardResponse[ExtractionResult],
summary="三元组抽取",
description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。",
)
async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""单条文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
extractor = _get_extractor()
extraction_req = _to_extraction_request(req)
try:
result = await extractor.extract(extraction_req)
except Exception:
logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=result)
@router.post(
"/extract/batch",
response_model=StandardResponse[list[ExtractionResult]],
summary="批量三元组抽取",
description="对多段文本逐条抽取三元组,单次最多 50 条。",
)
async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""批量文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller)
extractor = _get_extractor()
extraction_reqs = [_to_extraction_request(item) for item in req.items]
try:
results = await extractor.extract_batch(extraction_reqs)
except Exception:
logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=results)