# -*- coding: utf-8 -*- """LLM 关系抽取算子。 基于大语言模型从文本中识别实体,并抽取实体之间的关系, 输出实体列表和关系三元组(subject, relation, object)。 """ import json import os import shutil import time from typing import Any, Dict, List from loguru import logger from datamate.core.base_op import Mapper SYSTEM_PROMPT = ( "你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n" "你必须严格输出 JSON 格式,不要输出任何其他内容。" ) USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。 实体类型列表:{entity_types} 关系类型列表:{relation_types} 文本内容: {text} 请以如下 JSON 格式输出: {{ "entities": [ {{"text": "实体文本", "type": "PER", "start": 0, "end": 3}} ], "relations": [ {{ "subject": {{"text": "主语实体", "type": "PER"}}, "relation": "关系类型", "object": {{"text": "宾语实体", "type": "ORG"}} }} ] }} 注意: - 实体的 type 必须是实体类型列表中的值之一 - 关系的 relation 必须是关系类型列表中的值之一 - start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开) - 如果没有找到任何实体或关系,对应数组返回空 []""" class LLMRelationExtraction(Mapper): """基于 LLM 的关系抽取算子。""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._model_id: str = kwargs.get("modelId", "") self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC") self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于") self._output_dir: str = kwargs.get("outputDir", "") or "" self._llm_config = None def _get_llm_config(self) -> Dict[str, Any]: if self._llm_config is None: from ops.annotation._llm_utils import get_llm_config self._llm_config = get_llm_config(self._model_id) return self._llm_config @staticmethod def _validate_entities( entities_raw: Any, allowed_types: List[str] ) -> List[Dict[str, Any]]: if not isinstance(entities_raw, list): return [] validated: List[Dict[str, Any]] = [] allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set() for ent in entities_raw: if not isinstance(ent, dict): continue ent_text = str(ent.get("text", "")).strip() ent_type = str(ent.get("type", "")).strip().upper() if not ent_text: continue if allowed_set and ent_type not in allowed_set: continue validated.append( { "text": ent_text, "type": ent_type, "start": ent.get("start"), "end": ent.get("end"), } ) return validated @staticmethod def _validate_relations( relations_raw: Any, allowed_relation_types: List[str] ) -> List[Dict[str, Any]]: if not isinstance(relations_raw, list): return [] validated: List[Dict[str, Any]] = [] allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set() for rel in relations_raw: if not isinstance(rel, dict): continue subject = rel.get("subject") relation = str(rel.get("relation", "")).strip() obj = rel.get("object") if not isinstance(subject, dict) or not isinstance(obj, dict): continue if not relation: continue if allowed_set and relation not in allowed_set: continue validated.append( { "subject": { "text": str(subject.get("text", "")), "type": str(subject.get("type", "")), }, "relation": relation, "object": { "text": str(obj.get("text", "")), "type": str(obj.get("type", "")), }, } ) return validated def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: start = time.time() text_path = sample.get(self.text_key) if not text_path or not os.path.exists(str(text_path)): logger.warning("Text file not found: {}", text_path) return sample text_path = str(text_path) with open(text_path, "r", encoding="utf-8") as f: text_content = f.read() if not text_content.strip(): logger.warning("Empty text file: {}", text_path) return sample max_chars = 8000 truncated = text_content[:max_chars] from ops.annotation._llm_utils import call_llm, extract_json config = self._get_llm_config() prompt = USER_PROMPT_TEMPLATE.format( entity_types=self._entity_types, relation_types=self._relation_types, text=truncated, ) allowed_entity_types = [ t.strip() for t in self._entity_types.split(",") if t.strip() ] allowed_relation_types = [ t.strip() for t in self._relation_types.split(",") if t.strip() ] try: raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT) parsed = extract_json(raw_response) if not isinstance(parsed, dict): parsed = {} entities = self._validate_entities( parsed.get("entities", []), allowed_entity_types ) relations = self._validate_relations( parsed.get("relations", []), allowed_relation_types ) except Exception as e: logger.error("LLM relation extraction failed for {}: {}", text_path, e) entities = [] relations = [] annotation = { "file": os.path.basename(text_path), "task_type": "relation_extraction", "entity_types": self._entity_types, "relation_types": self._relation_types, "model": config.get("model_name", ""), "entities": entities, "relations": relations, } # 写入输出 output_dir = self._output_dir or os.path.dirname(text_path) annotations_dir = os.path.join(output_dir, "annotations") data_dir = os.path.join(output_dir, "data") os.makedirs(annotations_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True) base_name = os.path.splitext(os.path.basename(text_path))[0] dst_data = os.path.join(data_dir, os.path.basename(text_path)) if not os.path.exists(dst_data): shutil.copy2(text_path, dst_data) json_path = os.path.join(annotations_dir, f"{base_name}.json") with open(json_path, "w", encoding="utf-8") as f: json.dump(annotation, f, indent=2, ensure_ascii=False) sample["detection_count"] = len(relations) sample["annotations_file"] = json_path sample["annotations"] = annotation elapsed = time.time() - start logger.info( "RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s", os.path.basename(text_path), len(entities), len(relations), elapsed, ) return sample