diff --git a/frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts b/frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts index ec07c45..0df429c 100644 --- a/frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts +++ b/frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts @@ -22,6 +22,9 @@ type CategoryGroup = { const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([ "ImageObjectDetectionBoundingBox", "test_annotation_marker", + "LLMTextClassification", + "LLMNamedEntityRecognition", + "LLMRelationExtraction", ]); const ensureArray = (value: unknown): string[] => { diff --git a/runtime/ops/annotation/__init__.py b/runtime/ops/annotation/__init__.py index 7df3288..2a430e3 100644 --- a/runtime/ops/annotation/__init__.py +++ b/runtime/ops/annotation/__init__.py @@ -1,10 +1,16 @@ # -*- coding: utf-8 -*- -"""Annotation-related operators (e.g. YOLO detection).""" +"""Annotation-related operators (e.g. YOLO detection, LLM-based NLP annotation).""" from . import image_object_detection_bounding_box from . import test_annotation_marker +from . import llm_text_classification +from . import llm_named_entity_recognition +from . import llm_relation_extraction __all__ = [ "image_object_detection_bounding_box", "test_annotation_marker", + "llm_text_classification", + "llm_named_entity_recognition", + "llm_relation_extraction", ] diff --git a/runtime/ops/annotation/_llm_utils.py b/runtime/ops/annotation/_llm_utils.py new file mode 100644 index 0000000..5d5a1f4 --- /dev/null +++ b/runtime/ops/annotation/_llm_utils.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +"""LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。 + +提供三项核心能力: +1. 从 t_model_config 表加载模型配置(按 ID / 按默认) +2. 调用 OpenAI 兼容 chat/completions API +3. 从 LLM 原始输出中提取 JSON +""" + +import json +import re +from typing import Any, Dict + +from loguru import logger + + +# --------------------------------------------------------------------------- +# 模型配置加载 +# --------------------------------------------------------------------------- + +def load_model_config(model_id: str) -> Dict[str, Any]: + """根据 model_id 从 t_model_config 读取已启用的模型配置。""" + + from datamate.sql_manager.sql_manager import SQLManager + from sqlalchemy import text as sql_text + + sql = sql_text( + """ + SELECT model_name, provider, base_url, api_key, type + FROM t_model_config + WHERE id = :model_id AND is_enabled = 1 + LIMIT 1 + """ + ) + with SQLManager.create_connect() as conn: + row = conn.execute(sql, {"model_id": model_id}).fetchone() + if not row: + raise ValueError(f"Model config not found or disabled: {model_id}") + return dict(row._mapping) + + +def load_default_model_config() -> Dict[str, Any]: + """加载默认的 chat 模型配置(is_default=1 且 type='chat')。""" + + from datamate.sql_manager.sql_manager import SQLManager + from sqlalchemy import text as sql_text + + sql = sql_text( + """ + SELECT id, model_name, provider, base_url, api_key, type + FROM t_model_config + WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat' + LIMIT 1 + """ + ) + with SQLManager.create_connect() as conn: + row = conn.execute(sql).fetchone() + if not row: + raise ValueError("No default chat model configured in t_model_config") + return dict(row._mapping) + + +def get_llm_config(model_id: str = "") -> Dict[str, Any]: + """优先按 model_id 加载,未提供则加载默认模型。""" + + if model_id: + return load_model_config(model_id) + return load_default_model_config() + + +# --------------------------------------------------------------------------- +# LLM 调用 +# --------------------------------------------------------------------------- + +def call_llm( + config: Dict[str, Any], + prompt: str, + system_prompt: str = "", + temperature: float = 0.1, + max_retries: int = 2, +) -> str: + """调用 OpenAI 兼容的 chat/completions API 并返回文本内容。""" + + import requests as http_requests + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + headers: Dict[str, str] = {"Content-Type": "application/json"} + api_key = config.get("api_key", "") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + base_url = str(config["base_url"]).rstrip("/") + # 兼容 base_url 已包含 /v1 或不包含的情况 + if not base_url.endswith("/chat/completions"): + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" + url = f"{base_url}/chat/completions" + else: + url = base_url + + body = { + "model": config["model_name"], + "messages": messages, + "temperature": temperature, + } + + last_err = None + for attempt in range(max_retries + 1): + try: + resp = http_requests.post(url, json=body, headers=headers, timeout=120) + resp.raise_for_status() + content = resp.json()["choices"][0]["message"]["content"] + return content + except Exception as e: + last_err = e + logger.warning( + "LLM call attempt {}/{} failed: {}", + attempt + 1, + max_retries + 1, + e, + ) + + raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}") + + +# --------------------------------------------------------------------------- +# JSON 提取 +# --------------------------------------------------------------------------- + +def extract_json(raw: str) -> Any: + """从 LLM 原始输出中提取 JSON 对象/数组。 + + 处理常见干扰:Markdown 代码块、 标签、前后说明文字。 + """ + + if not raw: + raise ValueError("Empty LLM response") + + # 1. 去除 ... 等思考标签 + thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"] + for tag in thought_tags: + raw = re.sub(rf"<{tag}>[\s\S]*?", "", raw, flags=re.IGNORECASE) + + # 2. 去除 Markdown 代码块标记 + raw = re.sub(r"```(?:json)?\s*", "", raw) + raw = raw.replace("```", "") + + # 3. 定位第一个 { 或 [ 到最后一个 } 或 ] + start = None + end = None + for i, ch in enumerate(raw): + if ch in "{[": + start = i + break + for i in range(len(raw) - 1, -1, -1): + if raw[i] in "]}": + end = i + 1 + break + + if start is not None and end is not None and start < end: + return json.loads(raw[start:end]) + + # 兜底:直接尝试解析 + return json.loads(raw.strip()) diff --git a/runtime/ops/annotation/llm_named_entity_recognition/__init__.py b/runtime/ops/annotation/llm_named_entity_recognition/__init__.py new file mode 100644 index 0000000..8dd9dd4 --- /dev/null +++ b/runtime/ops/annotation/llm_named_entity_recognition/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from datamate.core.base_op import OPERATORS +from .process import LLMNamedEntityRecognition + +OPERATORS.register_module( + module_name="LLMNamedEntityRecognition", + module_path="ops.annotation.llm_named_entity_recognition.process", +) + +__all__ = ["LLMNamedEntityRecognition"] diff --git a/runtime/ops/annotation/llm_named_entity_recognition/metadata.yml b/runtime/ops/annotation/llm_named_entity_recognition/metadata.yml new file mode 100644 index 0000000..c9b51ac --- /dev/null +++ b/runtime/ops/annotation/llm_named_entity_recognition/metadata.yml @@ -0,0 +1,29 @@ +name: 'LLM命名实体识别' +name_en: 'LLM Named Entity Recognition' +description: '基于大语言模型的命名实体识别算子,支持自定义实体类型。' +description_en: 'LLM-based NER operator with custom entity types.' +language: 'python' +vendor: 'datamate' +raw_id: 'LLMNamedEntityRecognition' +version: '1.0.0' +types: + - 'annotation' +modal: 'text' +inputs: 'text' +outputs: 'text' +settings: + modelId: + name: '模型ID' + description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。' + type: 'input' + defaultVal: '' + entityTypes: + name: '实体类型' + description: '逗号分隔的实体类型,如:PER,ORG,LOC,DATE' + type: 'input' + defaultVal: 'PER,ORG,LOC,DATE' + outputDir: + name: '输出目录' + description: '算子输出目录(由运行时自动注入)。' + type: 'input' + defaultVal: '' diff --git a/runtime/ops/annotation/llm_named_entity_recognition/process.py b/runtime/ops/annotation/llm_named_entity_recognition/process.py new file mode 100644 index 0000000..badb950 --- /dev/null +++ b/runtime/ops/annotation/llm_named_entity_recognition/process.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +"""LLM 命名实体识别 (NER) 算子。 + +基于大语言模型从文本中识别命名实体(人名、地名、机构名等), +输出实体列表(含文本片段、实体类型、在原文中的起止位置)。 +""" + +import json +import os +import shutil +import time +from typing import Any, Dict, List + +from loguru import logger + +from datamate.core.base_op import Mapper + + +SYSTEM_PROMPT = ( + "你是一个专业的命名实体识别(NER)专家。根据给定的实体类型列表," + "从输入文本中识别所有命名实体。\n" + "你必须严格输出 JSON 格式,不要输出任何其他内容。" +) + +USER_PROMPT_TEMPLATE = """请从以下文本中识别所有命名实体。 + +实体类型列表:{entity_types} + +实体类型说明: +- PER:人名 +- ORG:组织/机构名 +- LOC:地点/地名 +- DATE:日期/时间 +- EVENT:事件 +- PRODUCT:产品名 +- MONEY:金额 +- PERCENT:百分比 + +文本内容: +{text} + +请以如下 JSON 格式输出(entities 为实体数组,每个实体包含 text、type、start、end 四个字段): +{{"entities": [{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}]}} + +注意: +- type 必须是实体类型列表中的值之一 +- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开) +- 如果没有找到任何实体,返回 {{"entities": []}}""" + + +class LLMNamedEntityRecognition(Mapper): + """基于 LLM 的命名实体识别算子。""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._model_id: str = kwargs.get("modelId", "") + self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC,DATE") + self._output_dir: str = kwargs.get("outputDir", "") or "" + self._llm_config = None + + def _get_llm_config(self) -> Dict[str, Any]: + if self._llm_config is None: + from ops.annotation._llm_utils import get_llm_config + + self._llm_config = get_llm_config(self._model_id) + return self._llm_config + + @staticmethod + def _validate_entities( + entities_raw: Any, allowed_types: List[str] + ) -> List[Dict[str, Any]]: + """校验并过滤实体列表,确保类型在允许范围内。""" + + if not isinstance(entities_raw, list): + return [] + + validated: List[Dict[str, Any]] = [] + allowed_set = {t.strip().upper() for t in allowed_types} + + for ent in entities_raw: + if not isinstance(ent, dict): + continue + ent_type = str(ent.get("type", "")).strip().upper() + ent_text = str(ent.get("text", "")).strip() + if not ent_text: + continue + # 保留匹配的类型,或在类型列表为空时全部保留 + if allowed_set and ent_type not in allowed_set: + continue + validated.append( + { + "text": ent_text, + "type": ent_type, + "start": ent.get("start"), + "end": ent.get("end"), + } + ) + return validated + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + + text_path = sample.get(self.text_key) + if not text_path or not os.path.exists(str(text_path)): + logger.warning("Text file not found: {}", text_path) + return sample + + text_path = str(text_path) + with open(text_path, "r", encoding="utf-8") as f: + text_content = f.read() + + if not text_content.strip(): + logger.warning("Empty text file: {}", text_path) + return sample + + max_chars = 8000 + truncated = text_content[:max_chars] + + from ops.annotation._llm_utils import call_llm, extract_json + + config = self._get_llm_config() + prompt = USER_PROMPT_TEMPLATE.format( + entity_types=self._entity_types, + text=truncated, + ) + + allowed_types = [t.strip() for t in self._entity_types.split(",") if t.strip()] + + try: + raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) + parsed = extract_json(raw_response) + entities_raw = parsed.get("entities", []) if isinstance(parsed, dict) else parsed + entities = self._validate_entities(entities_raw, allowed_types) + except Exception as e: + logger.error("LLM NER failed for {}: {}", text_path, e) + entities = [] + + annotation = { + "file": os.path.basename(text_path), + "task_type": "ner", + "entity_types": self._entity_types, + "model": config.get("model_name", ""), + "entities": entities, + } + + # 写入输出 + output_dir = self._output_dir or os.path.dirname(text_path) + annotations_dir = os.path.join(output_dir, "annotations") + data_dir = os.path.join(output_dir, "data") + os.makedirs(annotations_dir, exist_ok=True) + os.makedirs(data_dir, exist_ok=True) + + base_name = os.path.splitext(os.path.basename(text_path))[0] + + dst_data = os.path.join(data_dir, os.path.basename(text_path)) + if not os.path.exists(dst_data): + shutil.copy2(text_path, dst_data) + + json_path = os.path.join(annotations_dir, f"{base_name}.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(annotation, f, indent=2, ensure_ascii=False) + + sample["detection_count"] = len(entities) + sample["annotations_file"] = json_path + sample["annotations"] = annotation + + elapsed = time.time() - start + logger.info( + "NER: {} -> {} entities, Time: {:.2f}s", + os.path.basename(text_path), + len(entities), + elapsed, + ) + return sample diff --git a/runtime/ops/annotation/llm_relation_extraction/__init__.py b/runtime/ops/annotation/llm_relation_extraction/__init__.py new file mode 100644 index 0000000..204ca41 --- /dev/null +++ b/runtime/ops/annotation/llm_relation_extraction/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from datamate.core.base_op import OPERATORS +from .process import LLMRelationExtraction + +OPERATORS.register_module( + module_name="LLMRelationExtraction", + module_path="ops.annotation.llm_relation_extraction.process", +) + +__all__ = ["LLMRelationExtraction"] diff --git a/runtime/ops/annotation/llm_relation_extraction/metadata.yml b/runtime/ops/annotation/llm_relation_extraction/metadata.yml new file mode 100644 index 0000000..aa331eb --- /dev/null +++ b/runtime/ops/annotation/llm_relation_extraction/metadata.yml @@ -0,0 +1,34 @@ +name: 'LLM关系抽取' +name_en: 'LLM Relation Extraction' +description: '基于大语言模型的关系抽取算子,识别实体并抽取实体间关系三元组。' +description_en: 'LLM-based relation extraction operator that identifies entities and extracts relation triples.' +language: 'python' +vendor: 'datamate' +raw_id: 'LLMRelationExtraction' +version: '1.0.0' +types: + - 'annotation' +modal: 'text' +inputs: 'text' +outputs: 'text' +settings: + modelId: + name: '模型ID' + description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。' + type: 'input' + defaultVal: '' + entityTypes: + name: '实体类型' + description: '逗号分隔的实体类型,如:PER,ORG,LOC' + type: 'input' + defaultVal: 'PER,ORG,LOC' + relationTypes: + name: '关系类型' + description: '逗号分隔的关系类型,如:属于,位于,创立,工作于' + type: 'input' + defaultVal: '属于,位于,创立,工作于' + outputDir: + name: '输出目录' + description: '算子输出目录(由运行时自动注入)。' + type: 'input' + defaultVal: '' diff --git a/runtime/ops/annotation/llm_relation_extraction/process.py b/runtime/ops/annotation/llm_relation_extraction/process.py new file mode 100644 index 0000000..e8ad194 --- /dev/null +++ b/runtime/ops/annotation/llm_relation_extraction/process.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +"""LLM 关系抽取算子。 + +基于大语言模型从文本中识别实体,并抽取实体之间的关系, +输出实体列表和关系三元组(subject, relation, object)。 +""" + +import json +import os +import shutil +import time +from typing import Any, Dict, List + +from loguru import logger + +from datamate.core.base_op import Mapper + + +SYSTEM_PROMPT = ( + "你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n" + "你必须严格输出 JSON 格式,不要输出任何其他内容。" +) + +USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。 + +实体类型列表:{entity_types} +关系类型列表:{relation_types} + +文本内容: +{text} + +请以如下 JSON 格式输出: +{{ + "entities": [ + {{"text": "实体文本", "type": "PER", "start": 0, "end": 3}} + ], + "relations": [ + {{ + "subject": {{"text": "主语实体", "type": "PER"}}, + "relation": "关系类型", + "object": {{"text": "宾语实体", "type": "ORG"}} + }} + ] +}} + +注意: +- 实体的 type 必须是实体类型列表中的值之一 +- 关系的 relation 必须是关系类型列表中的值之一 +- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开) +- 如果没有找到任何实体或关系,对应数组返回空 []""" + + +class LLMRelationExtraction(Mapper): + """基于 LLM 的关系抽取算子。""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._model_id: str = kwargs.get("modelId", "") + self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC") + self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于") + self._output_dir: str = kwargs.get("outputDir", "") or "" + self._llm_config = None + + def _get_llm_config(self) -> Dict[str, Any]: + if self._llm_config is None: + from ops.annotation._llm_utils import get_llm_config + + self._llm_config = get_llm_config(self._model_id) + return self._llm_config + + @staticmethod + def _validate_entities( + entities_raw: Any, allowed_types: List[str] + ) -> List[Dict[str, Any]]: + if not isinstance(entities_raw, list): + return [] + + validated: List[Dict[str, Any]] = [] + allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set() + + for ent in entities_raw: + if not isinstance(ent, dict): + continue + ent_text = str(ent.get("text", "")).strip() + ent_type = str(ent.get("type", "")).strip().upper() + if not ent_text: + continue + if allowed_set and ent_type not in allowed_set: + continue + validated.append( + { + "text": ent_text, + "type": ent_type, + "start": ent.get("start"), + "end": ent.get("end"), + } + ) + return validated + + @staticmethod + def _validate_relations( + relations_raw: Any, allowed_relation_types: List[str] + ) -> List[Dict[str, Any]]: + if not isinstance(relations_raw, list): + return [] + + validated: List[Dict[str, Any]] = [] + allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set() + + for rel in relations_raw: + if not isinstance(rel, dict): + continue + subject = rel.get("subject") + relation = str(rel.get("relation", "")).strip() + obj = rel.get("object") + + if not isinstance(subject, dict) or not isinstance(obj, dict): + continue + if not relation: + continue + if allowed_set and relation not in allowed_set: + continue + + validated.append( + { + "subject": { + "text": str(subject.get("text", "")), + "type": str(subject.get("type", "")), + }, + "relation": relation, + "object": { + "text": str(obj.get("text", "")), + "type": str(obj.get("type", "")), + }, + } + ) + return validated + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + + text_path = sample.get(self.text_key) + if not text_path or not os.path.exists(str(text_path)): + logger.warning("Text file not found: {}", text_path) + return sample + + text_path = str(text_path) + with open(text_path, "r", encoding="utf-8") as f: + text_content = f.read() + + if not text_content.strip(): + logger.warning("Empty text file: {}", text_path) + return sample + + max_chars = 8000 + truncated = text_content[:max_chars] + + from ops.annotation._llm_utils import call_llm, extract_json + + config = self._get_llm_config() + prompt = USER_PROMPT_TEMPLATE.format( + entity_types=self._entity_types, + relation_types=self._relation_types, + text=truncated, + ) + + allowed_entity_types = [ + t.strip() for t in self._entity_types.split(",") if t.strip() + ] + allowed_relation_types = [ + t.strip() for t in self._relation_types.split(",") if t.strip() + ] + + try: + raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) + parsed = extract_json(raw_response) + if not isinstance(parsed, dict): + parsed = {} + entities = self._validate_entities( + parsed.get("entities", []), allowed_entity_types + ) + relations = self._validate_relations( + parsed.get("relations", []), allowed_relation_types + ) + except Exception as e: + logger.error("LLM relation extraction failed for {}: {}", text_path, e) + entities = [] + relations = [] + + annotation = { + "file": os.path.basename(text_path), + "task_type": "relation_extraction", + "entity_types": self._entity_types, + "relation_types": self._relation_types, + "model": config.get("model_name", ""), + "entities": entities, + "relations": relations, + } + + # 写入输出 + output_dir = self._output_dir or os.path.dirname(text_path) + annotations_dir = os.path.join(output_dir, "annotations") + data_dir = os.path.join(output_dir, "data") + os.makedirs(annotations_dir, exist_ok=True) + os.makedirs(data_dir, exist_ok=True) + + base_name = os.path.splitext(os.path.basename(text_path))[0] + + dst_data = os.path.join(data_dir, os.path.basename(text_path)) + if not os.path.exists(dst_data): + shutil.copy2(text_path, dst_data) + + json_path = os.path.join(annotations_dir, f"{base_name}.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(annotation, f, indent=2, ensure_ascii=False) + + sample["detection_count"] = len(relations) + sample["annotations_file"] = json_path + sample["annotations"] = annotation + + elapsed = time.time() - start + logger.info( + "RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s", + os.path.basename(text_path), + len(entities), + len(relations), + elapsed, + ) + return sample diff --git a/runtime/ops/annotation/llm_text_classification/__init__.py b/runtime/ops/annotation/llm_text_classification/__init__.py new file mode 100644 index 0000000..ad87729 --- /dev/null +++ b/runtime/ops/annotation/llm_text_classification/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from datamate.core.base_op import OPERATORS +from .process import LLMTextClassification + +OPERATORS.register_module( + module_name="LLMTextClassification", + module_path="ops.annotation.llm_text_classification.process", +) + +__all__ = ["LLMTextClassification"] diff --git a/runtime/ops/annotation/llm_text_classification/metadata.yml b/runtime/ops/annotation/llm_text_classification/metadata.yml new file mode 100644 index 0000000..a5cc844 --- /dev/null +++ b/runtime/ops/annotation/llm_text_classification/metadata.yml @@ -0,0 +1,29 @@ +name: 'LLM文本分类' +name_en: 'LLM Text Classification' +description: '基于大语言模型的文本分类算子,支持自定义类别标签。' +description_en: 'LLM-based text classification operator with custom category labels.' +language: 'python' +vendor: 'datamate' +raw_id: 'LLMTextClassification' +version: '1.0.0' +types: + - 'annotation' +modal: 'text' +inputs: 'text' +outputs: 'text' +settings: + modelId: + name: '模型ID' + description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。' + type: 'input' + defaultVal: '' + categories: + name: '分类标签' + description: '逗号分隔的分类标签列表,如:正面,负面,中性' + type: 'input' + defaultVal: '正面,负面,中性' + outputDir: + name: '输出目录' + description: '算子输出目录(由运行时自动注入)。' + type: 'input' + defaultVal: '' diff --git a/runtime/ops/annotation/llm_text_classification/process.py b/runtime/ops/annotation/llm_text_classification/process.py new file mode 100644 index 0000000..15dd524 --- /dev/null +++ b/runtime/ops/annotation/llm_text_classification/process.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +"""LLM 文本分类算子。 + +基于大语言模型对文本进行分类,输出分类标签、置信度和简短理由。 +支持通过 categories 参数自定义分类标签体系。 +""" + +import json +import os +import shutil +import time +from typing import Any, Dict + +from loguru import logger + +from datamate.core.base_op import Mapper + + +SYSTEM_PROMPT = ( + "你是一个专业的文本分类专家。根据给定的类别列表,对输入文本进行分类。\n" + "你必须严格输出 JSON 格式,不要输出任何其他内容。" +) + +USER_PROMPT_TEMPLATE = """请对以下文本进行分类。 + +可选类别:{categories} + +文本内容: +{text} + +请以如下 JSON 格式输出(label 必须是可选类别之一,confidence 为 0~1 的浮点数): +{{"label": "类别名", "confidence": 0.95, "reasoning": "简短理由"}}""" + + +class LLMTextClassification(Mapper): + """基于 LLM 的文本分类算子。""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._model_id: str = kwargs.get("modelId", "") + self._categories: str = kwargs.get("categories", "正面,负面,中性") + self._output_dir: str = kwargs.get("outputDir", "") or "" + self._llm_config = None + + def _get_llm_config(self) -> Dict[str, Any]: + if self._llm_config is None: + from ops.annotation._llm_utils import get_llm_config + + self._llm_config = get_llm_config(self._model_id) + return self._llm_config + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + + text_path = sample.get(self.text_key) + if not text_path or not os.path.exists(str(text_path)): + logger.warning("Text file not found: {}", text_path) + return sample + + text_path = str(text_path) + with open(text_path, "r", encoding="utf-8") as f: + text_content = f.read() + + if not text_content.strip(): + logger.warning("Empty text file: {}", text_path) + return sample + + # 截断过长文本以适应 LLM 上下文窗口 + max_chars = 8000 + truncated = text_content[:max_chars] + + from ops.annotation._llm_utils import call_llm, extract_json + + config = self._get_llm_config() + prompt = USER_PROMPT_TEMPLATE.format( + categories=self._categories, + text=truncated, + ) + + try: + raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) + result = extract_json(raw_response) + except Exception as e: + logger.error("LLM classification failed for {}: {}", text_path, e) + result = { + "label": "unknown", + "confidence": 0.0, + "reasoning": f"LLM call or JSON parse failed: {e}", + } + + annotation = { + "file": os.path.basename(text_path), + "task_type": "text_classification", + "categories": self._categories, + "model": config.get("model_name", ""), + "result": result, + } + + # 确定输出目录 + output_dir = self._output_dir or os.path.dirname(text_path) + annotations_dir = os.path.join(output_dir, "annotations") + data_dir = os.path.join(output_dir, "data") + os.makedirs(annotations_dir, exist_ok=True) + os.makedirs(data_dir, exist_ok=True) + + base_name = os.path.splitext(os.path.basename(text_path))[0] + + # 复制原始文本到 data 目录 + dst_data = os.path.join(data_dir, os.path.basename(text_path)) + if not os.path.exists(dst_data): + shutil.copy2(text_path, dst_data) + + # 写入标注 JSON + json_path = os.path.join(annotations_dir, f"{base_name}.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(annotation, f, indent=2, ensure_ascii=False) + + sample["detection_count"] = 1 + sample["annotations_file"] = json_path + sample["annotations"] = annotation + + elapsed = time.time() - start + logger.info( + "TextClassification: {} -> {}, Time: {:.2f}s", + os.path.basename(text_path), + result.get("label", "N/A"), + elapsed, + ) + return sample diff --git a/runtime/python-executor/datamate/auto_annotation_worker.py b/runtime/python-executor/datamate/auto_annotation_worker.py index fd4a613..6487cd0 100644 --- a/runtime/python-executor/datamate/auto_annotation_worker.py +++ b/runtime/python-executor/datamate/auto_annotation_worker.py @@ -122,7 +122,8 @@ DEFAULT_OUTPUT_ROOT = os.getenv( DEFAULT_OPERATOR_WHITELIST = os.getenv( "AUTO_ANNOTATION_OPERATOR_WHITELIST", - "ImageObjectDetectionBoundingBox,test_annotation_marker", + "ImageObjectDetectionBoundingBox,test_annotation_marker," + "LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction", )