"""知识图谱三元组抽取 API。 注意:本模块的接口由 Java 后端 (datamate-backend) 通过内网调用, 外部请求经 API Gateway 鉴权后由 Java 侧转发,不直接暴露给终端用户。 当前通过 X-User-Id 请求头获取调用方身份并记录审计日志。 """ from __future__ import annotations import uuid from enum import Enum from typing import Annotated, Optional from fastapi import APIRouter, Depends, Header, HTTPException from pydantic import BaseModel, Field from app.core.logging import get_logger from app.module.kg_extraction.extractor import KnowledgeGraphExtractor from app.module.kg_extraction.models import ( ExtractionRequest, ExtractionResult, ExtractionSchema, EntityTypeConstraint, RelationTypeConstraint, ) from app.module.shared.schema import StandardResponse router = APIRouter(prefix="/kg", tags=["knowledge-graph"]) logger = get_logger(__name__) # 延迟初始化:首次请求时创建,避免启动阶段就连接 LLM _extractor: KnowledgeGraphExtractor | None = None _UUID_PATTERN = ( r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" ) # 允许的实体/关系类型名称:字母、数字、下划线、连字符,1-50 字符 _TYPE_NAME_PATTERN = r"^[A-Za-z0-9_\-]{1,50}$" def _get_extractor() -> KnowledgeGraphExtractor: global _extractor if _extractor is None: _extractor = KnowledgeGraphExtractor.from_settings() return _extractor def _require_caller_id( x_user_id: Annotated[str, Header(min_length=1, description="调用方用户 ID,由上游 Java 后端传递")], ) -> str: """从请求头提取调用方用户 ID,用于审计日志。 该接口为内部服务调用,调用方身份由上游 Java 后端通过 X-User-Id 请求头传递。缺失或为空时返回 401。 """ caller = x_user_id.strip() if not caller: raise HTTPException(status_code=401, detail="Missing required header: X-User-Id") return caller # --------------------------------------------------------------------------- # Request / Response DTO(API 层,与内部 models 解耦) # --------------------------------------------------------------------------- class SourceType(str, Enum): ANNOTATION = "ANNOTATION" KNOWLEDGE_BASE = "KNOWLEDGE_BASE" IMPORT = "IMPORT" MANUAL = "MANUAL" class ExtractRequest(BaseModel): """三元组抽取请求。""" text: str = Field( ..., min_length=1, max_length=50000, description="待抽取的文本内容", examples=["张三是北京大学的教授,研究方向为人工智能。"], ) graph_id: str = Field( ..., pattern=_UUID_PATTERN, description="目标图谱 ID(UUID 格式)", examples=["550e8400-e29b-41d4-a716-446655440000"], ) allowed_nodes: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field( default=None, max_length=50, description="允许的实体类型列表(schema-guided 抽取),每项 1-50 个字母/数字/下划线/连字符", examples=[["Person", "Organization", "Location"]], ) allowed_relationships: Optional[list[Annotated[str, Field(pattern=_TYPE_NAME_PATTERN)]]] = Field( default=None, max_length=50, description="允许的关系类型列表(schema-guided 抽取)", examples=[["works_at", "located_in"]], ) source_id: Optional[str] = Field( default=None, pattern=_UUID_PATTERN, description="来源 ID(数据集/知识库条目,UUID 格式)", ) source_type: SourceType = Field( default=SourceType.KNOWLEDGE_BASE, description="来源类型", ) class BatchExtractRequest(BaseModel): """批量三元组抽取请求。""" items: list[ExtractRequest] = Field( ..., min_length=1, max_length=50, description="抽取请求列表,单次最多 50 条", ) def _to_extraction_request(req: ExtractRequest) -> ExtractionRequest: """将 API DTO 转换为内部抽取请求。""" schema: ExtractionSchema | None = None if req.allowed_nodes or req.allowed_relationships: schema = ExtractionSchema( entity_types=[EntityTypeConstraint(name=n) for n in (req.allowed_nodes or [])], relation_types=[ RelationTypeConstraint(name=r) for r in (req.allowed_relationships or []) ], ) return ExtractionRequest( text=req.text, graph_id=req.graph_id, schema=schema, source_id=req.source_id, source_type=req.source_type.value, ) # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.post( "/extract", response_model=StandardResponse[ExtractionResult], summary="三元组抽取", description="从文本中抽取实体和关系,返回知识图谱三元组。支持通过 allowed_nodes 和 allowed_relationships 约束抽取范围。", ) async def extract(req: ExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]): """单条文本三元组抽取。""" trace_id = uuid.uuid4().hex[:16] logger.info("[%s] Extract request: graph_id=%s, caller=%s", trace_id, req.graph_id, caller) extractor = _get_extractor() extraction_req = _to_extraction_request(req) try: result = await extractor.extract(extraction_req) except Exception: logger.exception("[%s] Extraction failed: graph_id=%s, caller=%s", trace_id, req.graph_id, caller) raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})") return StandardResponse(code=200, message="success", data=result) @router.post( "/extract/batch", response_model=StandardResponse[list[ExtractionResult]], summary="批量三元组抽取", description="对多段文本逐条抽取三元组,单次最多 50 条。", ) async def extract_batch(req: BatchExtractRequest, caller: Annotated[str, Depends(_require_caller_id)]): """批量文本三元组抽取。""" trace_id = uuid.uuid4().hex[:16] logger.info("[%s] Batch extract request: count=%d, caller=%s", trace_id, len(req.items), caller) extractor = _get_extractor() extraction_reqs = [_to_extraction_request(item) for item in req.items] try: results = await extractor.extract_batch(extraction_reqs) except Exception: logger.exception("[%s] Batch extraction failed: caller=%s", trace_id, caller) raise HTTPException(status_code=502, detail=f"抽取服务暂不可用 (trace: {trace_id})") return StandardResponse(code=200, message="success", data=results)