You've already forked DataMate
Add three new LLM-powered auto-annotation operators: - LLMTextClassification: Text classification using LLM - LLMNamedEntityRecognition: Named entity recognition with type validation - LLMRelationExtraction: Relation extraction with entity and relation type validation Key features: - Load LLM config from t_model_config table via modelId parameter - Lazy loading of LLM configuration on first execute() - Result validation with whitelist checking for entity/relation types - Fault-tolerant: returns empty results on LLM failure instead of throwing - Fully compatible with existing Worker pipeline Files added: - runtime/ops/annotation/_llm_utils.py: Shared LLM utilities - runtime/ops/annotation/llm_text_classification/: Text classification operator - runtime/ops/annotation/llm_named_entity_recognition/: NER operator - runtime/ops/annotation/llm_relation_extraction/: Relation extraction operator Files modified: - runtime/ops/annotation/__init__.py: Register 3 new operators - runtime/python-executor/datamate/auto_annotation_worker.py: Add to Worker whitelist - frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts: Add to frontend whitelist
230 lines
7.6 KiB
Python
230 lines
7.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""LLM 关系抽取算子。
|
|
|
|
基于大语言模型从文本中识别实体,并抽取实体之间的关系,
|
|
输出实体列表和关系三元组(subject, relation, object)。
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
from loguru import logger
|
|
|
|
from datamate.core.base_op import Mapper
|
|
|
|
|
|
SYSTEM_PROMPT = (
|
|
"你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n"
|
|
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
|
)
|
|
|
|
USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。
|
|
|
|
实体类型列表:{entity_types}
|
|
关系类型列表:{relation_types}
|
|
|
|
文本内容:
|
|
{text}
|
|
|
|
请以如下 JSON 格式输出:
|
|
{{
|
|
"entities": [
|
|
{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}
|
|
],
|
|
"relations": [
|
|
{{
|
|
"subject": {{"text": "主语实体", "type": "PER"}},
|
|
"relation": "关系类型",
|
|
"object": {{"text": "宾语实体", "type": "ORG"}}
|
|
}}
|
|
]
|
|
}}
|
|
|
|
注意:
|
|
- 实体的 type 必须是实体类型列表中的值之一
|
|
- 关系的 relation 必须是关系类型列表中的值之一
|
|
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
|
|
- 如果没有找到任何实体或关系,对应数组返回空 []"""
|
|
|
|
|
|
class LLMRelationExtraction(Mapper):
|
|
"""基于 LLM 的关系抽取算子。"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._model_id: str = kwargs.get("modelId", "")
|
|
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC")
|
|
self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于")
|
|
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
|
self._llm_config = None
|
|
|
|
def _get_llm_config(self) -> Dict[str, Any]:
|
|
if self._llm_config is None:
|
|
from ops.annotation._llm_utils import get_llm_config
|
|
|
|
self._llm_config = get_llm_config(self._model_id)
|
|
return self._llm_config
|
|
|
|
@staticmethod
|
|
def _validate_entities(
|
|
entities_raw: Any, allowed_types: List[str]
|
|
) -> List[Dict[str, Any]]:
|
|
if not isinstance(entities_raw, list):
|
|
return []
|
|
|
|
validated: List[Dict[str, Any]] = []
|
|
allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set()
|
|
|
|
for ent in entities_raw:
|
|
if not isinstance(ent, dict):
|
|
continue
|
|
ent_text = str(ent.get("text", "")).strip()
|
|
ent_type = str(ent.get("type", "")).strip().upper()
|
|
if not ent_text:
|
|
continue
|
|
if allowed_set and ent_type not in allowed_set:
|
|
continue
|
|
validated.append(
|
|
{
|
|
"text": ent_text,
|
|
"type": ent_type,
|
|
"start": ent.get("start"),
|
|
"end": ent.get("end"),
|
|
}
|
|
)
|
|
return validated
|
|
|
|
@staticmethod
|
|
def _validate_relations(
|
|
relations_raw: Any, allowed_relation_types: List[str]
|
|
) -> List[Dict[str, Any]]:
|
|
if not isinstance(relations_raw, list):
|
|
return []
|
|
|
|
validated: List[Dict[str, Any]] = []
|
|
allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set()
|
|
|
|
for rel in relations_raw:
|
|
if not isinstance(rel, dict):
|
|
continue
|
|
subject = rel.get("subject")
|
|
relation = str(rel.get("relation", "")).strip()
|
|
obj = rel.get("object")
|
|
|
|
if not isinstance(subject, dict) or not isinstance(obj, dict):
|
|
continue
|
|
if not relation:
|
|
continue
|
|
if allowed_set and relation not in allowed_set:
|
|
continue
|
|
|
|
validated.append(
|
|
{
|
|
"subject": {
|
|
"text": str(subject.get("text", "")),
|
|
"type": str(subject.get("type", "")),
|
|
},
|
|
"relation": relation,
|
|
"object": {
|
|
"text": str(obj.get("text", "")),
|
|
"type": str(obj.get("type", "")),
|
|
},
|
|
}
|
|
)
|
|
return validated
|
|
|
|
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
|
start = time.time()
|
|
|
|
text_path = sample.get(self.text_key)
|
|
if not text_path or not os.path.exists(str(text_path)):
|
|
logger.warning("Text file not found: {}", text_path)
|
|
return sample
|
|
|
|
text_path = str(text_path)
|
|
with open(text_path, "r", encoding="utf-8") as f:
|
|
text_content = f.read()
|
|
|
|
if not text_content.strip():
|
|
logger.warning("Empty text file: {}", text_path)
|
|
return sample
|
|
|
|
max_chars = 8000
|
|
truncated = text_content[:max_chars]
|
|
|
|
from ops.annotation._llm_utils import call_llm, extract_json
|
|
|
|
config = self._get_llm_config()
|
|
prompt = USER_PROMPT_TEMPLATE.format(
|
|
entity_types=self._entity_types,
|
|
relation_types=self._relation_types,
|
|
text=truncated,
|
|
)
|
|
|
|
allowed_entity_types = [
|
|
t.strip() for t in self._entity_types.split(",") if t.strip()
|
|
]
|
|
allowed_relation_types = [
|
|
t.strip() for t in self._relation_types.split(",") if t.strip()
|
|
]
|
|
|
|
try:
|
|
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
|
parsed = extract_json(raw_response)
|
|
if not isinstance(parsed, dict):
|
|
parsed = {}
|
|
entities = self._validate_entities(
|
|
parsed.get("entities", []), allowed_entity_types
|
|
)
|
|
relations = self._validate_relations(
|
|
parsed.get("relations", []), allowed_relation_types
|
|
)
|
|
except Exception as e:
|
|
logger.error("LLM relation extraction failed for {}: {}", text_path, e)
|
|
entities = []
|
|
relations = []
|
|
|
|
annotation = {
|
|
"file": os.path.basename(text_path),
|
|
"task_type": "relation_extraction",
|
|
"entity_types": self._entity_types,
|
|
"relation_types": self._relation_types,
|
|
"model": config.get("model_name", ""),
|
|
"entities": entities,
|
|
"relations": relations,
|
|
}
|
|
|
|
# 写入输出
|
|
output_dir = self._output_dir or os.path.dirname(text_path)
|
|
annotations_dir = os.path.join(output_dir, "annotations")
|
|
data_dir = os.path.join(output_dir, "data")
|
|
os.makedirs(annotations_dir, exist_ok=True)
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
|
|
|
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
|
if not os.path.exists(dst_data):
|
|
shutil.copy2(text_path, dst_data)
|
|
|
|
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
|
|
|
sample["detection_count"] = len(relations)
|
|
sample["annotations_file"] = json_path
|
|
sample["annotations"] = annotation
|
|
|
|
elapsed = time.time() - start
|
|
logger.info(
|
|
"RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s",
|
|
os.path.basename(text_path),
|
|
len(entities),
|
|
len(relations),
|
|
elapsed,
|
|
)
|
|
return sample
|