You've already forked DataMate
feat(auto-annotation): add LLM-based annotation operators
Add three new LLM-powered auto-annotation operators: - LLMTextClassification: Text classification using LLM - LLMNamedEntityRecognition: Named entity recognition with type validation - LLMRelationExtraction: Relation extraction with entity and relation type validation Key features: - Load LLM config from t_model_config table via modelId parameter - Lazy loading of LLM configuration on first execute() - Result validation with whitelist checking for entity/relation types - Fault-tolerant: returns empty results on LLM failure instead of throwing - Fully compatible with existing Worker pipeline Files added: - runtime/ops/annotation/_llm_utils.py: Shared LLM utilities - runtime/ops/annotation/llm_text_classification/: Text classification operator - runtime/ops/annotation/llm_named_entity_recognition/: NER operator - runtime/ops/annotation/llm_relation_extraction/: Relation extraction operator Files modified: - runtime/ops/annotation/__init__.py: Register 3 new operators - runtime/python-executor/datamate/auto_annotation_worker.py: Add to Worker whitelist - frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts: Add to frontend whitelist
This commit is contained in:
@@ -22,6 +22,9 @@ type CategoryGroup = {
|
||||
const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([
|
||||
"ImageObjectDetectionBoundingBox",
|
||||
"test_annotation_marker",
|
||||
"LLMTextClassification",
|
||||
"LLMNamedEntityRecognition",
|
||||
"LLMRelationExtraction",
|
||||
]);
|
||||
|
||||
const ensureArray = (value: unknown): string[] => {
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Annotation-related operators (e.g. YOLO detection)."""
|
||||
"""Annotation-related operators (e.g. YOLO detection, LLM-based NLP annotation)."""
|
||||
|
||||
from . import image_object_detection_bounding_box
|
||||
from . import test_annotation_marker
|
||||
from . import llm_text_classification
|
||||
from . import llm_named_entity_recognition
|
||||
from . import llm_relation_extraction
|
||||
|
||||
__all__ = [
|
||||
"image_object_detection_bounding_box",
|
||||
"test_annotation_marker",
|
||||
"llm_text_classification",
|
||||
"llm_named_entity_recognition",
|
||||
"llm_relation_extraction",
|
||||
]
|
||||
|
||||
168
runtime/ops/annotation/_llm_utils.py
Normal file
168
runtime/ops/annotation/_llm_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。
|
||||
|
||||
提供三项核心能力:
|
||||
1. 从 t_model_config 表加载模型配置(按 ID / 按默认)
|
||||
2. 调用 OpenAI 兼容 chat/completions API
|
||||
3. 从 LLM 原始输出中提取 JSON
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模型配置加载
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_model_config(model_id: str) -> Dict[str, Any]:
|
||||
"""根据 model_id 从 t_model_config 读取已启用的模型配置。"""
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sql = sql_text(
|
||||
"""
|
||||
SELECT model_name, provider, base_url, api_key, type
|
||||
FROM t_model_config
|
||||
WHERE id = :model_id AND is_enabled = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
with SQLManager.create_connect() as conn:
|
||||
row = conn.execute(sql, {"model_id": model_id}).fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"Model config not found or disabled: {model_id}")
|
||||
return dict(row._mapping)
|
||||
|
||||
|
||||
def load_default_model_config() -> Dict[str, Any]:
|
||||
"""加载默认的 chat 模型配置(is_default=1 且 type='chat')。"""
|
||||
|
||||
from datamate.sql_manager.sql_manager import SQLManager
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
sql = sql_text(
|
||||
"""
|
||||
SELECT id, model_name, provider, base_url, api_key, type
|
||||
FROM t_model_config
|
||||
WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat'
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
with SQLManager.create_connect() as conn:
|
||||
row = conn.execute(sql).fetchone()
|
||||
if not row:
|
||||
raise ValueError("No default chat model configured in t_model_config")
|
||||
return dict(row._mapping)
|
||||
|
||||
|
||||
def get_llm_config(model_id: str = "") -> Dict[str, Any]:
|
||||
"""优先按 model_id 加载,未提供则加载默认模型。"""
|
||||
|
||||
if model_id:
|
||||
return load_model_config(model_id)
|
||||
return load_default_model_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM 调用
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def call_llm(
|
||||
config: Dict[str, Any],
|
||||
prompt: str,
|
||||
system_prompt: str = "",
|
||||
temperature: float = 0.1,
|
||||
max_retries: int = 2,
|
||||
) -> str:
|
||||
"""调用 OpenAI 兼容的 chat/completions API 并返回文本内容。"""
|
||||
|
||||
import requests as http_requests
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||
api_key = config.get("api_key", "")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
base_url = str(config["base_url"]).rstrip("/")
|
||||
# 兼容 base_url 已包含 /v1 或不包含的情况
|
||||
if not base_url.endswith("/chat/completions"):
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
url = f"{base_url}/chat/completions"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
body = {
|
||||
"model": config["model_name"],
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
last_err = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
resp = http_requests.post(url, json=body, headers=headers, timeout=120)
|
||||
resp.raise_for_status()
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
return content
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
logger.warning(
|
||||
"LLM call attempt {}/{} failed: {}",
|
||||
attempt + 1,
|
||||
max_retries + 1,
|
||||
e,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON 提取
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def extract_json(raw: str) -> Any:
|
||||
"""从 LLM 原始输出中提取 JSON 对象/数组。
|
||||
|
||||
处理常见干扰:Markdown 代码块、<think> 标签、前后说明文字。
|
||||
"""
|
||||
|
||||
if not raw:
|
||||
raise ValueError("Empty LLM response")
|
||||
|
||||
# 1. 去除 <think>...</think> 等思考标签
|
||||
thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"]
|
||||
for tag in thought_tags:
|
||||
raw = re.sub(rf"<{tag}>[\s\S]*?</{tag}>", "", raw, flags=re.IGNORECASE)
|
||||
|
||||
# 2. 去除 Markdown 代码块标记
|
||||
raw = re.sub(r"```(?:json)?\s*", "", raw)
|
||||
raw = raw.replace("```", "")
|
||||
|
||||
# 3. 定位第一个 { 或 [ 到最后一个 } 或 ]
|
||||
start = None
|
||||
end = None
|
||||
for i, ch in enumerate(raw):
|
||||
if ch in "{[":
|
||||
start = i
|
||||
break
|
||||
for i in range(len(raw) - 1, -1, -1):
|
||||
if raw[i] in "]}":
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
if start is not None and end is not None and start < end:
|
||||
return json.loads(raw[start:end])
|
||||
|
||||
# 兜底:直接尝试解析
|
||||
return json.loads(raw.strip())
|
||||
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datamate.core.base_op import OPERATORS
|
||||
from .process import LLMNamedEntityRecognition
|
||||
|
||||
OPERATORS.register_module(
|
||||
module_name="LLMNamedEntityRecognition",
|
||||
module_path="ops.annotation.llm_named_entity_recognition.process",
|
||||
)
|
||||
|
||||
__all__ = ["LLMNamedEntityRecognition"]
|
||||
@@ -0,0 +1,29 @@
|
||||
name: 'LLM命名实体识别'
|
||||
name_en: 'LLM Named Entity Recognition'
|
||||
description: '基于大语言模型的命名实体识别算子,支持自定义实体类型。'
|
||||
description_en: 'LLM-based NER operator with custom entity types.'
|
||||
language: 'python'
|
||||
vendor: 'datamate'
|
||||
raw_id: 'LLMNamedEntityRecognition'
|
||||
version: '1.0.0'
|
||||
types:
|
||||
- 'annotation'
|
||||
modal: 'text'
|
||||
inputs: 'text'
|
||||
outputs: 'text'
|
||||
settings:
|
||||
modelId:
|
||||
name: '模型ID'
|
||||
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
entityTypes:
|
||||
name: '实体类型'
|
||||
description: '逗号分隔的实体类型,如:PER,ORG,LOC,DATE'
|
||||
type: 'input'
|
||||
defaultVal: 'PER,ORG,LOC,DATE'
|
||||
outputDir:
|
||||
name: '输出目录'
|
||||
description: '算子输出目录(由运行时自动注入)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
174
runtime/ops/annotation/llm_named_entity_recognition/process.py
Normal file
174
runtime/ops/annotation/llm_named_entity_recognition/process.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""LLM 命名实体识别 (NER) 算子。
|
||||
|
||||
基于大语言模型从文本中识别命名实体(人名、地名、机构名等),
|
||||
输出实体列表(含文本片段、实体类型、在原文中的起止位置)。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from datamate.core.base_op import Mapper
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"你是一个专业的命名实体识别(NER)专家。根据给定的实体类型列表,"
|
||||
"从输入文本中识别所有命名实体。\n"
|
||||
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||
)
|
||||
|
||||
USER_PROMPT_TEMPLATE = """请从以下文本中识别所有命名实体。
|
||||
|
||||
实体类型列表:{entity_types}
|
||||
|
||||
实体类型说明:
|
||||
- PER:人名
|
||||
- ORG:组织/机构名
|
||||
- LOC:地点/地名
|
||||
- DATE:日期/时间
|
||||
- EVENT:事件
|
||||
- PRODUCT:产品名
|
||||
- MONEY:金额
|
||||
- PERCENT:百分比
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请以如下 JSON 格式输出(entities 为实体数组,每个实体包含 text、type、start、end 四个字段):
|
||||
{{"entities": [{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}]}}
|
||||
|
||||
注意:
|
||||
- type 必须是实体类型列表中的值之一
|
||||
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
|
||||
- 如果没有找到任何实体,返回 {{"entities": []}}"""
|
||||
|
||||
|
||||
class LLMNamedEntityRecognition(Mapper):
|
||||
"""基于 LLM 的命名实体识别算子。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._model_id: str = kwargs.get("modelId", "")
|
||||
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC,DATE")
|
||||
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||
self._llm_config = None
|
||||
|
||||
def _get_llm_config(self) -> Dict[str, Any]:
|
||||
if self._llm_config is None:
|
||||
from ops.annotation._llm_utils import get_llm_config
|
||||
|
||||
self._llm_config = get_llm_config(self._model_id)
|
||||
return self._llm_config
|
||||
|
||||
@staticmethod
|
||||
def _validate_entities(
|
||||
entities_raw: Any, allowed_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""校验并过滤实体列表,确保类型在允许范围内。"""
|
||||
|
||||
if not isinstance(entities_raw, list):
|
||||
return []
|
||||
|
||||
validated: List[Dict[str, Any]] = []
|
||||
allowed_set = {t.strip().upper() for t in allowed_types}
|
||||
|
||||
for ent in entities_raw:
|
||||
if not isinstance(ent, dict):
|
||||
continue
|
||||
ent_type = str(ent.get("type", "")).strip().upper()
|
||||
ent_text = str(ent.get("text", "")).strip()
|
||||
if not ent_text:
|
||||
continue
|
||||
# 保留匹配的类型,或在类型列表为空时全部保留
|
||||
if allowed_set and ent_type not in allowed_set:
|
||||
continue
|
||||
validated.append(
|
||||
{
|
||||
"text": ent_text,
|
||||
"type": ent_type,
|
||||
"start": ent.get("start"),
|
||||
"end": ent.get("end"),
|
||||
}
|
||||
)
|
||||
return validated
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start = time.time()
|
||||
|
||||
text_path = sample.get(self.text_key)
|
||||
if not text_path or not os.path.exists(str(text_path)):
|
||||
logger.warning("Text file not found: {}", text_path)
|
||||
return sample
|
||||
|
||||
text_path = str(text_path)
|
||||
with open(text_path, "r", encoding="utf-8") as f:
|
||||
text_content = f.read()
|
||||
|
||||
if not text_content.strip():
|
||||
logger.warning("Empty text file: {}", text_path)
|
||||
return sample
|
||||
|
||||
max_chars = 8000
|
||||
truncated = text_content[:max_chars]
|
||||
|
||||
from ops.annotation._llm_utils import call_llm, extract_json
|
||||
|
||||
config = self._get_llm_config()
|
||||
prompt = USER_PROMPT_TEMPLATE.format(
|
||||
entity_types=self._entity_types,
|
||||
text=truncated,
|
||||
)
|
||||
|
||||
allowed_types = [t.strip() for t in self._entity_types.split(",") if t.strip()]
|
||||
|
||||
try:
|
||||
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||
parsed = extract_json(raw_response)
|
||||
entities_raw = parsed.get("entities", []) if isinstance(parsed, dict) else parsed
|
||||
entities = self._validate_entities(entities_raw, allowed_types)
|
||||
except Exception as e:
|
||||
logger.error("LLM NER failed for {}: {}", text_path, e)
|
||||
entities = []
|
||||
|
||||
annotation = {
|
||||
"file": os.path.basename(text_path),
|
||||
"task_type": "ner",
|
||||
"entity_types": self._entity_types,
|
||||
"model": config.get("model_name", ""),
|
||||
"entities": entities,
|
||||
}
|
||||
|
||||
# 写入输出
|
||||
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
data_dir = os.path.join(output_dir, "data")
|
||||
os.makedirs(annotations_dir, exist_ok=True)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||
|
||||
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||
if not os.path.exists(dst_data):
|
||||
shutil.copy2(text_path, dst_data)
|
||||
|
||||
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||
|
||||
sample["detection_count"] = len(entities)
|
||||
sample["annotations_file"] = json_path
|
||||
sample["annotations"] = annotation
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(
|
||||
"NER: {} -> {} entities, Time: {:.2f}s",
|
||||
os.path.basename(text_path),
|
||||
len(entities),
|
||||
elapsed,
|
||||
)
|
||||
return sample
|
||||
10
runtime/ops/annotation/llm_relation_extraction/__init__.py
Normal file
10
runtime/ops/annotation/llm_relation_extraction/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datamate.core.base_op import OPERATORS
|
||||
from .process import LLMRelationExtraction
|
||||
|
||||
OPERATORS.register_module(
|
||||
module_name="LLMRelationExtraction",
|
||||
module_path="ops.annotation.llm_relation_extraction.process",
|
||||
)
|
||||
|
||||
__all__ = ["LLMRelationExtraction"]
|
||||
34
runtime/ops/annotation/llm_relation_extraction/metadata.yml
Normal file
34
runtime/ops/annotation/llm_relation_extraction/metadata.yml
Normal file
@@ -0,0 +1,34 @@
|
||||
name: 'LLM关系抽取'
|
||||
name_en: 'LLM Relation Extraction'
|
||||
description: '基于大语言模型的关系抽取算子,识别实体并抽取实体间关系三元组。'
|
||||
description_en: 'LLM-based relation extraction operator that identifies entities and extracts relation triples.'
|
||||
language: 'python'
|
||||
vendor: 'datamate'
|
||||
raw_id: 'LLMRelationExtraction'
|
||||
version: '1.0.0'
|
||||
types:
|
||||
- 'annotation'
|
||||
modal: 'text'
|
||||
inputs: 'text'
|
||||
outputs: 'text'
|
||||
settings:
|
||||
modelId:
|
||||
name: '模型ID'
|
||||
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
entityTypes:
|
||||
name: '实体类型'
|
||||
description: '逗号分隔的实体类型,如:PER,ORG,LOC'
|
||||
type: 'input'
|
||||
defaultVal: 'PER,ORG,LOC'
|
||||
relationTypes:
|
||||
name: '关系类型'
|
||||
description: '逗号分隔的关系类型,如:属于,位于,创立,工作于'
|
||||
type: 'input'
|
||||
defaultVal: '属于,位于,创立,工作于'
|
||||
outputDir:
|
||||
name: '输出目录'
|
||||
description: '算子输出目录(由运行时自动注入)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
229
runtime/ops/annotation/llm_relation_extraction/process.py
Normal file
229
runtime/ops/annotation/llm_relation_extraction/process.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""LLM 关系抽取算子。
|
||||
|
||||
基于大语言模型从文本中识别实体,并抽取实体之间的关系,
|
||||
输出实体列表和关系三元组(subject, relation, object)。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from datamate.core.base_op import Mapper
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n"
|
||||
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||
)
|
||||
|
||||
USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。
|
||||
|
||||
实体类型列表:{entity_types}
|
||||
关系类型列表:{relation_types}
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请以如下 JSON 格式输出:
|
||||
{{
|
||||
"entities": [
|
||||
{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}
|
||||
],
|
||||
"relations": [
|
||||
{{
|
||||
"subject": {{"text": "主语实体", "type": "PER"}},
|
||||
"relation": "关系类型",
|
||||
"object": {{"text": "宾语实体", "type": "ORG"}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
注意:
|
||||
- 实体的 type 必须是实体类型列表中的值之一
|
||||
- 关系的 relation 必须是关系类型列表中的值之一
|
||||
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
|
||||
- 如果没有找到任何实体或关系,对应数组返回空 []"""
|
||||
|
||||
|
||||
class LLMRelationExtraction(Mapper):
|
||||
"""基于 LLM 的关系抽取算子。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._model_id: str = kwargs.get("modelId", "")
|
||||
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC")
|
||||
self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于")
|
||||
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||
self._llm_config = None
|
||||
|
||||
def _get_llm_config(self) -> Dict[str, Any]:
|
||||
if self._llm_config is None:
|
||||
from ops.annotation._llm_utils import get_llm_config
|
||||
|
||||
self._llm_config = get_llm_config(self._model_id)
|
||||
return self._llm_config
|
||||
|
||||
@staticmethod
|
||||
def _validate_entities(
|
||||
entities_raw: Any, allowed_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(entities_raw, list):
|
||||
return []
|
||||
|
||||
validated: List[Dict[str, Any]] = []
|
||||
allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set()
|
||||
|
||||
for ent in entities_raw:
|
||||
if not isinstance(ent, dict):
|
||||
continue
|
||||
ent_text = str(ent.get("text", "")).strip()
|
||||
ent_type = str(ent.get("type", "")).strip().upper()
|
||||
if not ent_text:
|
||||
continue
|
||||
if allowed_set and ent_type not in allowed_set:
|
||||
continue
|
||||
validated.append(
|
||||
{
|
||||
"text": ent_text,
|
||||
"type": ent_type,
|
||||
"start": ent.get("start"),
|
||||
"end": ent.get("end"),
|
||||
}
|
||||
)
|
||||
return validated
|
||||
|
||||
@staticmethod
|
||||
def _validate_relations(
|
||||
relations_raw: Any, allowed_relation_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(relations_raw, list):
|
||||
return []
|
||||
|
||||
validated: List[Dict[str, Any]] = []
|
||||
allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set()
|
||||
|
||||
for rel in relations_raw:
|
||||
if not isinstance(rel, dict):
|
||||
continue
|
||||
subject = rel.get("subject")
|
||||
relation = str(rel.get("relation", "")).strip()
|
||||
obj = rel.get("object")
|
||||
|
||||
if not isinstance(subject, dict) or not isinstance(obj, dict):
|
||||
continue
|
||||
if not relation:
|
||||
continue
|
||||
if allowed_set and relation not in allowed_set:
|
||||
continue
|
||||
|
||||
validated.append(
|
||||
{
|
||||
"subject": {
|
||||
"text": str(subject.get("text", "")),
|
||||
"type": str(subject.get("type", "")),
|
||||
},
|
||||
"relation": relation,
|
||||
"object": {
|
||||
"text": str(obj.get("text", "")),
|
||||
"type": str(obj.get("type", "")),
|
||||
},
|
||||
}
|
||||
)
|
||||
return validated
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start = time.time()
|
||||
|
||||
text_path = sample.get(self.text_key)
|
||||
if not text_path or not os.path.exists(str(text_path)):
|
||||
logger.warning("Text file not found: {}", text_path)
|
||||
return sample
|
||||
|
||||
text_path = str(text_path)
|
||||
with open(text_path, "r", encoding="utf-8") as f:
|
||||
text_content = f.read()
|
||||
|
||||
if not text_content.strip():
|
||||
logger.warning("Empty text file: {}", text_path)
|
||||
return sample
|
||||
|
||||
max_chars = 8000
|
||||
truncated = text_content[:max_chars]
|
||||
|
||||
from ops.annotation._llm_utils import call_llm, extract_json
|
||||
|
||||
config = self._get_llm_config()
|
||||
prompt = USER_PROMPT_TEMPLATE.format(
|
||||
entity_types=self._entity_types,
|
||||
relation_types=self._relation_types,
|
||||
text=truncated,
|
||||
)
|
||||
|
||||
allowed_entity_types = [
|
||||
t.strip() for t in self._entity_types.split(",") if t.strip()
|
||||
]
|
||||
allowed_relation_types = [
|
||||
t.strip() for t in self._relation_types.split(",") if t.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||
parsed = extract_json(raw_response)
|
||||
if not isinstance(parsed, dict):
|
||||
parsed = {}
|
||||
entities = self._validate_entities(
|
||||
parsed.get("entities", []), allowed_entity_types
|
||||
)
|
||||
relations = self._validate_relations(
|
||||
parsed.get("relations", []), allowed_relation_types
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("LLM relation extraction failed for {}: {}", text_path, e)
|
||||
entities = []
|
||||
relations = []
|
||||
|
||||
annotation = {
|
||||
"file": os.path.basename(text_path),
|
||||
"task_type": "relation_extraction",
|
||||
"entity_types": self._entity_types,
|
||||
"relation_types": self._relation_types,
|
||||
"model": config.get("model_name", ""),
|
||||
"entities": entities,
|
||||
"relations": relations,
|
||||
}
|
||||
|
||||
# 写入输出
|
||||
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
data_dir = os.path.join(output_dir, "data")
|
||||
os.makedirs(annotations_dir, exist_ok=True)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||
|
||||
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||
if not os.path.exists(dst_data):
|
||||
shutil.copy2(text_path, dst_data)
|
||||
|
||||
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||
|
||||
sample["detection_count"] = len(relations)
|
||||
sample["annotations_file"] = json_path
|
||||
sample["annotations"] = annotation
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(
|
||||
"RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s",
|
||||
os.path.basename(text_path),
|
||||
len(entities),
|
||||
len(relations),
|
||||
elapsed,
|
||||
)
|
||||
return sample
|
||||
10
runtime/ops/annotation/llm_text_classification/__init__.py
Normal file
10
runtime/ops/annotation/llm_text_classification/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datamate.core.base_op import OPERATORS
|
||||
from .process import LLMTextClassification
|
||||
|
||||
OPERATORS.register_module(
|
||||
module_name="LLMTextClassification",
|
||||
module_path="ops.annotation.llm_text_classification.process",
|
||||
)
|
||||
|
||||
__all__ = ["LLMTextClassification"]
|
||||
29
runtime/ops/annotation/llm_text_classification/metadata.yml
Normal file
29
runtime/ops/annotation/llm_text_classification/metadata.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
name: 'LLM文本分类'
|
||||
name_en: 'LLM Text Classification'
|
||||
description: '基于大语言模型的文本分类算子,支持自定义类别标签。'
|
||||
description_en: 'LLM-based text classification operator with custom category labels.'
|
||||
language: 'python'
|
||||
vendor: 'datamate'
|
||||
raw_id: 'LLMTextClassification'
|
||||
version: '1.0.0'
|
||||
types:
|
||||
- 'annotation'
|
||||
modal: 'text'
|
||||
inputs: 'text'
|
||||
outputs: 'text'
|
||||
settings:
|
||||
modelId:
|
||||
name: '模型ID'
|
||||
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
categories:
|
||||
name: '分类标签'
|
||||
description: '逗号分隔的分类标签列表,如:正面,负面,中性'
|
||||
type: 'input'
|
||||
defaultVal: '正面,负面,中性'
|
||||
outputDir:
|
||||
name: '输出目录'
|
||||
description: '算子输出目录(由运行时自动注入)。'
|
||||
type: 'input'
|
||||
defaultVal: ''
|
||||
129
runtime/ops/annotation/llm_text_classification/process.py
Normal file
129
runtime/ops/annotation/llm_text_classification/process.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""LLM 文本分类算子。
|
||||
|
||||
基于大语言模型对文本进行分类,输出分类标签、置信度和简短理由。
|
||||
支持通过 categories 参数自定义分类标签体系。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from datamate.core.base_op import Mapper
|
||||
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"你是一个专业的文本分类专家。根据给定的类别列表,对输入文本进行分类。\n"
|
||||
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||
)
|
||||
|
||||
USER_PROMPT_TEMPLATE = """请对以下文本进行分类。
|
||||
|
||||
可选类别:{categories}
|
||||
|
||||
文本内容:
|
||||
{text}
|
||||
|
||||
请以如下 JSON 格式输出(label 必须是可选类别之一,confidence 为 0~1 的浮点数):
|
||||
{{"label": "类别名", "confidence": 0.95, "reasoning": "简短理由"}}"""
|
||||
|
||||
|
||||
class LLMTextClassification(Mapper):
|
||||
"""基于 LLM 的文本分类算子。"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._model_id: str = kwargs.get("modelId", "")
|
||||
self._categories: str = kwargs.get("categories", "正面,负面,中性")
|
||||
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||
self._llm_config = None
|
||||
|
||||
def _get_llm_config(self) -> Dict[str, Any]:
|
||||
if self._llm_config is None:
|
||||
from ops.annotation._llm_utils import get_llm_config
|
||||
|
||||
self._llm_config = get_llm_config(self._model_id)
|
||||
return self._llm_config
|
||||
|
||||
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start = time.time()
|
||||
|
||||
text_path = sample.get(self.text_key)
|
||||
if not text_path or not os.path.exists(str(text_path)):
|
||||
logger.warning("Text file not found: {}", text_path)
|
||||
return sample
|
||||
|
||||
text_path = str(text_path)
|
||||
with open(text_path, "r", encoding="utf-8") as f:
|
||||
text_content = f.read()
|
||||
|
||||
if not text_content.strip():
|
||||
logger.warning("Empty text file: {}", text_path)
|
||||
return sample
|
||||
|
||||
# 截断过长文本以适应 LLM 上下文窗口
|
||||
max_chars = 8000
|
||||
truncated = text_content[:max_chars]
|
||||
|
||||
from ops.annotation._llm_utils import call_llm, extract_json
|
||||
|
||||
config = self._get_llm_config()
|
||||
prompt = USER_PROMPT_TEMPLATE.format(
|
||||
categories=self._categories,
|
||||
text=truncated,
|
||||
)
|
||||
|
||||
try:
|
||||
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||
result = extract_json(raw_response)
|
||||
except Exception as e:
|
||||
logger.error("LLM classification failed for {}: {}", text_path, e)
|
||||
result = {
|
||||
"label": "unknown",
|
||||
"confidence": 0.0,
|
||||
"reasoning": f"LLM call or JSON parse failed: {e}",
|
||||
}
|
||||
|
||||
annotation = {
|
||||
"file": os.path.basename(text_path),
|
||||
"task_type": "text_classification",
|
||||
"categories": self._categories,
|
||||
"model": config.get("model_name", ""),
|
||||
"result": result,
|
||||
}
|
||||
|
||||
# 确定输出目录
|
||||
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||
annotations_dir = os.path.join(output_dir, "annotations")
|
||||
data_dir = os.path.join(output_dir, "data")
|
||||
os.makedirs(annotations_dir, exist_ok=True)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||
|
||||
# 复制原始文本到 data 目录
|
||||
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||
if not os.path.exists(dst_data):
|
||||
shutil.copy2(text_path, dst_data)
|
||||
|
||||
# 写入标注 JSON
|
||||
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||
|
||||
sample["detection_count"] = 1
|
||||
sample["annotations_file"] = json_path
|
||||
sample["annotations"] = annotation
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(
|
||||
"TextClassification: {} -> {}, Time: {:.2f}s",
|
||||
os.path.basename(text_path),
|
||||
result.get("label", "N/A"),
|
||||
elapsed,
|
||||
)
|
||||
return sample
|
||||
@@ -122,7 +122,8 @@ DEFAULT_OUTPUT_ROOT = os.getenv(
|
||||
|
||||
DEFAULT_OPERATOR_WHITELIST = os.getenv(
|
||||
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
|
||||
"ImageObjectDetectionBoundingBox,test_annotation_marker",
|
||||
"ImageObjectDetectionBoundingBox,test_annotation_marker,"
|
||||
"LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user