# -*- coding: utf-8 -*- """LLM 文本分类算子。 基于大语言模型对文本进行分类,输出分类标签、置信度和简短理由。 支持通过 categories 参数自定义分类标签体系。 """ import json import os import shutil import time from typing import Any, Dict from loguru import logger from datamate.core.base_op import Mapper SYSTEM_PROMPT = ( "你是一个专业的文本分类专家。根据给定的类别列表,对输入文本进行分类。\n" "你必须严格输出 JSON 格式,不要输出任何其他内容。" ) USER_PROMPT_TEMPLATE = """请对以下文本进行分类。 可选类别:{categories} 文本内容: {text} 请以如下 JSON 格式输出(label 必须是可选类别之一,confidence 为 0~1 的浮点数): {{"label": "类别名", "confidence": 0.95, "reasoning": "简短理由"}}""" class LLMTextClassification(Mapper): """基于 LLM 的文本分类算子。""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._model_id: str = kwargs.get("modelId", "") self._categories: str = kwargs.get("categories", "正面,负面,中性") self._output_dir: str = kwargs.get("outputDir", "") or "" self._llm_config = None def _get_llm_config(self) -> Dict[str, Any]: if self._llm_config is None: from ops.annotation._llm_utils import get_llm_config self._llm_config = get_llm_config(self._model_id) return self._llm_config def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: start = time.time() text_path = sample.get(self.text_key) if not text_path or not os.path.exists(str(text_path)): logger.warning("Text file not found: {}", text_path) return sample text_path = str(text_path) with open(text_path, "r", encoding="utf-8") as f: text_content = f.read() if not text_content.strip(): logger.warning("Empty text file: {}", text_path) return sample # 截断过长文本以适应 LLM 上下文窗口 max_chars = 8000 truncated = text_content[:max_chars] from ops.annotation._llm_utils import call_llm, extract_json config = self._get_llm_config() prompt = USER_PROMPT_TEMPLATE.format( categories=self._categories, text=truncated, ) try: raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) result = extract_json(raw_response) except Exception as e: logger.error("LLM classification failed for {}: {}", text_path, e) result = { "label": "unknown", "confidence": 0.0, "reasoning": f"LLM call or JSON parse failed: {e}", } annotation = { "file": os.path.basename(text_path), "task_type": "text_classification", "categories": self._categories, "model": config.get("model_name", ""), "result": result, } # 确定输出目录 output_dir = self._output_dir or os.path.dirname(text_path) annotations_dir = os.path.join(output_dir, "annotations") data_dir = os.path.join(output_dir, "data") os.makedirs(annotations_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True) base_name = os.path.splitext(os.path.basename(text_path))[0] # 复制原始文本到 data 目录 dst_data = os.path.join(data_dir, os.path.basename(text_path)) if not os.path.exists(dst_data): shutil.copy2(text_path, dst_data) # 写入标注 JSON json_path = os.path.join(annotations_dir, f"{base_name}.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(annotation, f, indent=2, ensure_ascii=False) sample["detection_count"] = 1 sample["annotations_file"] = json_path sample["annotations"] = annotation elapsed = time.time() - start logger.info( "TextClassification: {} -> {}, Time: {:.2f}s", os.path.basename(text_path), result.get("label", "N/A"), elapsed, ) return sample