You've already forked DataMate
feat(auto-annotation): add LLM-based annotation operators
Add three new LLM-powered auto-annotation operators: - LLMTextClassification: Text classification using LLM - LLMNamedEntityRecognition: Named entity recognition with type validation - LLMRelationExtraction: Relation extraction with entity and relation type validation Key features: - Load LLM config from t_model_config table via modelId parameter - Lazy loading of LLM configuration on first execute() - Result validation with whitelist checking for entity/relation types - Fault-tolerant: returns empty results on LLM failure instead of throwing - Fully compatible with existing Worker pipeline Files added: - runtime/ops/annotation/_llm_utils.py: Shared LLM utilities - runtime/ops/annotation/llm_text_classification/: Text classification operator - runtime/ops/annotation/llm_named_entity_recognition/: NER operator - runtime/ops/annotation/llm_relation_extraction/: Relation extraction operator Files modified: - runtime/ops/annotation/__init__.py: Register 3 new operators - runtime/python-executor/datamate/auto_annotation_worker.py: Add to Worker whitelist - frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts: Add to frontend whitelist
This commit is contained in:
168
runtime/ops/annotation/_llm_utils.py
Normal file
168
runtime/ops/annotation/_llm_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。
|
||||
|
||||
提供三项核心能力:
|
||||
1. 从 t_model_config 表加载模型配置(按 ID / 按默认)
|
||||
2. 调用 OpenAI 兼容 chat/completions API
|
||||
3. 从 LLM 原始输出中提取 JSON
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模型配置加载
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_model_config(model_id: str) -> Dict[str, Any]:
|
||||
"""根据 model_id 从 t_model_config 读取已启用的模型配置。"""
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sql = sql_text(
|
||||
"""
|
||||
SELECT model_name, provider, base_url, api_key, type
|
||||
FROM t_model_config
|
||||
WHERE id = :model_id AND is_enabled = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
with SQLManager.create_connect() as conn:
|
||||
row = conn.execute(sql, {"model_id": model_id}).fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"Model config not found or disabled: {model_id}")
|
||||
return dict(row._mapping)
|
||||
|
||||
|
||||
def load_default_model_config() -> Dict[str, Any]:
|
||||
"""加载默认的 chat 模型配置(is_default=1 且 type='chat')。"""
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sql = sql_text(
|
||||
"""
|
||||
SELECT id, model_name, provider, base_url, api_key, type
|
||||
FROM t_model_config
|
||||
WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat'
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
with SQLManager.create_connect() as conn:
|
||||
row = conn.execute(sql).fetchone()
|
||||
if not row:
|
||||
raise ValueError("No default chat model configured in t_model_config")
|
||||
return dict(row._mapping)
|
||||
|
||||
|
||||
def get_llm_config(model_id: str = "") -> Dict[str, Any]:
|
||||
"""优先按 model_id 加载,未提供则加载默认模型。"""
|
||||
|
||||
if model_id:
|
||||
return load_model_config(model_id)
|
||||
return load_default_model_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM 调用
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def call_llm(
|
||||
config: Dict[str, Any],
|
||||
prompt: str,
|
||||
system_prompt: str = "",
|
||||
temperature: float = 0.1,
|
||||
max_retries: int = 2,
|
||||
) -> str:
|
||||
"""调用 OpenAI 兼容的 chat/completions API 并返回文本内容。"""
|
||||
|
||||
import requests as http_requests
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
api_key = config.get("api_key", "")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
base_url = str(config["base_url"]).rstrip("/")
|
||||
# 兼容 base_url 已包含 /v1 或不包含的情况
|
||||
if not base_url.endswith("/chat/completions"):
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
url = f"{base_url}/chat/completions"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
body = {
|
||||
"model": config["model_name"],
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
last_err = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
resp = http_requests.post(url, json=body, headers=headers, timeout=120)
|
||||
resp.raise_for_status()
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
logger.warning(
|
||||
"LLM call attempt {}/{} failed: {}",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
e,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON 提取
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def extract_json(raw: str) -> Any:
|
||||
"""从 LLM 原始输出中提取 JSON 对象/数组。
|
||||
|
||||
处理常见干扰:Markdown 代码块、<think> 标签、前后说明文字。
|
||||
"""
|
||||
|
||||
if not raw:
|
||||
raise ValueError("Empty LLM response")
|
||||
|
||||
# 1. 去除 <think>...</think> 等思考标签
|
||||
thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"]
|
||||
for tag in thought_tags:
|
||||
raw = re.sub(rf"<{tag}>[\s\S]*?</{tag}>", "", raw, flags=re.IGNORECASE)
|
||||
|
||||
# 2. 去除 Markdown 代码块标记
|
||||
raw = re.sub(r"```(?:json)?\s*", "", raw)
|
||||
raw = raw.replace("```", "")
|
||||
|
||||
# 3. 定位第一个 { 或 [ 到最后一个 } 或 ]
|
||||
start = None
|
||||
end = None
|
||||
for i, ch in enumerate(raw):
|
||||
if ch in "{[":
|
||||
start = i
|
||||
break
|
||||
for i in range(len(raw) - 1, -1, -1):
|
||||
if raw[i] in "]}":
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
if start is not None and end is not None and start < end:
|
||||
return json.loads(raw[start:end])
|
||||
|
||||
# 兜底:直接尝试解析
|
||||
return json.loads(raw.strip())
|
||||
Reference in New Issue
Block a user