You've already forked DataMate
feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口
实现功能: - 创建 kg_extraction/interface.py(FastAPI 路由) - 实现 POST /api/kg/extract(单条文本抽取) - 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条) - 集成到 FastAPI 主路由(/api/kg/ 前缀) 技术实现: - 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature) - 安全性: - API Key 使用 SecretStr 保护 - 错误信息脱敏(使用 trace_id,不暴露原始异常) - 请求文本不写入日志(使用 SHA-256 hash) - 强制要求 X-User-Id 头(鉴权边界) - 超时控制: - kg_llm_timeout_seconds(60秒) - kg_llm_max_retries(2次) - 输入校验: - graph_id 和 source_id 使用 UUID pattern - source_type 使用 Enum(4个值) - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符) - 审计日志:记录 caller、trace_id、text_hash 代码审查: - 经过 3 轮 Codex 审查和 2 轮 Claude 修复 - 所有问题已解决(5个 P1/P2 + 3个 P3) - 语法检查通过 API 端点: - POST /api/kg/extract:单条文本抽取 - POST /api/kg/extract/batch:批量抽取(最多 50 条) 配置环境变量: - KG_LLM_API_KEY:LLM API 密钥 - KG_LLM_BASE_URL:自定义端点(可选) - KG_LLM_MODEL:模型名称(默认 gpt-4o-mini) - KG_LLM_TEMPERATURE:生成温度(默认 0.0) - KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60) - KG_LLM_MAX_RETRIES:重试次数(默认 2)
This commit is contained in:
@@ -7,6 +7,7 @@ from .generation.interface import router as generation_router
|
||||
from .evaluation.interface import router as evaluation_router
|
||||
from .collection.interface import router as collection_route
|
||||
from .dataset.interface import router as dataset_router
|
||||
from .kg_extraction.interface import router as kg_extraction_router
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api"
|
||||
@@ -19,5 +20,6 @@ router.include_router(generation_router)
|
||||
router.include_router(evaluation_router)
|
||||
router.include_router(collection_route)
|
||||
router.include_router(dataset_router)
|
||||
router.include_router(kg_extraction_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.module.kg_extraction.models import (
|
||||
GraphNode,
|
||||
GraphEdge,
|
||||
)
|
||||
from app.module.kg_extraction.interface import router
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeGraphExtractor",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"Triple",
|
||||
"GraphNode",
|
||||
"GraphEdge",
|
||||
"router",
|
||||
]
|
||||
|
||||
@@ -6,13 +6,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||
from pydantic import SecretStr
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_extraction.models import (
|
||||
ExtractionRequest,
|
||||
ExtractionResult,
|
||||
@@ -22,36 +24,58 @@ from app.module.kg_extraction.models import (
|
||||
Triple,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _text_fingerprint(text: str) -> str:
|
||||
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
class KnowledgeGraphExtractor:
|
||||
"""基于 LLMGraphTransformer 的三元组抽取器。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str
|
||||
OpenAI 兼容模型名称。
|
||||
base_url : str | None
|
||||
自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。
|
||||
api_key : str
|
||||
API 密钥。
|
||||
temperature : float
|
||||
生成温度,抽取任务建议使用较低值。
|
||||
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
||||
也可直接构造以覆盖默认参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-4o-mini",
|
||||
base_url: str | None = None,
|
||||
api_key: str = "EMPTY",
|
||||
api_key: SecretStr = SecretStr("EMPTY"),
|
||||
temperature: float = 0.0,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 2,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
||||
model_name,
|
||||
base_url or "default",
|
||||
timeout,
|
||||
max_retries,
|
||||
)
|
||||
self._llm = ChatOpenAI(
|
||||
model=model_name,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls) -> KnowledgeGraphExtractor:
|
||||
"""从全局 Settings 创建抽取器实例。"""
|
||||
from app.core.config import settings
|
||||
|
||||
return cls(
|
||||
model_name=settings.kg_llm_model,
|
||||
base_url=settings.kg_llm_base_url,
|
||||
api_key=settings.kg_llm_api_key,
|
||||
temperature=settings.kg_llm_temperature,
|
||||
timeout=settings.kg_llm_timeout_seconds,
|
||||
max_retries=settings.kg_llm_max_retries,
|
||||
)
|
||||
|
||||
def _build_transformer(
|
||||
@@ -70,55 +94,89 @@ class KnowledgeGraphExtractor:
|
||||
return LLMGraphTransformer(**kwargs)
|
||||
|
||||
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
||||
"""从文本中抽取三元组。
|
||||
"""从文本中异步抽取三元组。"""
|
||||
text_hash = _text_fingerprint(request.text)
|
||||
logger.info(
|
||||
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
len(request.text),
|
||||
text_hash,
|
||||
)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request : ExtractionRequest
|
||||
包含文本、schema 约束等信息的抽取请求。
|
||||
|
||||
Returns
|
||||
-------
|
||||
ExtractionResult
|
||||
抽取得到的节点、边和三元组。
|
||||
"""
|
||||
transformer = self._build_transformer(request.schema)
|
||||
documents = [Document(page_content=request.text)]
|
||||
|
||||
try:
|
||||
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
||||
except Exception:
|
||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
||||
logger.exception(
|
||||
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
text_hash,
|
||||
)
|
||||
raise
|
||||
|
||||
return self._convert_result(graph_documents, request)
|
||||
result = self._convert_result(graph_documents, request)
|
||||
logger.info(
|
||||
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
||||
request.graph_id,
|
||||
len(result.nodes),
|
||||
len(result.edges),
|
||||
len(result.triples),
|
||||
)
|
||||
return result
|
||||
|
||||
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
||||
"""同步版本的三元组抽取。"""
|
||||
text_hash = _text_fingerprint(request.text)
|
||||
logger.info(
|
||||
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
len(request.text),
|
||||
text_hash,
|
||||
)
|
||||
|
||||
transformer = self._build_transformer(request.schema)
|
||||
documents = [Document(page_content=request.text)]
|
||||
|
||||
try:
|
||||
graph_documents = transformer.convert_to_graph_documents(documents)
|
||||
except Exception:
|
||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
||||
logger.exception(
|
||||
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||
request.graph_id,
|
||||
request.source_id,
|
||||
text_hash,
|
||||
)
|
||||
raise
|
||||
|
||||
return self._convert_result(graph_documents, request)
|
||||
result = self._convert_result(graph_documents, request)
|
||||
logger.info(
|
||||
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
||||
request.graph_id,
|
||||
len(result.nodes),
|
||||
len(result.edges),
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_batch(
|
||||
self,
|
||||
requests: Sequence[ExtractionRequest],
|
||||
) -> list[ExtractionResult]:
|
||||
"""批量抽取。
|
||||
"""批量抽取,逐条处理。
|
||||
|
||||
对多段文本逐一抽取并汇总结果。
|
||||
如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。
|
||||
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
|
||||
"""
|
||||
logger.info("Starting batch extraction: count=%d", len(requests))
|
||||
results: list[ExtractionResult] = []
|
||||
for req in requests:
|
||||
for i, req in enumerate(requests):
|
||||
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
|
||||
result = await self.extract(req)
|
||||
results.append(result)
|
||||
logger.info("Batch extraction complete: count=%d", len(results))
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
@@ -133,7 +191,6 @@ class KnowledgeGraphExtractor:
|
||||
seen_nodes: set[str] = set()
|
||||
|
||||
for doc in graph_documents:
|
||||
# 收集节点
|
||||
for node in doc.nodes:
|
||||
node_key = f"{node.id}:{node.type}"
|
||||
if node_key not in seen_nodes:
|
||||
@@ -146,16 +203,9 @@ class KnowledgeGraphExtractor:
|
||||
)
|
||||
)
|
||||
|
||||
# 收集关系
|
||||
for rel in doc.relationships:
|
||||
source_node = GraphNode(
|
||||
name=rel.source.id,
|
||||
type=rel.source.type,
|
||||
)
|
||||
target_node = GraphNode(
|
||||
name=rel.target.id,
|
||||
type=rel.target.type,
|
||||
)
|
||||
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
|
||||
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
|
||||
|
||||
edges.append(
|
||||
GraphEdge(
|
||||
@@ -165,13 +215,8 @@ class KnowledgeGraphExtractor:
|
||||
properties=rel.properties if hasattr(rel, "properties") else {},
|
||||
)
|
||||
)
|
||||
|
||||
triples.append(
|
||||
Triple(
|
||||
subject=source_node,
|
||||
predicate=rel.type,
|
||||
object=target_node,
|
||||
)
|
||||
Triple(subject=source_node, predicate=rel.type, object=target_node)
|
||||
)
|
||||
|
||||
return ExtractionResult(
|
||||
|
||||
193
runtime/datamate-python/app/module/kg_extraction/interface.py
Normal file
193
runtime/datamate-python/app/module/kg_extraction/interface.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""知识图谱三元组抽取 API。
|
||||
|
||||
注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用,
|
||||
外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。
|
||||
当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
||||
from app.module.kg_extraction.models import (
|
||||
ExtractionRequest,
|
||||
ExtractionResult,
|
||||
ExtractionSchema,
|
||||
EntityTypeConstraint,
|
||||
RelationTypeConstraint,
|
||||
)
|
||||
from app.module.shared.schema import StandardResponse
|
||||
|
||||
router = APIRouter(prefix="/kg", tags=["knowledge-graph"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM
|
||||
_extractor: KnowledgeGraphExtractor | None = None
|
||||
|
||||
_UUID_PATTERN = (
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符
|
||||
_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$"
|
||||
|
||||
|
||||
def _get_extractor() -> KnowledgeGraphExtractor:
|
||||
global _extractor
|
||||
if _extractor is None:
|
||||
_extractor = KnowledgeGraphExtractor.from_settings()
|
||||
return _extractor
|
||||
|
||||
|
||||
def _require_caller_id(
|
||||
x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")],
|
||||
) -> str:
|
||||
"""从请求头提取调用方用户 ID,用于审计日志。
|
||||
|
||||
该接口为内部服务调用,调用方身份由上游 Java 后端通过
|
||||
X-User-Id 请求头传递。缺失或为空时返回 401。
|
||||
"""
|
||||
caller = x_user_id.strip()
|
||||
if not caller:
|
||||
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
||||
return caller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response DTO(API 层,与内部 models 解耦)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
ANNOTATION = "ANNOTATION"
|
||||
KNOWLEDGE_BASE = "KNOWLEDGE_BASE"
|
||||
IMPORT = "IMPORT"
|
||||
MANUAL = "MANUAL"
|
||||
|
||||
|
||||
class ExtractRequest(BaseModel):
|
||||
"""三元组抽取请求。"""
|
||||
|
||||
text: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=50000,
|
||||
description="待抽取的文本内容",
|
||||
examples=["张三是北京大学的教授,研究方向为人工智能。"],
|
||||
)
|
||||
graph_id: str = Field(
|
||||
...,
|
||||
pattern=_UUID_PATTERN,
|
||||
description="目标图谱 ID(UUID 格式)",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
|
||||
default=None,
|
||||
max_length=50,
|
||||
description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符",
|
||||
examples=[["Person", "Organization", "Location"]],
|
||||
)
|
||||
allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
|
||||
default=None,
|
||||
max_length=50,
|
||||
description="允许的关系类型列表(schema-guided 抽取)",
|
||||
examples=[["works_at", "located_in"]],
|
||||
)
|
||||
source_id: Optional[str] = Field(
|
||||
default=None,
|
||||
pattern=_UUID_PATTERN,
|
||||
description="来源 ID(数据集/知识库条目,UUID 格式)",
|
||||
)
|
||||
source_type: SourceType = Field(
|
||||
default=SourceType.KNOWLEDGE_BASE,
|
||||
description="来源类型",
|
||||
)
|
||||
|
||||
|
||||
class BatchExtractRequest(BaseModel):
|
||||
"""批量三元组抽取请求。"""
|
||||
|
||||
items: list[ExtractRequest] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=50,
|
||||
description="抽取请求列表,单次最多 50 条",
|
||||
)
|
||||
|
||||
|
||||
def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest:
|
||||
"""将 API DTO 转换为内部抽取请求。"""
|
||||
schema: ExtractionSchema | None = None
|
||||
if req.allowed_nodes or req.allowed_relationships:
|
||||
schema = ExtractionSchema(
|
||||
entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])],
|
||||
relation_types=[
|
||||
RelationTypeConstraint(name=r) for r in (req.allowed_relationships or [])
|
||||
],
|
||||
)
|
||||
|
||||
return ExtractionRequest(
|
||||
text=req.text,
|
||||
graph_id=req.graph_id,
|
||||
schema=schema,
|
||||
source_id=req.source_id,
|
||||
source_type=req.source_type.value,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extract",
|
||||
response_model=StandardResponse[ExtractionResult],
|
||||
summary="三元组抽取",
|
||||
description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。",
|
||||
)
|
||||
async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
|
||||
"""单条文本三元组抽取。"""
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
|
||||
|
||||
extractor = _get_extractor()
|
||||
extraction_req = _to_extraction_request(req)
|
||||
|
||||
try:
|
||||
result = await extractor.extract(extraction_req)
|
||||
except Exception:
|
||||
logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
|
||||
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
return StandardResponse(code=200, message="success", data=result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extract/batch",
|
||||
response_model=StandardResponse[list[ExtractionResult]],
|
||||
summary="批量三元组抽取",
|
||||
description="对多段文本逐条抽取三元组,单次最多 50 条。",
|
||||
)
|
||||
async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
|
||||
"""批量文本三元组抽取。"""
|
||||
trace_id = uuid.uuid4().hex[:16]
|
||||
logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller)
|
||||
|
||||
extractor = _get_extractor()
|
||||
extraction_reqs = [_to_extraction_request(item) for item in req.items]
|
||||
|
||||
try:
|
||||
results = await extractor.extract_batch(extraction_reqs)
|
||||
except Exception:
|
||||
logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller)
|
||||
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
|
||||
|
||||
return StandardResponse(code=200, message="success", data=results)
|
||||
Reference in New Issue
Block a user