# -*- coding: utf-8 -*- """LLM 命名实体识别 (NER) 算子。 基于大语言模型从文本中识别命名实体(人名、地名、机构名等), 输出实体列表(含文本片段、实体类型、在原文中的起止位置)。 """ import json import os import shutil import time from typing import Any, Dict, List from loguru import logger from datamate.core.base_op import Mapper SYSTEM_PROMPT = ( "你是一个专业的命名实体识别(NER)专家。根据给定的实体类型列表," "从输入文本中识别所有命名实体。\n" "你必须严格输出 JSON 格式,不要输出任何其他内容。" ) USER_PROMPT_TEMPLATE = """请从以下文本中识别所有命名实体。 实体类型列表:{entity_types} 实体类型说明: - PER:人名 - ORG:组织/机构名 - LOC:地点/地名 - DATE:日期/时间 - EVENT:事件 - PRODUCT:产品名 - MONEY:金额 - PERCENT:百分比 文本内容: {text} 请以如下 JSON 格式输出(entities 为实体数组,每个实体包含 text、type、start、end 四个字段): {{"entities": [{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}]}} 注意: - type 必须是实体类型列表中的值之一 - start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开) - 如果没有找到任何实体,返回 {{"entities": []}}""" class LLMNamedEntityRecognition(Mapper): """基于 LLM 的命名实体识别算子。""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._model_id: str = kwargs.get("modelId", "") self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC,DATE") self._output_dir: str = kwargs.get("outputDir", "") or "" self._llm_config = None def _get_llm_config(self) -> Dict[str, Any]: if self._llm_config is None: from ops.annotation._llm_utils import get_llm_config self._llm_config = get_llm_config(self._model_id) return self._llm_config @staticmethod def _validate_entities( entities_raw: Any, allowed_types: List[str] ) -> List[Dict[str, Any]]: """校验并过滤实体列表,确保类型在允许范围内。""" if not isinstance(entities_raw, list): return [] validated: List[Dict[str, Any]] = [] allowed_set = {t.strip().upper() for t in allowed_types} for ent in entities_raw: if not isinstance(ent, dict): continue ent_type = str(ent.get("type", "")).strip().upper() ent_text = str(ent.get("text", "")).strip() if not ent_text: continue # 保留匹配的类型,或在类型列表为空时全部保留 if allowed_set and ent_type not in allowed_set: continue validated.append( { "text": ent_text, "type": ent_type, "start": ent.get("start"), "end": ent.get("end"), } ) return validated def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: start = time.time() text_path = sample.get(self.text_key) if not text_path or not os.path.exists(str(text_path)): logger.warning("Text file not found: {}", text_path) return sample text_path = str(text_path) with open(text_path, "r", encoding="utf-8") as f: text_content = f.read() if not text_content.strip(): logger.warning("Empty text file: {}", text_path) return sample max_chars = 8000 truncated = text_content[:max_chars] from ops.annotation._llm_utils import call_llm, extract_json config = self._get_llm_config() prompt = USER_PROMPT_TEMPLATE.format( entity_types=self._entity_types, text=truncated, ) allowed_types = [t.strip() for t in self._entity_types.split(",") if t.strip()] try: raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) parsed = extract_json(raw_response) entities_raw = parsed.get("entities", []) if isinstance(parsed, dict) else parsed entities = self._validate_entities(entities_raw, allowed_types) except Exception as e: logger.error("LLM NER failed for {}: {}", text_path, e) entities = [] annotation = { "file": os.path.basename(text_path), "task_type": "ner", "entity_types": self._entity_types, "model": config.get("model_name", ""), "entities": entities, } # 写入输出 output_dir = self._output_dir or os.path.dirname(text_path) annotations_dir = os.path.join(output_dir, "annotations") data_dir = os.path.join(output_dir, "data") os.makedirs(annotations_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True) base_name = os.path.splitext(os.path.basename(text_path))[0] dst_data = os.path.join(data_dir, os.path.basename(text_path)) if not os.path.exists(dst_data): shutil.copy2(text_path, dst_data) json_path = os.path.join(annotations_dir, f"{base_name}.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(annotation, f, indent=2, ensure_ascii=False) sample["detection_count"] = len(entities) sample["annotations_file"] = json_path sample["annotations"] = annotation elapsed = time.time() - start logger.info( "NER: {} -> {} entities, Time: {:.2f}s", os.path.basename(text_path), len(entities), elapsed, ) return sample