feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口

实现功能:
- 创建 kg_extraction/interface.py(FastAPI 路由)
- 实现 POST /api/kg/extract(单条文本抽取)
- 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条)
- 集成到 FastAPI 主路由(/api/kg/ 前缀)

技术实现:
- 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature)
- 安全性:
  - API Key 使用 SecretStr 保护
  - 错误信息脱敏(使用 trace_id,不暴露原始异常)
  - 请求文本不写入日志(使用 SHA-256 hash)
  - 强制要求 X-User-Id 头(鉴权边界)
- 超时控制:
  - kg_llm_timeout_seconds(60秒)
  - kg_llm_max_retries(2次)
- 输入校验:
  - graph_id 和 source_id 使用 UUID pattern
  - source_type 使用 Enum(4个值)
  - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符)
- 审计日志:记录 caller、trace_id、text_hash

代码审查:
- 经过 3 轮 Codex 审查和 2 轮 Claude 修复
- 所有问题已解决(5个 P1/P2 + 3个 P3)
- 语法检查通过

API 端点:
- POST /api/kg/extract:单条文本抽取
- POST /api/kg/extract/batch:批量抽取(最多 50 条)

配置环境变量:
- KG_LLM_API_KEY:LLM API 密钥
- KG_LLM_BASE_URL:自定义端点(可选)
- KG_LLM_MODEL:模型名称(默认 gpt-4o-mini)
- KG_LLM_TEMPERATURE:生成温度(默认 0.0)
- KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60)
- KG_LLM_MAX_RETRIES:重试次数(默认 2)
This commit is contained in:
2026-02-17 22:01:06 +08:00
parent 5a553ddde3
commit 0e0782a452
5 changed files with 302 additions and 52 deletions

View File

@@ -0,0 +1,193 @@
"""知识图谱三元组抽取 API。
注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用,
外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。
当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。
"""
from __future__ import annotations
import uuid
from enum import Enum
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from app.core.logging import get_logger
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
from app.module.kg_extraction.models import (
ExtractionRequest,
ExtractionResult,
ExtractionSchema,
EntityTypeConstraint,
RelationTypeConstraint,
)
from app.module.shared.schema import StandardResponse
router = APIRouter(prefix="/kg", tags=["knowledge-graph"])
logger = get_logger(__name__)
# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM
_extractor: KnowledgeGraphExtractor | None = None
_UUID_PATTERN = (
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
)
# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符
_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$"
def _get_extractor() -> KnowledgeGraphExtractor:
global _extractor
if _extractor is None:
_extractor = KnowledgeGraphExtractor.from_settings()
return _extractor
def _require_caller_id(
x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")],
) -> str:
"""从请求头提取调用方用户 ID,用于审计日志。
该接口为内部服务调用,调用方身份由上游 Java 后端通过
X-User-Id 请求头传递。缺失或为空时返回 401。
"""
caller = x_user_id.strip()
if not caller:
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
return caller
# ---------------------------------------------------------------------------
# Request / Response DTO(API 层,与内部 models 解耦)
# ---------------------------------------------------------------------------
class SourceType(str, Enum):
ANNOTATION = "ANNOTATION"
KNOWLEDGE_BASE = "KNOWLEDGE_BASE"
IMPORT = "IMPORT"
MANUAL = "MANUAL"
class ExtractRequest(BaseModel):
"""三元组抽取请求。"""
text: str = Field(
...,
min_length=1,
max_length=50000,
description="待抽取的文本内容",
examples=["张三是北京大学的教授,研究方向为人工智能。"],
)
graph_id: str = Field(
...,
pattern=_UUID_PATTERN,
description="目标图谱 ID(UUID 格式)",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符",
examples=[["Person", "Organization", "Location"]],
)
allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
default=None,
max_length=50,
description="允许的关系类型列表(schema-guided 抽取)",
examples=[["works_at", "located_in"]],
)
source_id: Optional[str] = Field(
default=None,
pattern=_UUID_PATTERN,
description="来源 ID(数据集/知识库条目,UUID 格式)",
)
source_type: SourceType = Field(
default=SourceType.KNOWLEDGE_BASE,
description="来源类型",
)
class BatchExtractRequest(BaseModel):
"""批量三元组抽取请求。"""
items: list[ExtractRequest] = Field(
...,
min_length=1,
max_length=50,
description="抽取请求列表,单次最多 50 条",
)
def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest:
"""将 API DTO 转换为内部抽取请求。"""
schema: ExtractionSchema | None = None
if req.allowed_nodes or req.allowed_relationships:
schema = ExtractionSchema(
entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])],
relation_types=[
RelationTypeConstraint(name=r) for r in (req.allowed_relationships or [])
],
)
return ExtractionRequest(
text=req.text,
graph_id=req.graph_id,
schema=schema,
source_id=req.source_id,
source_type=req.source_type.value,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post(
"/extract",
response_model=StandardResponse[ExtractionResult],
summary="三元组抽取",
description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。",
)
async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""单条文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
extractor = _get_extractor()
extraction_req = _to_extraction_request(req)
try:
result = await extractor.extract(extraction_req)
except Exception:
logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=result)
@router.post(
"/extract/batch",
response_model=StandardResponse[list[ExtractionResult]],
summary="批量三元组抽取",
description="对多段文本逐条抽取三元组,单次最多 50 条。",
)
async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
"""批量文本三元组抽取。"""
trace_id = uuid.uuid4().hex[:16]
logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller)
extractor = _get_extractor()
extraction_reqs = [_to_extraction_request(item) for item in req.items]
try:
results = await extractor.extract_batch(extraction_reqs)
except Exception:
logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller)
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
return StandardResponse(code=200, message="success", data=results)