# -*- coding: utf-8 -*- """LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。 提供三项核心能力: 1. 从 t_model_config 表加载模型配置(按 ID / 按默认) 2. 调用 OpenAI 兼容 chat/completions API 3. 从 LLM 原始输出中提取 JSON """ import json import re from typing import Any, Dict from loguru import logger # --------------------------------------------------------------------------- # 模型配置加载 # --------------------------------------------------------------------------- def load_model_config(model_id: str) -> Dict[str, Any]: """根据 model_id 从 t_model_config 读取已启用的模型配置。""" from datamate.sql_manager.sql_manager import SQLManager from sqlalchemy import text as sql_text sql = sql_text( """ SELECT model_name, provider, base_url, api_key, type FROM t_model_config WHERE id = :model_id AND is_enabled = 1 LIMIT 1 """ ) with SQLManager.create_connect() as conn: row = conn.execute(sql, {"model_id": model_id}).fetchone() if not row: raise ValueError(f"Model config not found or disabled: {model_id}") return dict(row._mapping) def load_default_model_config() -> Dict[str, Any]: """加载默认的 chat 模型配置(is_default=1 且 type='chat')。""" from datamate.sql_manager.sql_manager import SQLManager from sqlalchemy import text as sql_text sql = sql_text( """ SELECT id, model_name, provider, base_url, api_key, type FROM t_model_config WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat' LIMIT 1 """ ) with SQLManager.create_connect() as conn: row = conn.execute(sql).fetchone() if not row: raise ValueError("No default chat model configured in t_model_config") return dict(row._mapping) def get_llm_config(model_id: str = "") -> Dict[str, Any]: """优先按 model_id 加载,未提供则加载默认模型。""" if model_id: return load_model_config(model_id) return load_default_model_config() # --------------------------------------------------------------------------- # LLM 调用 # --------------------------------------------------------------------------- def call_llm( config: Dict[str, Any], prompt: str, system_prompt: str = "", temperature: float = 0.1, max_retries: int = 2, ) -> str: """调用 OpenAI 兼容的 chat/completions API 并返回文本内容。""" import requests as http_requests messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) headers: Dict[str, str] = {"Content-Type": "application/json"} api_key = config.get("api_key", "") if api_key: headers["Authorization"] = f"Bearer {api_key}" base_url = str(config["base_url"]).rstrip("/") # 兼容 base_url 已包含 /v1 或不包含的情况 if not base_url.endswith("/chat/completions"): if not base_url.endswith("/v1"): base_url = f"{base_url}/v1" url = f"{base_url}/chat/completions" else: url = base_url body = { "model": config["model_name"], "messages": messages, "temperature": temperature, } last_err = None for attempt in range(max_retries + 1): try: resp = http_requests.post(url, json=body, headers=headers, timeout=120) resp.raise_for_status() content = resp.json()["choices"][0]["message"]["content"] return content except Exception as e: last_err = e logger.warning( "LLM call attempt {}/{} failed: {}", attempt + 1, max_retries + 1, e, ) raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}") # --------------------------------------------------------------------------- # JSON 提取 # --------------------------------------------------------------------------- def extract_json(raw: str) -> Any: """从 LLM 原始输出中提取 JSON 对象/数组。 处理常见干扰:Markdown 代码块、 标签、前后说明文字。 """ if not raw: raise ValueError("Empty LLM response") # 1. 去除 ... 等思考标签 thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"] for tag in thought_tags: raw = re.sub(rf"<{tag}>[\s\S]*?", "", raw, flags=re.IGNORECASE) # 2. 去除 Markdown 代码块标记 raw = re.sub(r"```(?:json)?\s*", "", raw) raw = raw.replace("```", "") # 3. 定位第一个 { 或 [ 到最后一个 } 或 ] start = None end = None for i, ch in enumerate(raw): if ch in "{[": start = i break for i in range(len(raw) - 1, -1, -1): if raw[i] in "]}": end = i + 1 break if start is not None and end is not None and start < end: return json.loads(raw[start:end]) # 兜底:直接尝试解析 return json.loads(raw.strip())