You've already forked DataMate
feat(kg-extraction): 实现 Python 抽取器 FastAPI 接口
实现功能: - 创建 kg_extraction/interface.py(FastAPI 路由) - 实现 POST /api/kg/extract(单条文本抽取) - 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条) - 集成到 FastAPI 主路由(/api/kg/ 前缀) 技术实现: - 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature) - 安全性: - API Key 使用 SecretStr 保护 - 错误信息脱敏(使用 trace_id,不暴露原始异常) - 请求文本不写入日志(使用 SHA-256 hash) - 强制要求 X-User-Id 头(鉴权边界) - 超时控制: - kg_llm_timeout_seconds(60秒) - kg_llm_max_retries(2次) - 输入校验: - graph_id 和 source_id 使用 UUID pattern - source_type 使用 Enum(4个值) - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符) - 审计日志:记录 caller、trace_id、text_hash 代码审查: - 经过 3 轮 Codex 审查和 2 轮 Claude 修复 - 所有问题已解决(5个 P1/P2 + 3个 P3) - 语法检查通过 API 端点: - POST /api/kg/extract:单条文本抽取 - POST /api/kg/extract/batch:批量抽取(最多 50 条) 配置环境变量: - KG_LLM_API_KEY:LLM API 密钥 - KG_LLM_BASE_URL:自定义端点(可选) - KG_LLM_MODEL:模型名称(默认 gpt-4o-mini) - KG_LLM_TEMPERATURE:生成温度(默认 0.0) - KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60) - KG_LLM_MAX_RETRIES:重试次数(默认 2)
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from pydantic import model_validator
|
from pydantic import SecretStr, model_validator
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -62,9 +62,17 @@ class Settings(BaseSettings):
|
|||||||
# DataMate
|
# DataMate
|
||||||
dm_file_path_prefix: str = "/dataset" # DM存储文件夹前缀
|
dm_file_path_prefix: str = "/dataset" # DM存储文件夹前缀
|
||||||
|
|
||||||
# DataMate Backend (Java) - 用于通过“下载/预览接口”读取文件内容
|
# DataMate Backend (Java) - 用于通过"下载/预览接口"读取文件内容
|
||||||
datamate_backend_base_url: str = "http://datamate-backend:8080/api"
|
datamate_backend_base_url: str = "http://datamate-backend:8080/api"
|
||||||
|
|
||||||
|
# Knowledge Graph - LLM 三元组抽取配置
|
||||||
|
kg_llm_api_key: SecretStr = SecretStr("EMPTY")
|
||||||
|
kg_llm_base_url: Optional[str] = None
|
||||||
|
kg_llm_model: str = "gpt-4o-mini"
|
||||||
|
kg_llm_temperature: float = 0.0
|
||||||
|
kg_llm_timeout_seconds: int = 60
|
||||||
|
kg_llm_max_retries: int = 2
|
||||||
|
|
||||||
# 标注编辑器(Label Studio Editor)相关
|
# 标注编辑器(Label Studio Editor)相关
|
||||||
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .generation.interface import router as generation_router
|
|||||||
from .evaluation.interface import router as evaluation_router
|
from .evaluation.interface import router as evaluation_router
|
||||||
from .collection.interface import router as collection_route
|
from .collection.interface import router as collection_route
|
||||||
from .dataset.interface import router as dataset_router
|
from .dataset.interface import router as dataset_router
|
||||||
|
from .kg_extraction.interface import router as kg_extraction_router
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/api"
|
prefix="/api"
|
||||||
@@ -19,5 +20,6 @@ router.include_router(generation_router)
|
|||||||
router.include_router(evaluation_router)
|
router.include_router(evaluation_router)
|
||||||
router.include_router(collection_route)
|
router.include_router(collection_route)
|
||||||
router.include_router(dataset_router)
|
router.include_router(dataset_router)
|
||||||
|
router.include_router(kg_extraction_router)
|
||||||
|
|
||||||
__all__ = ["router"]
|
__all__ = ["router"]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from app.module.kg_extraction.models import (
|
|||||||
GraphNode,
|
GraphNode,
|
||||||
GraphEdge,
|
GraphEdge,
|
||||||
)
|
)
|
||||||
|
from app.module.kg_extraction.interface import router
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"KnowledgeGraphExtractor",
|
"KnowledgeGraphExtractor",
|
||||||
@@ -14,4 +15,5 @@ __all__ = [
|
|||||||
"Triple",
|
"Triple",
|
||||||
"GraphNode",
|
"GraphNode",
|
||||||
"GraphEdge",
|
"GraphEdge",
|
||||||
|
"router",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,13 +6,15 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import hashlib
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
from app.module.kg_extraction.models import (
|
from app.module.kg_extraction.models import (
|
||||||
ExtractionRequest,
|
ExtractionRequest,
|
||||||
ExtractionResult,
|
ExtractionResult,
|
||||||
@@ -22,36 +24,58 @@ from app.module.kg_extraction.models import (
|
|||||||
Triple,
|
Triple,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _text_fingerprint(text: str) -> str:
|
||||||
|
"""返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。"""
|
||||||
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeGraphExtractor:
|
class KnowledgeGraphExtractor:
|
||||||
"""基于 LLMGraphTransformer 的三元组抽取器。
|
"""基于 LLMGraphTransformer 的三元组抽取器。
|
||||||
|
|
||||||
Parameters
|
通过 ``from_settings()`` 工厂方法从全局配置创建实例,
|
||||||
----------
|
也可直接构造以覆盖默认参数。
|
||||||
model_name : str
|
|
||||||
OpenAI 兼容模型名称。
|
|
||||||
base_url : str | None
|
|
||||||
自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。
|
|
||||||
api_key : str
|
|
||||||
API 密钥。
|
|
||||||
temperature : float
|
|
||||||
生成温度,抽取任务建议使用较低值。
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = "gpt-4o-mini",
|
model_name: str = "gpt-4o-mini",
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str = "EMPTY",
|
api_key: SecretStr = SecretStr("EMPTY"),
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
|
timeout: int = 60,
|
||||||
|
max_retries: int = 2,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)",
|
||||||
|
model_name,
|
||||||
|
base_url or "default",
|
||||||
|
timeout,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
self._llm = ChatOpenAI(
|
self._llm = ChatOpenAI(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_settings(cls) -> KnowledgeGraphExtractor:
|
||||||
|
"""从全局 Settings 创建抽取器实例。"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model_name=settings.kg_llm_model,
|
||||||
|
base_url=settings.kg_llm_base_url,
|
||||||
|
api_key=settings.kg_llm_api_key,
|
||||||
|
temperature=settings.kg_llm_temperature,
|
||||||
|
timeout=settings.kg_llm_timeout_seconds,
|
||||||
|
max_retries=settings.kg_llm_max_retries,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_transformer(
|
def _build_transformer(
|
||||||
@@ -70,55 +94,89 @@ class KnowledgeGraphExtractor:
|
|||||||
return LLMGraphTransformer(**kwargs)
|
return LLMGraphTransformer(**kwargs)
|
||||||
|
|
||||||
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
async def extract(self, request: ExtractionRequest) -> ExtractionResult:
|
||||||
"""从文本中抽取三元组。
|
"""从文本中异步抽取三元组。"""
|
||||||
|
text_hash = _text_fingerprint(request.text)
|
||||||
|
logger.info(
|
||||||
|
"Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||||
|
request.graph_id,
|
||||||
|
request.source_id,
|
||||||
|
len(request.text),
|
||||||
|
text_hash,
|
||||||
|
)
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
request : ExtractionRequest
|
|
||||||
包含文本、schema 约束等信息的抽取请求。
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ExtractionResult
|
|
||||||
抽取得到的节点、边和三元组。
|
|
||||||
"""
|
|
||||||
transformer = self._build_transformer(request.schema)
|
transformer = self._build_transformer(request.schema)
|
||||||
documents = [Document(page_content=request.text)]
|
documents = [Document(page_content=request.text)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
graph_documents = await transformer.aconvert_to_graph_documents(documents)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
logger.exception(
|
||||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
"LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||||
|
request.graph_id,
|
||||||
|
request.source_id,
|
||||||
|
text_hash,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
logger.info(
|
||||||
|
"Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d",
|
||||||
|
request.graph_id,
|
||||||
|
len(result.nodes),
|
||||||
|
len(result.edges),
|
||||||
|
len(result.triples),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
def extract_sync(self, request: ExtractionRequest) -> ExtractionResult:
|
||||||
"""同步版本的三元组抽取。"""
|
"""同步版本的三元组抽取。"""
|
||||||
|
text_hash = _text_fingerprint(request.text)
|
||||||
|
logger.info(
|
||||||
|
"Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s",
|
||||||
|
request.graph_id,
|
||||||
|
request.source_id,
|
||||||
|
len(request.text),
|
||||||
|
text_hash,
|
||||||
|
)
|
||||||
|
|
||||||
transformer = self._build_transformer(request.schema)
|
transformer = self._build_transformer(request.schema)
|
||||||
documents = [Document(page_content=request.text)]
|
documents = [Document(page_content=request.text)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_documents = transformer.convert_to_graph_documents(documents)
|
graph_documents = transformer.convert_to_graph_documents(documents)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("LLM graph extraction failed for source_id=%s", request.source_id)
|
logger.exception(
|
||||||
return ExtractionResult(raw_text=request.text, source_id=request.source_id)
|
"LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s",
|
||||||
|
request.graph_id,
|
||||||
|
request.source_id,
|
||||||
|
text_hash,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return self._convert_result(graph_documents, request)
|
result = self._convert_result(graph_documents, request)
|
||||||
|
logger.info(
|
||||||
|
"Sync extraction complete: graph_id=%s, nodes=%d, edges=%d",
|
||||||
|
request.graph_id,
|
||||||
|
len(result.nodes),
|
||||||
|
len(result.edges),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def extract_batch(
|
async def extract_batch(
|
||||||
self,
|
self,
|
||||||
requests: Sequence[ExtractionRequest],
|
requests: Sequence[ExtractionRequest],
|
||||||
) -> list[ExtractionResult]:
|
) -> list[ExtractionResult]:
|
||||||
"""批量抽取。
|
"""批量抽取,逐条处理。
|
||||||
|
|
||||||
对多段文本逐一抽取并汇总结果。
|
如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。
|
||||||
如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。
|
|
||||||
"""
|
"""
|
||||||
|
logger.info("Starting batch extraction: count=%d", len(requests))
|
||||||
results: list[ExtractionResult] = []
|
results: list[ExtractionResult] = []
|
||||||
for req in requests:
|
for i, req in enumerate(requests):
|
||||||
|
logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id)
|
||||||
result = await self.extract(req)
|
result = await self.extract(req)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
logger.info("Batch extraction complete: count=%d", len(results))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -133,7 +191,6 @@ class KnowledgeGraphExtractor:
|
|||||||
seen_nodes: set[str] = set()
|
seen_nodes: set[str] = set()
|
||||||
|
|
||||||
for doc in graph_documents:
|
for doc in graph_documents:
|
||||||
# 收集节点
|
|
||||||
for node in doc.nodes:
|
for node in doc.nodes:
|
||||||
node_key = f"{node.id}:{node.type}"
|
node_key = f"{node.id}:{node.type}"
|
||||||
if node_key not in seen_nodes:
|
if node_key not in seen_nodes:
|
||||||
@@ -146,16 +203,9 @@ class KnowledgeGraphExtractor:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 收集关系
|
|
||||||
for rel in doc.relationships:
|
for rel in doc.relationships:
|
||||||
source_node = GraphNode(
|
source_node = GraphNode(name=rel.source.id, type=rel.source.type)
|
||||||
name=rel.source.id,
|
target_node = GraphNode(name=rel.target.id, type=rel.target.type)
|
||||||
type=rel.source.type,
|
|
||||||
)
|
|
||||||
target_node = GraphNode(
|
|
||||||
name=rel.target.id,
|
|
||||||
type=rel.target.type,
|
|
||||||
)
|
|
||||||
|
|
||||||
edges.append(
|
edges.append(
|
||||||
GraphEdge(
|
GraphEdge(
|
||||||
@@ -165,13 +215,8 @@ class KnowledgeGraphExtractor:
|
|||||||
properties=rel.properties if hasattr(rel, "properties") else {},
|
properties=rel.properties if hasattr(rel, "properties") else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
triples.append(
|
triples.append(
|
||||||
Triple(
|
Triple(subject=source_node, predicate=rel.type, object=target_node)
|
||||||
subject=source_node,
|
|
||||||
predicate=rel.type,
|
|
||||||
object=target_node,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ExtractionResult(
|
return ExtractionResult(
|
||||||
|
|||||||
193
runtime/datamate-python/app/module/kg_extraction/interface.py
Normal file
193
runtime/datamate-python/app/module/kg_extraction/interface.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""知识图谱三元组抽取 API。
|
||||||
|
|
||||||
|
注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用,
|
||||||
|
外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。
|
||||||
|
当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.module.kg_extraction.extractor import KnowledgeGraphExtractor
|
||||||
|
from app.module.kg_extraction.models import (
|
||||||
|
ExtractionRequest,
|
||||||
|
ExtractionResult,
|
||||||
|
ExtractionSchema,
|
||||||
|
EntityTypeConstraint,
|
||||||
|
RelationTypeConstraint,
|
||||||
|
)
|
||||||
|
from app.module.shared.schema import StandardResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/kg", tags=["knowledge-graph"])
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM
|
||||||
|
_extractor: KnowledgeGraphExtractor | None = None
|
||||||
|
|
||||||
|
_UUID_PATTERN = (
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符
|
||||||
|
_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extractor() -> KnowledgeGraphExtractor:
|
||||||
|
global _extractor
|
||||||
|
if _extractor is None:
|
||||||
|
_extractor = KnowledgeGraphExtractor.from_settings()
|
||||||
|
return _extractor
|
||||||
|
|
||||||
|
|
||||||
|
def _require_caller_id(
|
||||||
|
x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")],
|
||||||
|
) -> str:
|
||||||
|
"""从请求头提取调用方用户 ID,用于审计日志。
|
||||||
|
|
||||||
|
该接口为内部服务调用,调用方身份由上游 Java 后端通过
|
||||||
|
X-User-Id 请求头传递。缺失或为空时返回 401。
|
||||||
|
"""
|
||||||
|
caller = x_user_id.strip()
|
||||||
|
if not caller:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing required header: X-User-Id")
|
||||||
|
return caller
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request / Response DTO(API 层,与内部 models 解耦)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SourceType(str, Enum):
|
||||||
|
ANNOTATION = "ANNOTATION"
|
||||||
|
KNOWLEDGE_BASE = "KNOWLEDGE_BASE"
|
||||||
|
IMPORT = "IMPORT"
|
||||||
|
MANUAL = "MANUAL"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractRequest(BaseModel):
|
||||||
|
"""三元组抽取请求。"""
|
||||||
|
|
||||||
|
text: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=50000,
|
||||||
|
description="待抽取的文本内容",
|
||||||
|
examples=["张三是北京大学的教授,研究方向为人工智能。"],
|
||||||
|
)
|
||||||
|
graph_id: str = Field(
|
||||||
|
...,
|
||||||
|
pattern=_UUID_PATTERN,
|
||||||
|
description="目标图谱 ID(UUID 格式)",
|
||||||
|
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||||
|
)
|
||||||
|
allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=50,
|
||||||
|
description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符",
|
||||||
|
examples=[["Person", "Organization", "Location"]],
|
||||||
|
)
|
||||||
|
allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=50,
|
||||||
|
description="允许的关系类型列表(schema-guided 抽取)",
|
||||||
|
examples=[["works_at", "located_in"]],
|
||||||
|
)
|
||||||
|
source_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
pattern=_UUID_PATTERN,
|
||||||
|
description="来源 ID(数据集/知识库条目,UUID 格式)",
|
||||||
|
)
|
||||||
|
source_type: SourceType = Field(
|
||||||
|
default=SourceType.KNOWLEDGE_BASE,
|
||||||
|
description="来源类型",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchExtractRequest(BaseModel):
|
||||||
|
"""批量三元组抽取请求。"""
|
||||||
|
|
||||||
|
items: list[ExtractRequest] = Field(
|
||||||
|
...,
|
||||||
|
min_length=1,
|
||||||
|
max_length=50,
|
||||||
|
description="抽取请求列表,单次最多 50 条",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest:
|
||||||
|
"""将 API DTO 转换为内部抽取请求。"""
|
||||||
|
schema: ExtractionSchema | None = None
|
||||||
|
if req.allowed_nodes or req.allowed_relationships:
|
||||||
|
schema = ExtractionSchema(
|
||||||
|
entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])],
|
||||||
|
relation_types=[
|
||||||
|
RelationTypeConstraint(name=r) for r in (req.allowed_relationships or [])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractionRequest(
|
||||||
|
text=req.text,
|
||||||
|
graph_id=req.graph_id,
|
||||||
|
schema=schema,
|
||||||
|
source_id=req.source_id,
|
||||||
|
source_type=req.source_type.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/extract",
|
||||||
|
response_model=StandardResponse[ExtractionResult],
|
||||||
|
summary="三元组抽取",
|
||||||
|
description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。",
|
||||||
|
)
|
||||||
|
async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
|
||||||
|
"""单条文本三元组抽取。"""
|
||||||
|
trace_id = uuid.uuid4().hex[:16]
|
||||||
|
logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
|
||||||
|
|
||||||
|
extractor = _get_extractor()
|
||||||
|
extraction_req = _to_extraction_request(req)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await extractor.extract(extraction_req)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller)
|
||||||
|
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/extract/batch",
|
||||||
|
response_model=StandardResponse[list[ExtractionResult]],
|
||||||
|
summary="批量三元组抽取",
|
||||||
|
description="对多段文本逐条抽取三元组,单次最多 50 条。",
|
||||||
|
)
|
||||||
|
async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]):
|
||||||
|
"""批量文本三元组抽取。"""
|
||||||
|
trace_id = uuid.uuid4().hex[:16]
|
||||||
|
logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller)
|
||||||
|
|
||||||
|
extractor = _get_extractor()
|
||||||
|
extraction_reqs = [_to_extraction_request(item) for item in req.items]
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await extractor.extract_batch(extraction_reqs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller)
|
||||||
|
raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})")
|
||||||
|
|
||||||
|
return StandardResponse(code=200, message="success", data=results)
|
||||||
Reference in New Issue
Block a user