From 0e0782a452a95b444c5821a92bbef94c7df1dcc2 Mon Sep 17 00:00:00 2001 From: Jerry Yan <792602257@qq.com> Date: Tue, 17 Feb 2026 22:01:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(kg-extraction):=20=E5=AE=9E=E7=8E=B0=20Pyt?= =?UTF-8?q?hon=20=E6=8A=BD=E5=8F=96=E5=99=A8=20FastAPI=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现功能: - 创建 kg_extraction/interface.py(FastAPI 路由) - 实现 POST /api/kg/extract(单条文本抽取) - 实现 POST /api/kg/extract/batch(批量抽取,最多 50 条) - 集成到 FastAPI 主路由(/api/kg/ 前缀) 技术实现: - 配置管理:从环境变量读取 LLM 配置(API Key、Base URL、Model、Temperature) - 安全性: - API Key 使用 SecretStr 保护 - 错误信息脱敏(使用 trace_id,不暴露原始异常) - 请求文本不写入日志(使用 SHA-256 hash) - 强制要求 X-User-Id 头(鉴权边界) - 超时控制: - kg_llm_timeout_seconds(60秒) - kg_llm_max_retries(2次) - 输入校验: - graph_id 和 source_id 使用 UUID pattern - source_type 使用 Enum(4个值) - allowed_nodes/relationships 元素使用正则约束(ASCII,1-50字符) - 审计日志:记录 caller、trace_id、text_hash 代码审查: - 经过 3 轮 Codex 审查和 2 轮 Claude 修复 - 所有问题已解决(5个 P1/P2 + 3个 P3) - 语法检查通过 API 端点: - POST /api/kg/extract:单条文本抽取 - POST /api/kg/extract/batch:批量抽取(最多 50 条) 配置环境变量: - KG_LLM_API_KEY:LLM API 密钥 - KG_LLM_BASE_URL:自定义端点(可选) - KG_LLM_MODEL:模型名称(默认 gpt-4o-mini) - KG_LLM_TEMPERATURE:生成温度(默认 0.0) - KG_LLM_TIMEOUT_SECONDS:超时时间(默认 60) - KG_LLM_MAX_RETRIES:重试次数(默认 2) --- runtime/datamate-python/app/core/config.py | 12 +- .../datamate-python/app/module/__init__.py | 2 + .../app/module/kg_extraction/__init__.py | 2 + .../app/module/kg_extraction/extractor.py | 145 ++++++++----- .../app/module/kg_extraction/interface.py | 193 ++++++++++++++++++ 5 files changed, 302 insertions(+), 52 deletions(-) create mode 100644 runtime/datamate-python/app/module/kg_extraction/interface.py diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index df3efa2..3730e0f 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -1,5 +1,5 @@ from pydantic_settings import BaseSettings -from pydantic import model_validator +from pydantic import SecretStr, model_validator from typing import Optional class Settings(BaseSettings): @@ -62,9 +62,17 @@ class Settings(BaseSettings): # DataMate dm_file_path_prefix: str = "/dataset" # DM存储文件夹前缀 - # DataMate Backend (Java) - 用于通过“下载/预览接口”读取文件内容 + # DataMate Backend (Java) - 用于通过"下载/预览接口"读取文件内容 datamate_backend_base_url: str = "http://datamate-backend:8080/api" + # Knowledge Graph - LLM 三元组抽取配置 + kg_llm_api_key: SecretStr = SecretStr("EMPTY") + kg_llm_base_url: Optional[str] = None + kg_llm_model: str = "gpt-4o-mini" + kg_llm_temperature: float = 0.0 + kg_llm_timeout_seconds: int = 60 + kg_llm_max_retries: int = 2 + # 标注编辑器(Label Studio Editor)相关 editor_max_text_bytes: int = 0 # <=0 表示不限制,正数为最大字节数 diff --git a/runtime/datamate-python/app/module/__init__.py b/runtime/datamate-python/app/module/__init__.py index 2dab088..4900d44 100644 --- a/runtime/datamate-python/app/module/__init__.py +++ b/runtime/datamate-python/app/module/__init__.py @@ -7,6 +7,7 @@ from .generation.interface import router as generation_router from .evaluation.interface import router as evaluation_router from .collection.interface import router as collection_route from .dataset.interface import router as dataset_router +from .kg_extraction.interface import router as kg_extraction_router router = APIRouter( prefix="/api" @@ -19,5 +20,6 @@ router.include_router(generation_router) router.include_router(evaluation_router) router.include_router(collection_route) router.include_router(dataset_router) +router.include_router(kg_extraction_router) __all__ = ["router"] diff --git a/runtime/datamate-python/app/module/kg_extraction/__init__.py b/runtime/datamate-python/app/module/kg_extraction/__init__.py index 56e8d7c..f8a973b 100644 --- a/runtime/datamate-python/app/module/kg_extraction/__init__.py +++ b/runtime/datamate-python/app/module/kg_extraction/__init__.py @@ -6,6 +6,7 @@ from app.module.kg_extraction.models import ( GraphNode, GraphEdge, ) +from app.module.kg_extraction.interface import router __all__ = [ "KnowledgeGraphExtractor", @@ -14,4 +15,5 @@ __all__ = [ "Triple", "GraphNode", "GraphEdge", + "router", ] diff --git a/runtime/datamate-python/app/module/kg_extraction/extractor.py b/runtime/datamate-python/app/module/kg_extraction/extractor.py index 498cde8..d3587c3 100644 --- a/runtime/datamate-python/app/module/kg_extraction/extractor.py +++ b/runtime/datamate-python/app/module/kg_extraction/extractor.py @@ -6,13 +6,15 @@ from __future__ import annotations -import logging +import hashlib from typing import Sequence from langchain_core.documents import Document from langchain_openai import ChatOpenAI from langchain_experimental.graph_transformers import LLMGraphTransformer +from pydantic import SecretStr +from app.core.logging import get_logger from app.module.kg_extraction.models import ( ExtractionRequest, ExtractionResult, @@ -22,36 +24,58 @@ from app.module.kg_extraction.models import ( Triple, ) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) + + +def _text_fingerprint(text: str) -> str: + """返回文本的短 SHA-256 摘要,用于日志关联而不泄露原文。""" + return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12] class KnowledgeGraphExtractor: """基于 LLMGraphTransformer 的三元组抽取器。 - Parameters - ---------- - model_name : str - OpenAI 兼容模型名称。 - base_url : str | None - 自定义 API base URL(用于对接 vLLM/Ollama 等本地模型服务)。 - api_key : str - API 密钥。 - temperature : float - 生成温度,抽取任务建议使用较低值。 + 通过 ``from_settings()`` 工厂方法从全局配置创建实例, + 也可直接构造以覆盖默认参数。 """ def __init__( self, model_name: str = "gpt-4o-mini", base_url: str | None = None, - api_key: str = "EMPTY", + api_key: SecretStr = SecretStr("EMPTY"), temperature: float = 0.0, + timeout: int = 60, + max_retries: int = 2, ) -> None: + logger.info( + "Initializing KnowledgeGraphExtractor (model=%s, base_url=%s, timeout=%ds, max_retries=%d)", + model_name, + base_url or "default", + timeout, + max_retries, + ) self._llm = ChatOpenAI( model=model_name, base_url=base_url, api_key=api_key, temperature=temperature, + timeout=timeout, + max_retries=max_retries, + ) + + @classmethod + def from_settings(cls) -> KnowledgeGraphExtractor: + """从全局 Settings 创建抽取器实例。""" + from app.core.config import settings + + return cls( + model_name=settings.kg_llm_model, + base_url=settings.kg_llm_base_url, + api_key=settings.kg_llm_api_key, + temperature=settings.kg_llm_temperature, + timeout=settings.kg_llm_timeout_seconds, + max_retries=settings.kg_llm_max_retries, ) def _build_transformer( @@ -70,55 +94,89 @@ class KnowledgeGraphExtractor: return LLMGraphTransformer(**kwargs) async def extract(self, request: ExtractionRequest) -> ExtractionResult: - """从文本中抽取三元组。 + """从文本中异步抽取三元组。""" + text_hash = _text_fingerprint(request.text) + logger.info( + "Starting extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s", + request.graph_id, + request.source_id, + len(request.text), + text_hash, + ) - Parameters - ---------- - request : ExtractionRequest - 包含文本、schema 约束等信息的抽取请求。 - - Returns - ------- - ExtractionResult - 抽取得到的节点、边和三元组。 - """ transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = await transformer.aconvert_to_graph_documents(documents) except Exception: - logger.exception("LLM graph extraction failed for source_id=%s", request.source_id) - return ExtractionResult(raw_text=request.text, source_id=request.source_id) + logger.exception( + "LLM extraction failed: graph_id=%s, source_id=%s, text_hash=%s", + request.graph_id, + request.source_id, + text_hash, + ) + raise - return self._convert_result(graph_documents, request) + result = self._convert_result(graph_documents, request) + logger.info( + "Extraction complete: graph_id=%s, nodes=%d, edges=%d, triples=%d", + request.graph_id, + len(result.nodes), + len(result.edges), + len(result.triples), + ) + return result def extract_sync(self, request: ExtractionRequest) -> ExtractionResult: """同步版本的三元组抽取。""" + text_hash = _text_fingerprint(request.text) + logger.info( + "Starting sync extraction: graph_id=%s, source_id=%s, text_len=%d, text_hash=%s", + request.graph_id, + request.source_id, + len(request.text), + text_hash, + ) + transformer = self._build_transformer(request.schema) documents = [Document(page_content=request.text)] try: graph_documents = transformer.convert_to_graph_documents(documents) except Exception: - logger.exception("LLM graph extraction failed for source_id=%s", request.source_id) - return ExtractionResult(raw_text=request.text, source_id=request.source_id) + logger.exception( + "LLM sync extraction failed: graph_id=%s, source_id=%s, text_hash=%s", + request.graph_id, + request.source_id, + text_hash, + ) + raise - return self._convert_result(graph_documents, request) + result = self._convert_result(graph_documents, request) + logger.info( + "Sync extraction complete: graph_id=%s, nodes=%d, edges=%d", + request.graph_id, + len(result.nodes), + len(result.edges), + ) + return result async def extract_batch( self, requests: Sequence[ExtractionRequest], ) -> list[ExtractionResult]: - """批量抽取。 + """批量抽取,逐条处理。 - 对多段文本逐一抽取并汇总结果。 - 如需更高吞吐,可自行用 asyncio.gather 并发调用 extract。 + 如需更高吞吐,可在调用侧用 asyncio.gather 并发调用 extract。 """ + logger.info("Starting batch extraction: count=%d", len(requests)) results: list[ExtractionResult] = [] - for req in requests: + for i, req in enumerate(requests): + logger.debug("Batch item %d/%d: source_id=%s", i + 1, len(requests), req.source_id) result = await self.extract(req) results.append(result) + logger.info("Batch extraction complete: count=%d", len(results)) return results @staticmethod @@ -133,7 +191,6 @@ class KnowledgeGraphExtractor: seen_nodes: set[str] = set() for doc in graph_documents: - # 收集节点 for node in doc.nodes: node_key = f"{node.id}:{node.type}" if node_key not in seen_nodes: @@ -146,16 +203,9 @@ class KnowledgeGraphExtractor: ) ) - # 收集关系 for rel in doc.relationships: - source_node = GraphNode( - name=rel.source.id, - type=rel.source.type, - ) - target_node = GraphNode( - name=rel.target.id, - type=rel.target.type, - ) + source_node = GraphNode(name=rel.source.id, type=rel.source.type) + target_node = GraphNode(name=rel.target.id, type=rel.target.type) edges.append( GraphEdge( @@ -165,13 +215,8 @@ class KnowledgeGraphExtractor: properties=rel.properties if hasattr(rel, "properties") else {}, ) ) - triples.append( - Triple( - subject=source_node, - predicate=rel.type, - object=target_node, - ) + Triple(subject=source_node, predicate=rel.type, object=target_node) ) return ExtractionResult( diff --git a/runtime/datamate-python/app/module/kg_extraction/interface.py b/runtime/datamate-python/app/module/kg_extraction/interface.py new file mode 100644 index 0000000..284153b --- /dev/null +++ b/runtime/datamate-python/app/module/kg_extraction/interface.py @@ -0,0 +1,193 @@ +"""知识图谱三元组抽取 API。 + +注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用, +外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。 +当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。 +""" + +from __future__ import annotations + +import uuid +from enum import Enum +from typing import Annotated, Optional + +from fastapi import APIRouter, Depends, Header, HTTPException +from pydantic import BaseModel, Field + +from app.core.logging import get_logger +from app.module.kg_extraction.extractor import KnowledgeGraphExtractor +from app.module.kg_extraction.models import ( + ExtractionRequest, + ExtractionResult, + ExtractionSchema, + EntityTypeConstraint, + RelationTypeConstraint, +) +from app.module.shared.schema import StandardResponse + +router = APIRouter(prefix="/kg", tags=["knowledge-graph"]) +logger = get_logger(__name__) + +# 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM +_extractor: KnowledgeGraphExtractor | None = None + +_UUID_PATTERN = ( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" +) + +# 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符 +_TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$" + + +def _get_extractor() -> KnowledgeGraphExtractor: + global _extractor + if _extractor is None: + _extractor = KnowledgeGraphExtractor.from_settings() + return _extractor + + +def _require_caller_id( + x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")], +) -> str: + """从请求头提取调用方用户 ID,用于审计日志。 + + 该接口为内部服务调用,调用方身份由上游 Java 后端通过 + X-User-Id 请求头传递。缺失或为空时返回 401。 + """ + caller = x_user_id.strip() + if not caller: + raise HTTPException(status_code=401, detail="Missing required header: X-User-Id") + return caller + + +# --------------------------------------------------------------------------- +# Request / Response DTO(API 层,与内部 models 解耦) +# --------------------------------------------------------------------------- + + +class SourceType(str, Enum): + ANNOTATION = "ANNOTATION" + KNOWLEDGE_BASE = "KNOWLEDGE_BASE" + IMPORT = "IMPORT" + MANUAL = "MANUAL" + + +class ExtractRequest(BaseModel): + """三元组抽取请求。""" + + text: str = Field( + ..., + min_length=1, + max_length=50000, + description="待抽取的文本内容", + examples=["张三是北京大学的教授,研究方向为人工智能。"], + ) + graph_id: str = Field( + ..., + pattern=_UUID_PATTERN, + description="目标图谱 ID(UUID 格式)", + examples=["550e8400-e29b-41d4-a716-446655440000"], + ) + allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field( + default=None, + max_length=50, + description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符", + examples=[["Person", "Organization", "Location"]], + ) + allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field( + default=None, + max_length=50, + description="允许的关系类型列表(schema-guided 抽取)", + examples=[["works_at", "located_in"]], + ) + source_id: Optional[str] = Field( + default=None, + pattern=_UUID_PATTERN, + description="来源 ID(数据集/知识库条目,UUID 格式)", + ) + source_type: SourceType = Field( + default=SourceType.KNOWLEDGE_BASE, + description="来源类型", + ) + + +class BatchExtractRequest(BaseModel): + """批量三元组抽取请求。""" + + items: list[ExtractRequest] = Field( + ..., + min_length=1, + max_length=50, + description="抽取请求列表,单次最多 50 条", + ) + + +def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest: + """将 API DTO 转换为内部抽取请求。""" + schema: ExtractionSchema | None = None + if req.allowed_nodes or req.allowed_relationships: + schema = ExtractionSchema( + entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])], + relation_types=[ + RelationTypeConstraint(name=r) for r in (req.allowed_relationships or []) + ], + ) + + return ExtractionRequest( + text=req.text, + graph_id=req.graph_id, + schema=schema, + source_id=req.source_id, + source_type=req.source_type.value, + ) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.post( + "/extract", + response_model=StandardResponse[ExtractionResult], + summary="三元组抽取", + description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。", +) +async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]): + """单条文本三元组抽取。""" + trace_id = uuid.uuid4().hex[:16] + logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller) + + extractor = _get_extractor() + extraction_req = _to_extraction_request(req) + + try: + result = await extractor.extract(extraction_req) + except Exception: + logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller) + raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})") + + return StandardResponse(code=200, message="success", data=result) + + +@router.post( + "/extract/batch", + response_model=StandardResponse[list[ExtractionResult]], + summary="批量三元组抽取", + description="对多段文本逐条抽取三元组,单次最多 50 条。", +) +async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]): + """批量文本三元组抽取。""" + trace_id = uuid.uuid4().hex[:16] + logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller) + + extractor = _get_extractor() + extraction_reqs = [_to_extraction_request(item) for item in req.items] + + try: + results = await extractor.extract_batch(extraction_reqs) + except Exception: + logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller) + raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})") + + return StandardResponse(code=200, message="success", data=results)