You've already forked DataMate
Compare commits
5 Commits
06a7cd9abd
...
f707ce9dae
| Author | SHA1 | Date | |
|---|---|---|---|
| f707ce9dae | |||
| 9988ff00f5 | |||
| 2fbfefdb91 | |||
| dc490f03be | |||
| 49f99527cc |
@@ -22,6 +22,9 @@ type CategoryGroup = {
|
|||||||
const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([
|
const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([
|
||||||
"ImageObjectDetectionBoundingBox",
|
"ImageObjectDetectionBoundingBox",
|
||||||
"test_annotation_marker",
|
"test_annotation_marker",
|
||||||
|
"LLMTextClassification",
|
||||||
|
"LLMNamedEntityRecognition",
|
||||||
|
"LLMRelationExtraction",
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const ensureArray = (value: unknown): string[] => {
|
const ensureArray = (value: unknown): string[] => {
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Annotation-related operators (e.g. YOLO detection)."""
|
"""Annotation-related operators (e.g. YOLO detection, LLM-based NLP annotation)."""
|
||||||
|
|
||||||
from . import image_object_detection_bounding_box
|
from . import image_object_detection_bounding_box
|
||||||
from . import test_annotation_marker
|
from . import test_annotation_marker
|
||||||
|
from . import llm_text_classification
|
||||||
|
from . import llm_named_entity_recognition
|
||||||
|
from . import llm_relation_extraction
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"image_object_detection_bounding_box",
|
"image_object_detection_bounding_box",
|
||||||
"test_annotation_marker",
|
"test_annotation_marker",
|
||||||
|
"llm_text_classification",
|
||||||
|
"llm_named_entity_recognition",
|
||||||
|
"llm_relation_extraction",
|
||||||
]
|
]
|
||||||
|
|||||||
168
runtime/ops/annotation/_llm_utils.py
Normal file
168
runtime/ops/annotation/_llm_utils.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。
|
||||||
|
|
||||||
|
提供三项核心能力:
|
||||||
|
1. 从 t_model_config 表加载模型配置(按 ID / 按默认)
|
||||||
|
2. 调用 OpenAI 兼容 chat/completions API
|
||||||
|
3. 从 LLM 原始输出中提取 JSON
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 模型配置加载
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_model_config(model_id: str) -> Dict[str, Any]:
|
||||||
|
"""根据 model_id 从 t_model_config 读取已启用的模型配置。"""
|
||||||
|
|
||||||
|
from datamate.sql_manager.sql_manager import SQLManager
|
||||||
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
|
sql = sql_text(
|
||||||
|
"""
|
||||||
|
SELECT model_name, provider, base_url, api_key, type
|
||||||
|
FROM t_model_config
|
||||||
|
WHERE id = :model_id AND is_enabled = 1
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
row = conn.execute(sql, {"model_id": model_id}).fetchone()
|
||||||
|
if not row:
|
||||||
|
raise ValueError(f"Model config not found or disabled: {model_id}")
|
||||||
|
return dict(row._mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def load_default_model_config() -> Dict[str, Any]:
|
||||||
|
"""加载默认的 chat 模型配置(is_default=1 且 type='chat')。"""
|
||||||
|
|
||||||
|
from datamate.sql_manager.sql_manager import SQLManager
|
||||||
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
|
sql = sql_text(
|
||||||
|
"""
|
||||||
|
SELECT id, model_name, provider, base_url, api_key, type
|
||||||
|
FROM t_model_config
|
||||||
|
WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat'
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
row = conn.execute(sql).fetchone()
|
||||||
|
if not row:
|
||||||
|
raise ValueError("No default chat model configured in t_model_config")
|
||||||
|
return dict(row._mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_config(model_id: str = "") -> Dict[str, Any]:
|
||||||
|
"""优先按 model_id 加载,未提供则加载默认模型。"""
|
||||||
|
|
||||||
|
if model_id:
|
||||||
|
return load_model_config(model_id)
|
||||||
|
return load_default_model_config()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLM 调用
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def call_llm(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
prompt: str,
|
||||||
|
system_prompt: str = "",
|
||||||
|
temperature: float = 0.1,
|
||||||
|
max_retries: int = 2,
|
||||||
|
) -> str:
|
||||||
|
"""调用 OpenAI 兼容的 chat/completions API 并返回文本内容。"""
|
||||||
|
|
||||||
|
import requests as http_requests
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
api_key = config.get("api_key", "")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
base_url = str(config["base_url"]).rstrip("/")
|
||||||
|
# 兼容 base_url 已包含 /v1 或不包含的情况
|
||||||
|
if not base_url.endswith("/chat/completions"):
|
||||||
|
if not base_url.endswith("/v1"):
|
||||||
|
base_url = f"{base_url}/v1"
|
||||||
|
url = f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
url = base_url
|
||||||
|
|
||||||
|
body = {
|
||||||
|
"model": config["model_name"],
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
last_err = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
resp = http_requests.post(url, json=body, headers=headers, timeout=120)
|
||||||
|
resp.raise_for_status()
|
||||||
|
content = resp.json()["choices"][0]["message"]["content"]
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
last_err = e
|
||||||
|
logger.warning(
|
||||||
|
"LLM call attempt {}/{} failed: {}",
|
||||||
|
attempt + 1,
|
||||||
|
max_retries + 1,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JSON 提取
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def extract_json(raw: str) -> Any:
|
||||||
|
"""从 LLM 原始输出中提取 JSON 对象/数组。
|
||||||
|
|
||||||
|
处理常见干扰:Markdown 代码块、<think> 标签、前后说明文字。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not raw:
|
||||||
|
raise ValueError("Empty LLM response")
|
||||||
|
|
||||||
|
# 1. 去除 <think>...</think> 等思考标签
|
||||||
|
thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"]
|
||||||
|
for tag in thought_tags:
|
||||||
|
raw = re.sub(rf"<{tag}>[\s\S]*?</{tag}>", "", raw, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
# 2. 去除 Markdown 代码块标记
|
||||||
|
raw = re.sub(r"```(?:json)?\s*", "", raw)
|
||||||
|
raw = raw.replace("```", "")
|
||||||
|
|
||||||
|
# 3. 定位第一个 { 或 [ 到最后一个 } 或 ]
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
for i, ch in enumerate(raw):
|
||||||
|
if ch in "{[":
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
for i in range(len(raw) - 1, -1, -1):
|
||||||
|
if raw[i] in "]}":
|
||||||
|
end = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if start is not None and end is not None and start < end:
|
||||||
|
return json.loads(raw[start:end])
|
||||||
|
|
||||||
|
# 兜底:直接尝试解析
|
||||||
|
return json.loads(raw.strip())
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datamate.core.base_op import OPERATORS
|
||||||
|
from .process import LLMNamedEntityRecognition
|
||||||
|
|
||||||
|
OPERATORS.register_module(
|
||||||
|
module_name="LLMNamedEntityRecognition",
|
||||||
|
module_path="ops.annotation.llm_named_entity_recognition.process",
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["LLMNamedEntityRecognition"]
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
name: 'LLM命名实体识别'
|
||||||
|
name_en: 'LLM Named Entity Recognition'
|
||||||
|
description: '基于大语言模型的命名实体识别算子,支持自定义实体类型。'
|
||||||
|
description_en: 'LLM-based NER operator with custom entity types.'
|
||||||
|
language: 'python'
|
||||||
|
vendor: 'datamate'
|
||||||
|
raw_id: 'LLMNamedEntityRecognition'
|
||||||
|
version: '1.0.0'
|
||||||
|
types:
|
||||||
|
- 'annotation'
|
||||||
|
modal: 'text'
|
||||||
|
inputs: 'text'
|
||||||
|
outputs: 'text'
|
||||||
|
settings:
|
||||||
|
modelId:
|
||||||
|
name: '模型ID'
|
||||||
|
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
|
entityTypes:
|
||||||
|
name: '实体类型'
|
||||||
|
description: '逗号分隔的实体类型,如:PER,ORG,LOC,DATE'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: 'PER,ORG,LOC,DATE'
|
||||||
|
outputDir:
|
||||||
|
name: '输出目录'
|
||||||
|
description: '算子输出目录(由运行时自动注入)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
174
runtime/ops/annotation/llm_named_entity_recognition/process.py
Normal file
174
runtime/ops/annotation/llm_named_entity_recognition/process.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""LLM 命名实体识别 (NER) 算子。
|
||||||
|
|
||||||
|
基于大语言模型从文本中识别命名实体(人名、地名、机构名等),
|
||||||
|
输出实体列表(含文本片段、实体类型、在原文中的起止位置)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from datamate.core.base_op import Mapper
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"你是一个专业的命名实体识别(NER)专家。根据给定的实体类型列表,"
|
||||||
|
"从输入文本中识别所有命名实体。\n"
|
||||||
|
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||||
|
)
|
||||||
|
|
||||||
|
USER_PROMPT_TEMPLATE = """请从以下文本中识别所有命名实体。
|
||||||
|
|
||||||
|
实体类型列表:{entity_types}
|
||||||
|
|
||||||
|
实体类型说明:
|
||||||
|
- PER:人名
|
||||||
|
- ORG:组织/机构名
|
||||||
|
- LOC:地点/地名
|
||||||
|
- DATE:日期/时间
|
||||||
|
- EVENT:事件
|
||||||
|
- PRODUCT:产品名
|
||||||
|
- MONEY:金额
|
||||||
|
- PERCENT:百分比
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
请以如下 JSON 格式输出(entities 为实体数组,每个实体包含 text、type、start、end 四个字段):
|
||||||
|
{{"entities": [{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}]}}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- type 必须是实体类型列表中的值之一
|
||||||
|
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
|
||||||
|
- 如果没有找到任何实体,返回 {{"entities": []}}"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNamedEntityRecognition(Mapper):
|
||||||
|
"""基于 LLM 的命名实体识别算子。"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._model_id: str = kwargs.get("modelId", "")
|
||||||
|
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC,DATE")
|
||||||
|
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||||
|
self._llm_config = None
|
||||||
|
|
||||||
|
def _get_llm_config(self) -> Dict[str, Any]:
|
||||||
|
if self._llm_config is None:
|
||||||
|
from ops.annotation._llm_utils import get_llm_config
|
||||||
|
|
||||||
|
self._llm_config = get_llm_config(self._model_id)
|
||||||
|
return self._llm_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_entities(
|
||||||
|
entities_raw: Any, allowed_types: List[str]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""校验并过滤实体列表,确保类型在允许范围内。"""
|
||||||
|
|
||||||
|
if not isinstance(entities_raw, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
validated: List[Dict[str, Any]] = []
|
||||||
|
allowed_set = {t.strip().upper() for t in allowed_types}
|
||||||
|
|
||||||
|
for ent in entities_raw:
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
ent_type = str(ent.get("type", "")).strip().upper()
|
||||||
|
ent_text = str(ent.get("text", "")).strip()
|
||||||
|
if not ent_text:
|
||||||
|
continue
|
||||||
|
# 保留匹配的类型,或在类型列表为空时全部保留
|
||||||
|
if allowed_set and ent_type not in allowed_set:
|
||||||
|
continue
|
||||||
|
validated.append(
|
||||||
|
{
|
||||||
|
"text": ent_text,
|
||||||
|
"type": ent_type,
|
||||||
|
"start": ent.get("start"),
|
||||||
|
"end": ent.get("end"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return validated
|
||||||
|
|
||||||
|
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
text_path = sample.get(self.text_key)
|
||||||
|
if not text_path or not os.path.exists(str(text_path)):
|
||||||
|
logger.warning("Text file not found: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
text_path = str(text_path)
|
||||||
|
with open(text_path, "r", encoding="utf-8") as f:
|
||||||
|
text_content = f.read()
|
||||||
|
|
||||||
|
if not text_content.strip():
|
||||||
|
logger.warning("Empty text file: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
max_chars = 8000
|
||||||
|
truncated = text_content[:max_chars]
|
||||||
|
|
||||||
|
from ops.annotation._llm_utils import call_llm, extract_json
|
||||||
|
|
||||||
|
config = self._get_llm_config()
|
||||||
|
prompt = USER_PROMPT_TEMPLATE.format(
|
||||||
|
entity_types=self._entity_types,
|
||||||
|
text=truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_types = [t.strip() for t in self._entity_types.split(",") if t.strip()]
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||||
|
parsed = extract_json(raw_response)
|
||||||
|
entities_raw = parsed.get("entities", []) if isinstance(parsed, dict) else parsed
|
||||||
|
entities = self._validate_entities(entities_raw, allowed_types)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LLM NER failed for {}: {}", text_path, e)
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
annotation = {
|
||||||
|
"file": os.path.basename(text_path),
|
||||||
|
"task_type": "ner",
|
||||||
|
"entity_types": self._entity_types,
|
||||||
|
"model": config.get("model_name", ""),
|
||||||
|
"entities": entities,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 写入输出
|
||||||
|
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||||
|
annotations_dir = os.path.join(output_dir, "annotations")
|
||||||
|
data_dir = os.path.join(output_dir, "data")
|
||||||
|
os.makedirs(annotations_dir, exist_ok=True)
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||||
|
|
||||||
|
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||||
|
if not os.path.exists(dst_data):
|
||||||
|
shutil.copy2(text_path, dst_data)
|
||||||
|
|
||||||
|
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
sample["detection_count"] = len(entities)
|
||||||
|
sample["annotations_file"] = json_path
|
||||||
|
sample["annotations"] = annotation
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
logger.info(
|
||||||
|
"NER: {} -> {} entities, Time: {:.2f}s",
|
||||||
|
os.path.basename(text_path),
|
||||||
|
len(entities),
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
return sample
|
||||||
10
runtime/ops/annotation/llm_relation_extraction/__init__.py
Normal file
10
runtime/ops/annotation/llm_relation_extraction/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datamate.core.base_op import OPERATORS
|
||||||
|
from .process import LLMRelationExtraction
|
||||||
|
|
||||||
|
OPERATORS.register_module(
|
||||||
|
module_name="LLMRelationExtraction",
|
||||||
|
module_path="ops.annotation.llm_relation_extraction.process",
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["LLMRelationExtraction"]
|
||||||
34
runtime/ops/annotation/llm_relation_extraction/metadata.yml
Normal file
34
runtime/ops/annotation/llm_relation_extraction/metadata.yml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: 'LLM关系抽取'
|
||||||
|
name_en: 'LLM Relation Extraction'
|
||||||
|
description: '基于大语言模型的关系抽取算子,识别实体并抽取实体间关系三元组。'
|
||||||
|
description_en: 'LLM-based relation extraction operator that identifies entities and extracts relation triples.'
|
||||||
|
language: 'python'
|
||||||
|
vendor: 'datamate'
|
||||||
|
raw_id: 'LLMRelationExtraction'
|
||||||
|
version: '1.0.0'
|
||||||
|
types:
|
||||||
|
- 'annotation'
|
||||||
|
modal: 'text'
|
||||||
|
inputs: 'text'
|
||||||
|
outputs: 'text'
|
||||||
|
settings:
|
||||||
|
modelId:
|
||||||
|
name: '模型ID'
|
||||||
|
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
|
entityTypes:
|
||||||
|
name: '实体类型'
|
||||||
|
description: '逗号分隔的实体类型,如:PER,ORG,LOC'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: 'PER,ORG,LOC'
|
||||||
|
relationTypes:
|
||||||
|
name: '关系类型'
|
||||||
|
description: '逗号分隔的关系类型,如:属于,位于,创立,工作于'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: '属于,位于,创立,工作于'
|
||||||
|
outputDir:
|
||||||
|
name: '输出目录'
|
||||||
|
description: '算子输出目录(由运行时自动注入)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
229
runtime/ops/annotation/llm_relation_extraction/process.py
Normal file
229
runtime/ops/annotation/llm_relation_extraction/process.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""LLM 关系抽取算子。
|
||||||
|
|
||||||
|
基于大语言模型从文本中识别实体,并抽取实体之间的关系,
|
||||||
|
输出实体列表和关系三元组(subject, relation, object)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from datamate.core.base_op import Mapper
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n"
|
||||||
|
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||||
|
)
|
||||||
|
|
||||||
|
USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。
|
||||||
|
|
||||||
|
实体类型列表:{entity_types}
|
||||||
|
关系类型列表:{relation_types}
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
请以如下 JSON 格式输出:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}
|
||||||
|
],
|
||||||
|
"relations": [
|
||||||
|
{{
|
||||||
|
"subject": {{"text": "主语实体", "type": "PER"}},
|
||||||
|
"relation": "关系类型",
|
||||||
|
"object": {{"text": "宾语实体", "type": "ORG"}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 实体的 type 必须是实体类型列表中的值之一
|
||||||
|
- 关系的 relation 必须是关系类型列表中的值之一
|
||||||
|
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
|
||||||
|
- 如果没有找到任何实体或关系,对应数组返回空 []"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRelationExtraction(Mapper):
|
||||||
|
"""基于 LLM 的关系抽取算子。"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._model_id: str = kwargs.get("modelId", "")
|
||||||
|
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC")
|
||||||
|
self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于")
|
||||||
|
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||||
|
self._llm_config = None
|
||||||
|
|
||||||
|
def _get_llm_config(self) -> Dict[str, Any]:
|
||||||
|
if self._llm_config is None:
|
||||||
|
from ops.annotation._llm_utils import get_llm_config
|
||||||
|
|
||||||
|
self._llm_config = get_llm_config(self._model_id)
|
||||||
|
return self._llm_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_entities(
|
||||||
|
entities_raw: Any, allowed_types: List[str]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
if not isinstance(entities_raw, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
validated: List[Dict[str, Any]] = []
|
||||||
|
allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set()
|
||||||
|
|
||||||
|
for ent in entities_raw:
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
ent_text = str(ent.get("text", "")).strip()
|
||||||
|
ent_type = str(ent.get("type", "")).strip().upper()
|
||||||
|
if not ent_text:
|
||||||
|
continue
|
||||||
|
if allowed_set and ent_type not in allowed_set:
|
||||||
|
continue
|
||||||
|
validated.append(
|
||||||
|
{
|
||||||
|
"text": ent_text,
|
||||||
|
"type": ent_type,
|
||||||
|
"start": ent.get("start"),
|
||||||
|
"end": ent.get("end"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return validated
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_relations(
|
||||||
|
relations_raw: Any, allowed_relation_types: List[str]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
if not isinstance(relations_raw, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
validated: List[Dict[str, Any]] = []
|
||||||
|
allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set()
|
||||||
|
|
||||||
|
for rel in relations_raw:
|
||||||
|
if not isinstance(rel, dict):
|
||||||
|
continue
|
||||||
|
subject = rel.get("subject")
|
||||||
|
relation = str(rel.get("relation", "")).strip()
|
||||||
|
obj = rel.get("object")
|
||||||
|
|
||||||
|
if not isinstance(subject, dict) or not isinstance(obj, dict):
|
||||||
|
continue
|
||||||
|
if not relation:
|
||||||
|
continue
|
||||||
|
if allowed_set and relation not in allowed_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
validated.append(
|
||||||
|
{
|
||||||
|
"subject": {
|
||||||
|
"text": str(subject.get("text", "")),
|
||||||
|
"type": str(subject.get("type", "")),
|
||||||
|
},
|
||||||
|
"relation": relation,
|
||||||
|
"object": {
|
||||||
|
"text": str(obj.get("text", "")),
|
||||||
|
"type": str(obj.get("type", "")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return validated
|
||||||
|
|
||||||
|
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
text_path = sample.get(self.text_key)
|
||||||
|
if not text_path or not os.path.exists(str(text_path)):
|
||||||
|
logger.warning("Text file not found: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
text_path = str(text_path)
|
||||||
|
with open(text_path, "r", encoding="utf-8") as f:
|
||||||
|
text_content = f.read()
|
||||||
|
|
||||||
|
if not text_content.strip():
|
||||||
|
logger.warning("Empty text file: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
max_chars = 8000
|
||||||
|
truncated = text_content[:max_chars]
|
||||||
|
|
||||||
|
from ops.annotation._llm_utils import call_llm, extract_json
|
||||||
|
|
||||||
|
config = self._get_llm_config()
|
||||||
|
prompt = USER_PROMPT_TEMPLATE.format(
|
||||||
|
entity_types=self._entity_types,
|
||||||
|
relation_types=self._relation_types,
|
||||||
|
text=truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_entity_types = [
|
||||||
|
t.strip() for t in self._entity_types.split(",") if t.strip()
|
||||||
|
]
|
||||||
|
allowed_relation_types = [
|
||||||
|
t.strip() for t in self._relation_types.split(",") if t.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||||
|
parsed = extract_json(raw_response)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
parsed = {}
|
||||||
|
entities = self._validate_entities(
|
||||||
|
parsed.get("entities", []), allowed_entity_types
|
||||||
|
)
|
||||||
|
relations = self._validate_relations(
|
||||||
|
parsed.get("relations", []), allowed_relation_types
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LLM relation extraction failed for {}: {}", text_path, e)
|
||||||
|
entities = []
|
||||||
|
relations = []
|
||||||
|
|
||||||
|
annotation = {
|
||||||
|
"file": os.path.basename(text_path),
|
||||||
|
"task_type": "relation_extraction",
|
||||||
|
"entity_types": self._entity_types,
|
||||||
|
"relation_types": self._relation_types,
|
||||||
|
"model": config.get("model_name", ""),
|
||||||
|
"entities": entities,
|
||||||
|
"relations": relations,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 写入输出
|
||||||
|
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||||
|
annotations_dir = os.path.join(output_dir, "annotations")
|
||||||
|
data_dir = os.path.join(output_dir, "data")
|
||||||
|
os.makedirs(annotations_dir, exist_ok=True)
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||||
|
|
||||||
|
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||||
|
if not os.path.exists(dst_data):
|
||||||
|
shutil.copy2(text_path, dst_data)
|
||||||
|
|
||||||
|
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
sample["detection_count"] = len(relations)
|
||||||
|
sample["annotations_file"] = json_path
|
||||||
|
sample["annotations"] = annotation
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
logger.info(
|
||||||
|
"RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s",
|
||||||
|
os.path.basename(text_path),
|
||||||
|
len(entities),
|
||||||
|
len(relations),
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
return sample
|
||||||
10
runtime/ops/annotation/llm_text_classification/__init__.py
Normal file
10
runtime/ops/annotation/llm_text_classification/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from datamate.core.base_op import OPERATORS
|
||||||
|
from .process import LLMTextClassification
|
||||||
|
|
||||||
|
OPERATORS.register_module(
|
||||||
|
module_name="LLMTextClassification",
|
||||||
|
module_path="ops.annotation.llm_text_classification.process",
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["LLMTextClassification"]
|
||||||
29
runtime/ops/annotation/llm_text_classification/metadata.yml
Normal file
29
runtime/ops/annotation/llm_text_classification/metadata.yml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: 'LLM文本分类'
|
||||||
|
name_en: 'LLM Text Classification'
|
||||||
|
description: '基于大语言模型的文本分类算子,支持自定义类别标签。'
|
||||||
|
description_en: 'LLM-based text classification operator with custom category labels.'
|
||||||
|
language: 'python'
|
||||||
|
vendor: 'datamate'
|
||||||
|
raw_id: 'LLMTextClassification'
|
||||||
|
version: '1.0.0'
|
||||||
|
types:
|
||||||
|
- 'annotation'
|
||||||
|
modal: 'text'
|
||||||
|
inputs: 'text'
|
||||||
|
outputs: 'text'
|
||||||
|
settings:
|
||||||
|
modelId:
|
||||||
|
name: '模型ID'
|
||||||
|
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
|
categories:
|
||||||
|
name: '分类标签'
|
||||||
|
description: '逗号分隔的分类标签列表,如:正面,负面,中性'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: '正面,负面,中性'
|
||||||
|
outputDir:
|
||||||
|
name: '输出目录'
|
||||||
|
description: '算子输出目录(由运行时自动注入)。'
|
||||||
|
type: 'input'
|
||||||
|
defaultVal: ''
|
||||||
129
runtime/ops/annotation/llm_text_classification/process.py
Normal file
129
runtime/ops/annotation/llm_text_classification/process.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""LLM 文本分类算子。
|
||||||
|
|
||||||
|
基于大语言模型对文本进行分类,输出分类标签、置信度和简短理由。
|
||||||
|
支持通过 categories 参数自定义分类标签体系。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from datamate.core.base_op import Mapper
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = (
|
||||||
|
"你是一个专业的文本分类专家。根据给定的类别列表,对输入文本进行分类。\n"
|
||||||
|
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
|
||||||
|
)
|
||||||
|
|
||||||
|
USER_PROMPT_TEMPLATE = """请对以下文本进行分类。
|
||||||
|
|
||||||
|
可选类别:{categories}
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
请以如下 JSON 格式输出(label 必须是可选类别之一,confidence 为 0~1 的浮点数):
|
||||||
|
{{"label": "类别名", "confidence": 0.95, "reasoning": "简短理由"}}"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTextClassification(Mapper):
|
||||||
|
"""基于 LLM 的文本分类算子。"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._model_id: str = kwargs.get("modelId", "")
|
||||||
|
self._categories: str = kwargs.get("categories", "正面,负面,中性")
|
||||||
|
self._output_dir: str = kwargs.get("outputDir", "") or ""
|
||||||
|
self._llm_config = None
|
||||||
|
|
||||||
|
def _get_llm_config(self) -> Dict[str, Any]:
|
||||||
|
if self._llm_config is None:
|
||||||
|
from ops.annotation._llm_utils import get_llm_config
|
||||||
|
|
||||||
|
self._llm_config = get_llm_config(self._model_id)
|
||||||
|
return self._llm_config
|
||||||
|
|
||||||
|
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
text_path = sample.get(self.text_key)
|
||||||
|
if not text_path or not os.path.exists(str(text_path)):
|
||||||
|
logger.warning("Text file not found: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
text_path = str(text_path)
|
||||||
|
with open(text_path, "r", encoding="utf-8") as f:
|
||||||
|
text_content = f.read()
|
||||||
|
|
||||||
|
if not text_content.strip():
|
||||||
|
logger.warning("Empty text file: {}", text_path)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# 截断过长文本以适应 LLM 上下文窗口
|
||||||
|
max_chars = 8000
|
||||||
|
truncated = text_content[:max_chars]
|
||||||
|
|
||||||
|
from ops.annotation._llm_utils import call_llm, extract_json
|
||||||
|
|
||||||
|
config = self._get_llm_config()
|
||||||
|
prompt = USER_PROMPT_TEMPLATE.format(
|
||||||
|
categories=self._categories,
|
||||||
|
text=truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
|
||||||
|
result = extract_json(raw_response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("LLM classification failed for {}: {}", text_path, e)
|
||||||
|
result = {
|
||||||
|
"label": "unknown",
|
||||||
|
"confidence": 0.0,
|
||||||
|
"reasoning": f"LLM call or JSON parse failed: {e}",
|
||||||
|
}
|
||||||
|
|
||||||
|
annotation = {
|
||||||
|
"file": os.path.basename(text_path),
|
||||||
|
"task_type": "text_classification",
|
||||||
|
"categories": self._categories,
|
||||||
|
"model": config.get("model_name", ""),
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 确定输出目录
|
||||||
|
output_dir = self._output_dir or os.path.dirname(text_path)
|
||||||
|
annotations_dir = os.path.join(output_dir, "annotations")
|
||||||
|
data_dir = os.path.join(output_dir, "data")
|
||||||
|
os.makedirs(annotations_dir, exist_ok=True)
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(text_path))[0]
|
||||||
|
|
||||||
|
# 复制原始文本到 data 目录
|
||||||
|
dst_data = os.path.join(data_dir, os.path.basename(text_path))
|
||||||
|
if not os.path.exists(dst_data):
|
||||||
|
shutil.copy2(text_path, dst_data)
|
||||||
|
|
||||||
|
# 写入标注 JSON
|
||||||
|
json_path = os.path.join(annotations_dir, f"{base_name}.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(annotation, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
sample["detection_count"] = 1
|
||||||
|
sample["annotations_file"] = json_path
|
||||||
|
sample["annotations"] = annotation
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
logger.info(
|
||||||
|
"TextClassification: {} -> {}, Time: {:.2f}s",
|
||||||
|
os.path.basename(text_path),
|
||||||
|
result.get("label", "N/A"),
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
return sample
|
||||||
491
runtime/python-executor/datamate/annotation_result_converter.py
Normal file
491
runtime/python-executor/datamate/annotation_result_converter.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""将自动标注算子输出转换为 Label Studio 兼容格式。
|
||||||
|
|
||||||
|
支持的算子类型:
|
||||||
|
- LLMTextClassification → choices
|
||||||
|
- LLMNamedEntityRecognition → labels (span)
|
||||||
|
- LLMRelationExtraction → labels + relation
|
||||||
|
- ImageObjectDetectionBoundingBox → rectanglelabels
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 颜色调色板(Label Studio 背景色)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_LABEL_COLORS = [
|
||||||
|
"#e53935",
|
||||||
|
"#fb8c00",
|
||||||
|
"#43a047",
|
||||||
|
"#1e88e5",
|
||||||
|
"#8e24aa",
|
||||||
|
"#00897b",
|
||||||
|
"#d81b60",
|
||||||
|
"#3949ab",
|
||||||
|
"#fdd835",
|
||||||
|
"#6d4c41",
|
||||||
|
"#546e7a",
|
||||||
|
"#f4511e",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_color(index: int) -> str:
|
||||||
|
return _LABEL_COLORS[index % len(_LABEL_COLORS)]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 稳定 LS ID 生成(与 editor.py:118-136 一致)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _stable_ls_id(seed: str) -> int:
|
||||||
|
"""生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。"""
|
||||||
|
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
|
||||||
|
value = int(digest[:13], 16)
|
||||||
|
return value if value > 0 else 1
|
||||||
|
|
||||||
|
|
||||||
|
def make_ls_task_id(project_id: str, file_id: str) -> int:
|
||||||
|
return _stable_ls_id(f"task:{project_id}:{file_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def make_ls_annotation_id(project_id: str, file_id: str) -> int:
|
||||||
|
return _stable_ls_id(f"annotation:{project_id}:{file_id}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 确定性 region ID(关系抽取需要稳定引用)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_NAMESPACE_REGION = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_region_id(
|
||||||
|
project_id: str, file_id: str, text: str, entity_type: str, start: Any, end: Any
|
||||||
|
) -> str:
|
||||||
|
seed = f"{project_id}:{file_id}:{text}:{entity_type}:{start}:{end}"
|
||||||
|
return str(uuid.uuid5(_NAMESPACE_REGION, seed))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. 文本分类
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def convert_text_classification(
|
||||||
|
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""将 LLMTextClassification 算子输出转换为 LS annotation。"""
|
||||||
|
|
||||||
|
result_data = annotation.get("result")
|
||||||
|
if not isinstance(result_data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
label = result_data.get("label")
|
||||||
|
if not label:
|
||||||
|
return None
|
||||||
|
|
||||||
|
region_id = str(uuid.uuid4())
|
||||||
|
ls_result = [
|
||||||
|
{
|
||||||
|
"id": region_id,
|
||||||
|
"from_name": "sentiment",
|
||||||
|
"to_name": "text",
|
||||||
|
"type": "choices",
|
||||||
|
"value": {"choices": [str(label)]},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": make_ls_annotation_id(project_id, file_id),
|
||||||
|
"task": make_ls_task_id(project_id, file_id),
|
||||||
|
"result": ls_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. 命名实体识别
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def convert_ner(
|
||||||
|
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""将 LLMNamedEntityRecognition 算子输出转换为 LS annotation。"""
|
||||||
|
|
||||||
|
entities = annotation.get("entities")
|
||||||
|
if not isinstance(entities, list) or not entities:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ls_result: List[Dict[str, Any]] = []
|
||||||
|
for ent in entities:
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
ent_text = ent.get("text", "")
|
||||||
|
ent_type = ent.get("type", "")
|
||||||
|
start = ent.get("start")
|
||||||
|
end = ent.get("end")
|
||||||
|
|
||||||
|
if not ent_text or start is None or end is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
|
||||||
|
ls_result.append(
|
||||||
|
{
|
||||||
|
"id": region_id,
|
||||||
|
"from_name": "label",
|
||||||
|
"to_name": "text",
|
||||||
|
"type": "labels",
|
||||||
|
"value": {
|
||||||
|
"start": int(start),
|
||||||
|
"end": int(end),
|
||||||
|
"text": str(ent_text),
|
||||||
|
"labels": [str(ent_type)],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ls_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": make_ls_annotation_id(project_id, file_id),
|
||||||
|
"task": make_ls_task_id(project_id, file_id),
|
||||||
|
"result": ls_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. 关系抽取
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _find_entity_region_id(
|
||||||
|
entity_regions: List[Dict[str, Any]], text: str, entity_type: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""在已生成的 entity regions 中查找匹配的 region ID。"""
|
||||||
|
for region in entity_regions:
|
||||||
|
value = region.get("value", {})
|
||||||
|
if value.get("text") == text and entity_type in value.get("labels", []):
|
||||||
|
return region["id"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_relation_extraction(
|
||||||
|
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""将 LLMRelationExtraction 算子输出转换为 LS annotation。"""
|
||||||
|
|
||||||
|
entities = annotation.get("entities", [])
|
||||||
|
relations = annotation.get("relations", [])
|
||||||
|
if not isinstance(entities, list):
|
||||||
|
entities = []
|
||||||
|
if not isinstance(relations, list):
|
||||||
|
relations = []
|
||||||
|
|
||||||
|
if not entities and not relations:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 构建实体 label regions(去重)
|
||||||
|
entity_regions: List[Dict[str, Any]] = []
|
||||||
|
seen_keys: set = set()
|
||||||
|
|
||||||
|
for ent in entities:
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
ent_text = str(ent.get("text", "")).strip()
|
||||||
|
ent_type = str(ent.get("type", "")).strip()
|
||||||
|
start = ent.get("start")
|
||||||
|
end = ent.get("end")
|
||||||
|
|
||||||
|
if not ent_text or start is None or end is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dedup_key = (ent_text, ent_type, int(start), int(end))
|
||||||
|
if dedup_key in seen_keys:
|
||||||
|
continue
|
||||||
|
seen_keys.add(dedup_key)
|
||||||
|
|
||||||
|
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
|
||||||
|
entity_regions.append(
|
||||||
|
{
|
||||||
|
"id": region_id,
|
||||||
|
"from_name": "label",
|
||||||
|
"to_name": "text",
|
||||||
|
"type": "labels",
|
||||||
|
"value": {
|
||||||
|
"start": int(start),
|
||||||
|
"end": int(end),
|
||||||
|
"text": ent_text,
|
||||||
|
"labels": [ent_type],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建关系
|
||||||
|
relation_results: List[Dict[str, Any]] = []
|
||||||
|
for rel in relations:
|
||||||
|
if not isinstance(rel, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
subject = rel.get("subject")
|
||||||
|
obj = rel.get("object")
|
||||||
|
relation_type = str(rel.get("relation", "")).strip()
|
||||||
|
|
||||||
|
if not isinstance(subject, dict) or not isinstance(obj, dict) or not relation_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
subj_text = str(subject.get("text", "")).strip()
|
||||||
|
subj_type = str(subject.get("type", "")).strip()
|
||||||
|
obj_text = str(obj.get("text", "")).strip()
|
||||||
|
obj_type = str(obj.get("type", "")).strip()
|
||||||
|
|
||||||
|
from_id = _find_entity_region_id(entity_regions, subj_text, subj_type)
|
||||||
|
to_id = _find_entity_region_id(entity_regions, obj_text, obj_type)
|
||||||
|
|
||||||
|
if not from_id or not to_id:
|
||||||
|
logger.debug(
|
||||||
|
"Skipping relation '{}': could not find region for subject='{}' or object='{}'",
|
||||||
|
relation_type,
|
||||||
|
subj_text,
|
||||||
|
obj_text,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
relation_results.append(
|
||||||
|
{
|
||||||
|
"from_id": from_id,
|
||||||
|
"to_id": to_id,
|
||||||
|
"type": "relation",
|
||||||
|
"direction": "right",
|
||||||
|
"labels": [relation_type],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ls_result = entity_regions + relation_results
|
||||||
|
if not ls_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": make_ls_annotation_id(project_id, file_id),
|
||||||
|
"task": make_ls_task_id(project_id, file_id),
|
||||||
|
"result": ls_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. 目标检测
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def convert_object_detection(
|
||||||
|
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""将 ImageObjectDetectionBoundingBox 算子输出转换为 LS annotation。"""
|
||||||
|
|
||||||
|
detections = annotation.get("detections")
|
||||||
|
if not isinstance(detections, list) or not detections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img_width = annotation.get("width", 0)
|
||||||
|
img_height = annotation.get("height", 0)
|
||||||
|
if not img_width or not img_height:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ls_result: List[Dict[str, Any]] = []
|
||||||
|
for det in detections:
|
||||||
|
if not isinstance(det, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
label = det.get("label", "unknown")
|
||||||
|
bbox = det.get("bbox_xyxy")
|
||||||
|
if not isinstance(bbox, (list, tuple)) or len(bbox) < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
|
||||||
|
x_pct = x1 * 100.0 / img_width
|
||||||
|
y_pct = y1 * 100.0 / img_height
|
||||||
|
w_pct = (x2 - x1) * 100.0 / img_width
|
||||||
|
h_pct = (y2 - y1) * 100.0 / img_height
|
||||||
|
|
||||||
|
region_id = str(uuid.uuid4())
|
||||||
|
ls_result.append(
|
||||||
|
{
|
||||||
|
"id": region_id,
|
||||||
|
"from_name": "label",
|
||||||
|
"to_name": "image",
|
||||||
|
"type": "rectanglelabels",
|
||||||
|
"original_width": int(img_width),
|
||||||
|
"original_height": int(img_height),
|
||||||
|
"image_rotation": 0,
|
||||||
|
"value": {
|
||||||
|
"x": round(x_pct, 4),
|
||||||
|
"y": round(y_pct, 4),
|
||||||
|
"width": round(w_pct, 4),
|
||||||
|
"height": round(h_pct, 4),
|
||||||
|
"rotation": 0,
|
||||||
|
"rectanglelabels": [str(label)],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ls_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": make_ls_annotation_id(project_id, file_id),
|
||||||
|
"task": make_ls_task_id(project_id, file_id),
|
||||||
|
"result": ls_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 分发器
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
TASK_TYPE_CONVERTERS = {
|
||||||
|
"text_classification": convert_text_classification,
|
||||||
|
"ner": convert_ner,
|
||||||
|
"relation_extraction": convert_relation_extraction,
|
||||||
|
"object_detection": convert_object_detection,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_annotation(
|
||||||
|
annotation: Dict[str, Any], file_id: str, project_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""根据 task_type 分发到对应的转换函数。"""
|
||||||
|
task_type = annotation.get("task_type")
|
||||||
|
if task_type is None and "detections" in annotation:
|
||||||
|
task_type = "object_detection"
|
||||||
|
converter = TASK_TYPE_CONVERTERS.get(task_type) # type: ignore[arg-type]
|
||||||
|
if converter is None:
|
||||||
|
logger.warning("No LS converter for task_type: {}", task_type)
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return converter(annotation, file_id, project_id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to convert annotation (task_type={}): {}", task_type, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# label_config XML 生成
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _split_labels(raw: str) -> List[str]:
|
||||||
|
"""逗号分隔字符串 → 去空白列表。"""
|
||||||
|
return [s.strip() for s in raw.split(",") if s.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_label_config_xml(task_type: str, operator_params: Dict[str, Any]) -> str:
|
||||||
|
"""根据标注类型和算子参数生成 Label Studio label_config XML。"""
|
||||||
|
|
||||||
|
if task_type == "text_classification":
|
||||||
|
return _gen_text_classification_xml(operator_params)
|
||||||
|
if task_type == "ner":
|
||||||
|
return _gen_ner_xml(operator_params)
|
||||||
|
if task_type == "relation_extraction":
|
||||||
|
return _gen_relation_extraction_xml(operator_params)
|
||||||
|
if task_type == "object_detection":
|
||||||
|
return _gen_object_detection_xml(operator_params)
|
||||||
|
|
||||||
|
# 未知类型:返回最小可用 XML
|
||||||
|
return "<View><Header value=\"Unknown annotation type\"/></View>"
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_text_classification_xml(params: Dict[str, Any]) -> str:
|
||||||
|
categories = _split_labels(str(params.get("categories", "正面,负面,中性")))
|
||||||
|
|
||||||
|
view = ET.Element("View")
|
||||||
|
ET.SubElement(view, "Text", name="text", value="$text")
|
||||||
|
choices = ET.SubElement(
|
||||||
|
view,
|
||||||
|
"Choices",
|
||||||
|
name="sentiment",
|
||||||
|
toName="text",
|
||||||
|
choice="single",
|
||||||
|
showInline="true",
|
||||||
|
)
|
||||||
|
for cat in categories:
|
||||||
|
ET.SubElement(choices, "Choice", value=cat)
|
||||||
|
|
||||||
|
return ET.tostring(view, encoding="unicode")
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_ner_xml(params: Dict[str, Any]) -> str:
|
||||||
|
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC,DATE")))
|
||||||
|
|
||||||
|
view = ET.Element("View")
|
||||||
|
labels = ET.SubElement(view, "Labels", name="label", toName="text")
|
||||||
|
for i, et in enumerate(entity_types):
|
||||||
|
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
|
||||||
|
ET.SubElement(view, "Text", name="text", value="$text")
|
||||||
|
|
||||||
|
return ET.tostring(view, encoding="unicode")
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_relation_extraction_xml(params: Dict[str, Any]) -> str:
|
||||||
|
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC")))
|
||||||
|
relation_types = _split_labels(
|
||||||
|
str(params.get("relationTypes", "属于,位于,创立,工作于"))
|
||||||
|
)
|
||||||
|
|
||||||
|
view = ET.Element("View")
|
||||||
|
relations = ET.SubElement(view, "Relations")
|
||||||
|
for rt in relation_types:
|
||||||
|
ET.SubElement(relations, "Relation", value=rt)
|
||||||
|
labels = ET.SubElement(view, "Labels", name="label", toName="text")
|
||||||
|
for i, et in enumerate(entity_types):
|
||||||
|
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
|
||||||
|
ET.SubElement(view, "Text", name="text", value="$text")
|
||||||
|
|
||||||
|
return ET.tostring(view, encoding="unicode")
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_object_detection_xml(params: Dict[str, Any]) -> str:
|
||||||
|
detected_labels: List[str] = params.get("_detected_labels", [])
|
||||||
|
if not detected_labels:
|
||||||
|
detected_labels = ["object"]
|
||||||
|
|
||||||
|
view = ET.Element("View")
|
||||||
|
ET.SubElement(view, "Image", name="image", value="$image")
|
||||||
|
rect_labels = ET.SubElement(view, "RectangleLabels", name="label", toName="image")
|
||||||
|
for i, lbl in enumerate(detected_labels):
|
||||||
|
ET.SubElement(rect_labels, "Label", value=lbl, background=_pick_color(i))
|
||||||
|
|
||||||
|
return ET.tostring(view, encoding="unicode")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pipeline 信息提取
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
OPERATOR_TASK_TYPE_MAP: Dict[str, str] = {
|
||||||
|
"LLMTextClassification": "text_classification",
|
||||||
|
"LLMNamedEntityRecognition": "ner",
|
||||||
|
"LLMRelationExtraction": "relation_extraction",
|
||||||
|
"ImageObjectDetectionBoundingBox": "object_detection",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def infer_task_type_from_pipeline(pipeline: List[Dict[str, Any]]) -> Optional[str]:
|
||||||
|
"""从 normalized pipeline 中推断标注类型(取第一个匹配的算子)。"""
|
||||||
|
for step in pipeline:
|
||||||
|
operator_id = str(step.get("operator_id", ""))
|
||||||
|
task_type = OPERATOR_TASK_TYPE_MAP.get(operator_id)
|
||||||
|
if task_type is not None:
|
||||||
|
return task_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_operator_params(pipeline: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""从 normalized pipeline 中提取第一个标注算子的 overrides 参数。"""
|
||||||
|
for step in pipeline:
|
||||||
|
operator_id = str(step.get("operator_id", ""))
|
||||||
|
if operator_id in OPERATOR_TASK_TYPE_MAP:
|
||||||
|
return dict(step.get("overrides", {}))
|
||||||
|
return {}
|
||||||
@@ -2,31 +2,34 @@
|
|||||||
"""Simple background worker for auto-annotation tasks.
|
"""Simple background worker for auto-annotation tasks.
|
||||||
|
|
||||||
This module runs inside the datamate-runtime container (operator_runtime service).
|
This module runs inside the datamate-runtime container (operator_runtime service).
|
||||||
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
|
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs annotation
|
||||||
inference using the ImageObjectDetectionBoundingBox operator, updating
|
using configurable operator pipelines (YOLO, LLM text classification, NER,
|
||||||
progress back to the same table so that the datamate-python backend and
|
relation extraction, etc.), updating progress back to the same table so that
|
||||||
frontend can display real-time status.
|
the datamate-python backend and frontend can display real-time status.
|
||||||
|
|
||||||
设计目标(最小可用版本):
|
设计:
|
||||||
- 单实例 worker,串行处理 `pending` 状态的任务。
|
- 多任务并发: 可通过 AUTO_ANNOTATION_WORKER_COUNT 启动多个 worker 线程,
|
||||||
- 对指定数据集下的所有已完成文件逐张执行目标检测。
|
各自独立轮询和认领 pending 任务(run_token 原子 claim 保证不重复)。
|
||||||
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
|
- 任务内文件并发: 可通过 AUTO_ANNOTATION_FILE_WORKERS 配置线程池大小,
|
||||||
- 失败时将任务标记为 `failed` 并记录 `error_message`。
|
单任务内并行处理多个文件(LLM I/O 密集型场景尤其有效)。
|
||||||
|
算子链通过对象池隔离,每个线程使用独立的链实例。
|
||||||
注意:
|
- 进度更新节流: 可通过 AUTO_ANNOTATION_PROGRESS_INTERVAL 控制进度写入频率,
|
||||||
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时,
|
避免大数据集每文件都写 DB 造成的写压力(默认 2 秒间隔)。
|
||||||
已处于 running 的任务不会被重新拉起,需要后续扩展。
|
- 启动时自动恢复心跳超时的 running 任务:未处理文件重置为 pending,
|
||||||
|
已有部分进度的标记为 failed,由用户决定是否手动重试。
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -122,9 +125,121 @@ DEFAULT_OUTPUT_ROOT = os.getenv(
|
|||||||
|
|
||||||
DEFAULT_OPERATOR_WHITELIST = os.getenv(
|
DEFAULT_OPERATOR_WHITELIST = os.getenv(
|
||||||
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
|
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
|
||||||
"ImageObjectDetectionBoundingBox,test_annotation_marker",
|
"ImageObjectDetectionBoundingBox,test_annotation_marker,"
|
||||||
|
"LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
HEARTBEAT_TIMEOUT_SECONDS = int(os.getenv("AUTO_ANNOTATION_HEARTBEAT_TIMEOUT", "300"))
|
||||||
|
|
||||||
|
WORKER_COUNT = int(os.getenv("AUTO_ANNOTATION_WORKER_COUNT", "1"))
|
||||||
|
|
||||||
|
FILE_WORKERS = int(os.getenv("AUTO_ANNOTATION_FILE_WORKERS", "1"))
|
||||||
|
|
||||||
|
PROGRESS_UPDATE_INTERVAL = float(os.getenv("AUTO_ANNOTATION_PROGRESS_INTERVAL", "2.0"))
|
||||||
|
|
||||||
|
|
||||||
|
def _recover_stale_running_tasks() -> int:
|
||||||
|
"""启动时恢复心跳超时的 running 任务。
|
||||||
|
|
||||||
|
- processed_images = 0 → 重置为 pending(自动重试)
|
||||||
|
- processed_images > 0 → 标记为 failed(需用户干预)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
恢复的任务数量。
|
||||||
|
"""
|
||||||
|
if HEARTBEAT_TIMEOUT_SECONDS <= 0:
|
||||||
|
logger.info(
|
||||||
|
"Heartbeat timeout disabled (HEARTBEAT_TIMEOUT_SECONDS={}), skipping recovery",
|
||||||
|
HEARTBEAT_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cutoff = datetime.now() - timedelta(seconds=HEARTBEAT_TIMEOUT_SECONDS)
|
||||||
|
|
||||||
|
find_sql = text("""
|
||||||
|
SELECT id, processed_images, total_images, heartbeat_at
|
||||||
|
FROM t_dm_auto_annotation_tasks
|
||||||
|
WHERE status = 'running'
|
||||||
|
AND deleted_at IS NULL
|
||||||
|
AND (heartbeat_at IS NULL OR heartbeat_at < :cutoff)
|
||||||
|
""")
|
||||||
|
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
rows = conn.execute(find_sql, {"cutoff": cutoff}).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
recovered = 0
|
||||||
|
for row in rows:
|
||||||
|
task_id = row[0]
|
||||||
|
processed = row[1] or 0
|
||||||
|
total = row[2] or 0
|
||||||
|
heartbeat = row[3]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if processed == 0:
|
||||||
|
# 未开始处理,重置为 pending 自动重试
|
||||||
|
reset_sql = text("""
|
||||||
|
UPDATE t_dm_auto_annotation_tasks
|
||||||
|
SET status = 'pending',
|
||||||
|
run_token = NULL,
|
||||||
|
heartbeat_at = NULL,
|
||||||
|
started_at = NULL,
|
||||||
|
error_message = NULL,
|
||||||
|
updated_at = :now
|
||||||
|
WHERE id = :task_id AND status = 'running'
|
||||||
|
""")
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
reset_sql, {"task_id": task_id, "now": datetime.now()}
|
||||||
|
)
|
||||||
|
if int(getattr(result, "rowcount", 0) or 0) > 0:
|
||||||
|
recovered += 1
|
||||||
|
logger.info(
|
||||||
|
"Recovered stale task {} -> pending (no progress, heartbeat={})",
|
||||||
|
task_id,
|
||||||
|
heartbeat,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 已有部分进度,标记为 failed
|
||||||
|
error_msg = (
|
||||||
|
f"Worker 心跳超时(上次心跳: {heartbeat},"
|
||||||
|
f"超时阈值: {HEARTBEAT_TIMEOUT_SECONDS}秒)。"
|
||||||
|
f"已处理 {processed}/{total} 个文件。请检查后手动重试。"
|
||||||
|
)
|
||||||
|
fail_sql = text("""
|
||||||
|
UPDATE t_dm_auto_annotation_tasks
|
||||||
|
SET status = 'failed',
|
||||||
|
run_token = NULL,
|
||||||
|
error_message = :error_message,
|
||||||
|
completed_at = :now,
|
||||||
|
updated_at = :now
|
||||||
|
WHERE id = :task_id AND status = 'running'
|
||||||
|
""")
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
result = conn.execute(
|
||||||
|
fail_sql,
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"error_message": error_msg[:2000],
|
||||||
|
"now": datetime.now(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if int(getattr(result, "rowcount", 0) or 0) > 0:
|
||||||
|
recovered += 1
|
||||||
|
logger.warning(
|
||||||
|
"Recovered stale task {} -> failed (processed {}/{}, heartbeat={})",
|
||||||
|
task_id,
|
||||||
|
processed,
|
||||||
|
total,
|
||||||
|
heartbeat,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to recover stale task {}: {}", task_id, exc)
|
||||||
|
|
||||||
|
return recovered
|
||||||
|
|
||||||
|
|
||||||
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
||||||
"""原子 claim 一个 pending 任务并返回任务详情。"""
|
"""原子 claim 一个 pending 任务并返回任务详情。"""
|
||||||
@@ -856,6 +971,139 @@ def _register_output_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_labeling_project_with_annotations(
|
||||||
|
task_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
dataset_name: str,
|
||||||
|
task_name: str,
|
||||||
|
dataset_type: str,
|
||||||
|
normalized_pipeline: List[Dict[str, Any]],
|
||||||
|
file_results: List[Tuple[str, Dict[str, Any]]],
|
||||||
|
all_file_ids: List[str],
|
||||||
|
) -> None:
|
||||||
|
"""将自动标注结果转换为 Label Studio 格式,创建标注项目并写入标注结果。"""
|
||||||
|
|
||||||
|
from datamate.annotation_result_converter import (
|
||||||
|
convert_annotation,
|
||||||
|
extract_operator_params,
|
||||||
|
generate_label_config_xml,
|
||||||
|
infer_task_type_from_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_type = infer_task_type_from_pipeline(normalized_pipeline)
|
||||||
|
if not task_type:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot infer task_type from pipeline for task {}, skipping labeling project creation",
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
operator_params = extract_operator_params(normalized_pipeline)
|
||||||
|
|
||||||
|
# 目标检测:从实际检测结果中收集唯一标签列表
|
||||||
|
if task_type == "object_detection":
|
||||||
|
all_labels: set = set()
|
||||||
|
for _, ann in file_results:
|
||||||
|
for det in ann.get("detections", []):
|
||||||
|
if isinstance(det, dict):
|
||||||
|
all_labels.add(str(det.get("label", "unknown")))
|
||||||
|
operator_params["_detected_labels"] = sorted(all_labels)
|
||||||
|
|
||||||
|
label_config = generate_label_config_xml(task_type, operator_params)
|
||||||
|
|
||||||
|
project_id = str(uuid.uuid4())
|
||||||
|
labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8)
|
||||||
|
project_name = f"自动标注 - {task_name or dataset_name or task_id[:8]}"[:100]
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
configuration = json.dumps(
|
||||||
|
{
|
||||||
|
"label_config": label_config,
|
||||||
|
"description": f"由自动标注任务 {task_id[:8]} 自动创建",
|
||||||
|
"auto_annotation_task_id": task_id,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
insert_project_sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO t_dm_labeling_projects
|
||||||
|
(id, dataset_id, name, labeling_project_id, template_id, configuration, created_at, updated_at)
|
||||||
|
VALUES
|
||||||
|
(:id, :dataset_id, :name, :labeling_project_id, NULL, :configuration, :now, :now)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
insert_snapshot_sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO t_dm_labeling_project_files (id, project_id, file_id, created_at)
|
||||||
|
VALUES (:id, :project_id, :file_id, :now)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
insert_annotation_sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO t_dm_annotation_results
|
||||||
|
(id, project_id, file_id, annotation, annotation_status, file_version, created_at, updated_at)
|
||||||
|
VALUES
|
||||||
|
(:id, :project_id, :file_id, :annotation, :annotation_status, :file_version, :now, :now)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with SQLManager.create_connect() as conn:
|
||||||
|
# 1. 创建标注项目
|
||||||
|
conn.execute(
|
||||||
|
insert_project_sql,
|
||||||
|
{
|
||||||
|
"id": project_id,
|
||||||
|
"dataset_id": dataset_id,
|
||||||
|
"name": project_name,
|
||||||
|
"labeling_project_id": labeling_project_id,
|
||||||
|
"configuration": configuration,
|
||||||
|
"now": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 创建项目文件快照
|
||||||
|
for file_id in all_file_ids:
|
||||||
|
conn.execute(
|
||||||
|
insert_snapshot_sql,
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"project_id": project_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"now": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 转换并写入标注结果
|
||||||
|
converted_count = 0
|
||||||
|
for file_id, annotation in file_results:
|
||||||
|
ls_annotation = convert_annotation(annotation, file_id, project_id)
|
||||||
|
if ls_annotation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
insert_annotation_sql,
|
||||||
|
{
|
||||||
|
"id": str(uuid.uuid4()),
|
||||||
|
"project_id": project_id,
|
||||||
|
"file_id": file_id,
|
||||||
|
"annotation": json.dumps(ls_annotation, ensure_ascii=False),
|
||||||
|
"annotation_status": "ANNOTATED",
|
||||||
|
"file_version": 1,
|
||||||
|
"now": now,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
converted_count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created labeling project {} ({}) with {} annotations for auto-annotation task {}",
|
||||||
|
project_id,
|
||||||
|
project_name,
|
||||||
|
converted_count,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _process_single_task(task: Dict[str, Any]) -> None:
|
def _process_single_task(task: Dict[str, Any]) -> None:
|
||||||
"""执行单个自动标注任务。"""
|
"""执行单个自动标注任务。"""
|
||||||
|
|
||||||
@@ -914,7 +1162,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
|||||||
else:
|
else:
|
||||||
all_files = _load_dataset_files(dataset_id)
|
all_files = _load_dataset_files(dataset_id)
|
||||||
|
|
||||||
files = [(path, name) for _, path, name in all_files]
|
files = all_files # [(file_id, file_path, file_name)]
|
||||||
|
|
||||||
total_images = len(files)
|
total_images = len(files)
|
||||||
if total_images == 0:
|
if total_images == 0:
|
||||||
@@ -962,10 +1210,6 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
|||||||
raise RuntimeError("Pipeline is empty after normalization")
|
raise RuntimeError("Pipeline is empty after normalization")
|
||||||
|
|
||||||
_validate_pipeline_whitelist(normalized_pipeline)
|
_validate_pipeline_whitelist(normalized_pipeline)
|
||||||
|
|
||||||
chain = _build_operator_chain(normalized_pipeline)
|
|
||||||
if not chain:
|
|
||||||
raise RuntimeError("No valid operator instances initialized")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
|
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
|
||||||
_update_task_status(
|
_update_task_status(
|
||||||
@@ -980,76 +1224,149 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
processed = 0
|
# --- 构建算子链池(每个线程使用独立的链实例,避免线程安全问题)---
|
||||||
detected_total = 0
|
effective_file_workers = max(1, FILE_WORKERS)
|
||||||
|
chain_pool: queue.Queue = queue.Queue()
|
||||||
try:
|
try:
|
||||||
|
for _ in range(effective_file_workers):
|
||||||
for file_path, file_name in files:
|
c = _build_operator_chain(normalized_pipeline)
|
||||||
if _is_stop_requested(task_id, run_token):
|
if not c:
|
||||||
logger.info("Task stop requested during processing: {}", task_id)
|
raise RuntimeError("No valid operator instances initialized")
|
||||||
_update_task_status(
|
chain_pool.put(c)
|
||||||
task_id,
|
except Exception as e:
|
||||||
run_token=run_token,
|
logger.error("Failed to build operator chain pool for task {}: {}", task_id, e)
|
||||||
status="stopped",
|
|
||||||
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
|
||||||
processed_images=processed,
|
|
||||||
detected_objects=detected_total,
|
|
||||||
total_images=total_images,
|
|
||||||
output_path=output_dir,
|
|
||||||
output_dataset_id=output_dataset_id,
|
|
||||||
completed=True,
|
|
||||||
clear_run_token=True,
|
|
||||||
error_message="Task stopped by request",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
sample_key = _get_sample_key(dataset_type)
|
|
||||||
sample = {
|
|
||||||
sample_key: file_path,
|
|
||||||
"filename": file_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = _run_pipeline_sample(sample, chain)
|
|
||||||
detected_total += _count_detections(result)
|
|
||||||
processed += 1
|
|
||||||
|
|
||||||
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
|
||||||
|
|
||||||
_update_task_status(
|
|
||||||
task_id,
|
|
||||||
run_token=run_token,
|
|
||||||
status="running",
|
|
||||||
progress=progress,
|
|
||||||
processed_images=processed,
|
|
||||||
detected_objects=detected_total,
|
|
||||||
total_images=total_images,
|
|
||||||
output_path=output_dir,
|
|
||||||
output_dataset_id=output_dataset_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Failed to process file for task {}: file_path={}, error={}",
|
|
||||||
task_id,
|
|
||||||
file_path,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
_update_task_status(
|
_update_task_status(
|
||||||
task_id,
|
task_id,
|
||||||
run_token=run_token,
|
run_token=run_token,
|
||||||
status="completed",
|
status="failed",
|
||||||
progress=100,
|
|
||||||
processed_images=processed,
|
|
||||||
detected_objects=detected_total,
|
|
||||||
total_images=total_images,
|
total_images=total_images,
|
||||||
output_path=output_dir,
|
processed_images=0,
|
||||||
output_dataset_id=output_dataset_id,
|
detected_objects=0,
|
||||||
completed=True,
|
error_message=f"Init pipeline failed: {e}",
|
||||||
clear_run_token=True,
|
clear_run_token=True,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
detected_total = 0
|
||||||
|
file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations)
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- 线程安全的进度跟踪 ---
|
||||||
|
progress_lock = threading.Lock()
|
||||||
|
stop_event = threading.Event()
|
||||||
|
|
||||||
|
def _process_file(
|
||||||
|
file_id: str, file_path: str, file_name: str,
|
||||||
|
) -> Optional[Tuple[str, Dict[str, Any]]]:
|
||||||
|
"""在线程池中处理单个文件。"""
|
||||||
|
if stop_event.is_set():
|
||||||
|
return None
|
||||||
|
chain = chain_pool.get()
|
||||||
|
try:
|
||||||
|
sample_key = _get_sample_key(dataset_type)
|
||||||
|
sample: Dict[str, Any] = {
|
||||||
|
sample_key: file_path,
|
||||||
|
"filename": file_name,
|
||||||
|
}
|
||||||
|
result = _run_pipeline_sample(sample, chain)
|
||||||
|
return (file_id, result)
|
||||||
|
finally:
|
||||||
|
chain_pool.put(chain)
|
||||||
|
|
||||||
|
# --- 并发文件处理 ---
|
||||||
|
stop_check_interval = max(1, effective_file_workers * 2)
|
||||||
|
completed_since_check = 0
|
||||||
|
last_progress_update = time.monotonic()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=effective_file_workers) as executor:
|
||||||
|
future_to_file = {
|
||||||
|
executor.submit(_process_file, fid, fpath, fname): (fid, fpath, fname)
|
||||||
|
for fid, fpath, fname in files
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(future_to_file):
|
||||||
|
fid, fpath, fname = future_to_file[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
file_id_out, sample_result = result
|
||||||
|
detections = _count_detections(sample_result)
|
||||||
|
ann = sample_result.get("annotations")
|
||||||
|
|
||||||
|
with progress_lock:
|
||||||
|
processed += 1
|
||||||
|
detected_total += detections
|
||||||
|
if isinstance(ann, dict):
|
||||||
|
file_results.append((file_id_out, ann))
|
||||||
|
current_processed = processed
|
||||||
|
current_detected = detected_total
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if PROGRESS_UPDATE_INTERVAL <= 0 or (now - last_progress_update) >= PROGRESS_UPDATE_INTERVAL:
|
||||||
|
progress = int(current_processed * 100 / total_images) if total_images > 0 else 100
|
||||||
|
_update_task_status(
|
||||||
|
task_id,
|
||||||
|
run_token=run_token,
|
||||||
|
status="running",
|
||||||
|
progress=progress,
|
||||||
|
processed_images=current_processed,
|
||||||
|
detected_objects=current_detected,
|
||||||
|
total_images=total_images,
|
||||||
|
output_path=output_dir,
|
||||||
|
output_dataset_id=output_dataset_id,
|
||||||
|
)
|
||||||
|
last_progress_update = now
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to process file for task {}: file_path={}, error={}",
|
||||||
|
task_id,
|
||||||
|
fpath,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
completed_since_check += 1
|
||||||
|
if completed_since_check >= stop_check_interval:
|
||||||
|
completed_since_check = 0
|
||||||
|
if _is_stop_requested(task_id, run_token):
|
||||||
|
stop_event.set()
|
||||||
|
for f in future_to_file:
|
||||||
|
f.cancel()
|
||||||
|
stopped = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if stopped:
|
||||||
|
logger.info("Task stop requested during processing: {}", task_id)
|
||||||
|
_update_task_status(
|
||||||
|
task_id,
|
||||||
|
run_token=run_token,
|
||||||
|
status="stopped",
|
||||||
|
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
||||||
|
processed_images=processed,
|
||||||
|
detected_objects=detected_total,
|
||||||
|
total_images=total_images,
|
||||||
|
output_path=output_dir,
|
||||||
|
output_dataset_id=output_dataset_id,
|
||||||
|
completed=True,
|
||||||
|
clear_run_token=True,
|
||||||
|
error_message="Task stopped by request",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_update_task_status(
|
||||||
|
task_id,
|
||||||
|
run_token=run_token,
|
||||||
|
status="completed",
|
||||||
|
progress=100,
|
||||||
|
processed_images=processed,
|
||||||
|
detected_objects=detected_total,
|
||||||
|
total_images=total_images,
|
||||||
|
output_path=output_dir,
|
||||||
|
output_dataset_id=output_dataset_id,
|
||||||
|
completed=True,
|
||||||
|
clear_run_token=True,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
||||||
@@ -1076,6 +1393,26 @@ def _process_single_task(task: Dict[str, Any]) -> None:
|
|||||||
task_id,
|
task_id,
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 将自动标注结果转换为 Label Studio 格式并写入标注项目
|
||||||
|
if file_results:
|
||||||
|
try:
|
||||||
|
_create_labeling_project_with_annotations(
|
||||||
|
task_id=task_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
dataset_name=source_dataset_name,
|
||||||
|
task_name=task_name,
|
||||||
|
dataset_type=dataset_type,
|
||||||
|
normalized_pipeline=normalized_pipeline,
|
||||||
|
file_results=file_results,
|
||||||
|
all_file_ids=[fid for fid, _, _ in all_files],
|
||||||
|
)
|
||||||
|
except Exception as e: # pragma: no cover - 防御性日志
|
||||||
|
logger.error(
|
||||||
|
"Failed to create labeling project for auto-annotation task {}: {}",
|
||||||
|
task_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Task execution failed for task {}: {}", task_id, e)
|
logger.error("Task execution failed for task {}: {}", task_id, e)
|
||||||
_update_task_status(
|
_update_task_status(
|
||||||
@@ -1097,9 +1434,10 @@ def _worker_loop() -> None:
|
|||||||
"""Worker 主循环,在独立线程中运行。"""
|
"""Worker 主循环,在独立线程中运行。"""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Auto-annotation worker started with poll interval {} seconds, output root {}",
|
"Auto-annotation worker started (poll_interval={}s, output_root={}, file_workers={})",
|
||||||
POLL_INTERVAL_SECONDS,
|
POLL_INTERVAL_SECONDS,
|
||||||
DEFAULT_OUTPUT_ROOT,
|
DEFAULT_OUTPUT_ROOT,
|
||||||
|
FILE_WORKERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -1118,6 +1456,20 @@ def _worker_loop() -> None:
|
|||||||
def start_auto_annotation_worker() -> None:
|
def start_auto_annotation_worker() -> None:
|
||||||
"""在后台线程中启动自动标注 worker。"""
|
"""在后台线程中启动自动标注 worker。"""
|
||||||
|
|
||||||
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
|
# 启动前执行一次恢复(在 worker 线程启动前运行,避免多线程重复恢复)
|
||||||
thread.start()
|
try:
|
||||||
logger.info("Auto-annotation worker thread started: {}", thread.name)
|
recovered = _recover_stale_running_tasks()
|
||||||
|
if recovered > 0:
|
||||||
|
logger.info("Recovered {} stale running task(s) on startup", recovered)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to run startup task recovery: {}", e)
|
||||||
|
|
||||||
|
count = max(1, WORKER_COUNT)
|
||||||
|
for i in range(count):
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=_worker_loop,
|
||||||
|
name=f"auto-annotation-worker-{i}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
logger.info("Auto-annotation worker thread started: {}", thread.name)
|
||||||
|
|||||||
390
runtime/python-executor/tests/test_worker_concurrency.py
Normal file
390
runtime/python-executor/tests/test_worker_concurrency.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Tests for auto_annotation_worker concurrency features (improvement #4).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Multi-worker startup (WORKER_COUNT)
|
||||||
|
- Intra-task file parallelism (FILE_WORKERS)
|
||||||
|
- Chain pool acquire/release
|
||||||
|
- Thread-safe progress tracking
|
||||||
|
- Stop request handling during concurrent processing
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
from unittest.mock import MagicMock, patch, call
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Ensure the module under test can be imported
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
RUNTIME_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||||
|
EXECUTOR_ROOT = os.path.join(os.path.dirname(__file__), "..")
|
||||||
|
for p in (RUNTIME_ROOT, EXECUTOR_ROOT):
|
||||||
|
abs_p = os.path.abspath(p)
|
||||||
|
if abs_p not in sys.path:
|
||||||
|
sys.path.insert(0, abs_p)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerCountConfig(unittest.TestCase):
|
||||||
|
"""Test that start_auto_annotation_worker launches WORKER_COUNT threads."""
|
||||||
|
|
||||||
|
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0)
|
||||||
|
@patch("datamate.auto_annotation_worker._worker_loop")
|
||||||
|
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 3)
|
||||||
|
def test_multiple_worker_threads_launched(self, mock_loop, mock_recover):
|
||||||
|
"""WORKER_COUNT=3 should launch 3 daemon threads."""
|
||||||
|
from datamate.auto_annotation_worker import start_auto_annotation_worker
|
||||||
|
|
||||||
|
started_threads: List[threading.Thread] = []
|
||||||
|
original_thread_init = threading.Thread.__init__
|
||||||
|
|
||||||
|
def track_thread(self_thread, *args, **kwargs):
|
||||||
|
original_thread_init(self_thread, *args, **kwargs)
|
||||||
|
started_threads.append(self_thread)
|
||||||
|
|
||||||
|
with patch.object(threading.Thread, "__init__", track_thread):
|
||||||
|
with patch.object(threading.Thread, "start"):
|
||||||
|
start_auto_annotation_worker()
|
||||||
|
|
||||||
|
self.assertEqual(len(started_threads), 3)
|
||||||
|
for i, t in enumerate(started_threads):
|
||||||
|
self.assertEqual(t.name, f"auto-annotation-worker-{i}")
|
||||||
|
self.assertTrue(t.daemon)
|
||||||
|
|
||||||
|
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0)
|
||||||
|
@patch("datamate.auto_annotation_worker._worker_loop")
|
||||||
|
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 1)
|
||||||
|
def test_single_worker_default(self, mock_loop, mock_recover):
|
||||||
|
"""WORKER_COUNT=1 (default) should launch exactly 1 thread."""
|
||||||
|
from datamate.auto_annotation_worker import start_auto_annotation_worker
|
||||||
|
|
||||||
|
started_threads: List[threading.Thread] = []
|
||||||
|
original_thread_init = threading.Thread.__init__
|
||||||
|
|
||||||
|
def track_thread(self_thread, *args, **kwargs):
|
||||||
|
original_thread_init(self_thread, *args, **kwargs)
|
||||||
|
started_threads.append(self_thread)
|
||||||
|
|
||||||
|
with patch.object(threading.Thread, "__init__", track_thread):
|
||||||
|
with patch.object(threading.Thread, "start"):
|
||||||
|
start_auto_annotation_worker()
|
||||||
|
|
||||||
|
self.assertEqual(len(started_threads), 1)
|
||||||
|
|
||||||
|
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", side_effect=RuntimeError("db down"))
|
||||||
|
@patch("datamate.auto_annotation_worker._worker_loop")
|
||||||
|
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 2)
|
||||||
|
def test_recovery_failure_doesnt_block_workers(self, mock_loop, mock_recover):
|
||||||
|
"""Recovery failure should not prevent worker threads from starting."""
|
||||||
|
from datamate.auto_annotation_worker import start_auto_annotation_worker
|
||||||
|
|
||||||
|
started_threads: List[threading.Thread] = []
|
||||||
|
original_thread_init = threading.Thread.__init__
|
||||||
|
|
||||||
|
def track_thread(self_thread, *args, **kwargs):
|
||||||
|
original_thread_init(self_thread, *args, **kwargs)
|
||||||
|
started_threads.append(self_thread)
|
||||||
|
|
||||||
|
with patch.object(threading.Thread, "__init__", track_thread):
|
||||||
|
with patch.object(threading.Thread, "start"):
|
||||||
|
start_auto_annotation_worker()
|
||||||
|
|
||||||
|
# Workers should still be launched despite recovery failure
|
||||||
|
self.assertEqual(len(started_threads), 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChainPool(unittest.TestCase):
|
||||||
|
"""Test the chain pool pattern used for operator instance isolation."""
|
||||||
|
|
||||||
|
def test_pool_acquire_release(self):
|
||||||
|
"""Each thread should get its own chain and return it after use."""
|
||||||
|
pool: queue.Queue = queue.Queue()
|
||||||
|
chains = [f"chain-{i}" for i in range(3)]
|
||||||
|
for c in chains:
|
||||||
|
pool.put(c)
|
||||||
|
|
||||||
|
acquired: List[str] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
chain = pool.get()
|
||||||
|
with lock:
|
||||||
|
acquired.append(chain)
|
||||||
|
time.sleep(0.01)
|
||||||
|
pool.put(chain)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker) for _ in range(6)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# All 6 workers should have acquired a chain
|
||||||
|
self.assertEqual(len(acquired), 6)
|
||||||
|
# Pool should have all 3 chains back
|
||||||
|
self.assertEqual(pool.qsize(), 3)
|
||||||
|
returned = set()
|
||||||
|
while not pool.empty():
|
||||||
|
returned.add(pool.get())
|
||||||
|
self.assertEqual(returned, set(chains))
|
||||||
|
|
||||||
|
def test_pool_blocks_when_empty(self):
|
||||||
|
"""When pool is empty, threads should block until a chain is returned."""
|
||||||
|
pool: queue.Queue = queue.Queue()
|
||||||
|
pool.put("only-chain")
|
||||||
|
|
||||||
|
acquired_times: List[float] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
chain = pool.get()
|
||||||
|
with lock:
|
||||||
|
acquired_times.append(time.monotonic())
|
||||||
|
time.sleep(0.05)
|
||||||
|
pool.put(chain)
|
||||||
|
|
||||||
|
t1 = threading.Thread(target=worker)
|
||||||
|
t2 = threading.Thread(target=worker)
|
||||||
|
t1.start()
|
||||||
|
time.sleep(0.01) # Ensure t1 starts first
|
||||||
|
t2.start()
|
||||||
|
t1.join()
|
||||||
|
t2.join()
|
||||||
|
|
||||||
|
# t2 should have acquired after t1 returned (at least 0.04s gap)
|
||||||
|
self.assertEqual(len(acquired_times), 2)
|
||||||
|
gap = acquired_times[1] - acquired_times[0]
|
||||||
|
self.assertGreater(gap, 0.03)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentFileProcessing(unittest.TestCase):
|
||||||
|
"""Test the concurrent file processing logic from _process_single_task."""
|
||||||
|
|
||||||
|
def test_threadpool_processes_all_files(self):
|
||||||
|
"""ThreadPoolExecutor should process all submitted files."""
|
||||||
|
results: List[str] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def process_file(file_id):
|
||||||
|
time.sleep(0.01)
|
||||||
|
with lock:
|
||||||
|
results.append(file_id)
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
files = [f"file-{i}" for i in range(10)]
|
||||||
|
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||||
|
futures = {executor.submit(process_file, f): f for f in files}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
self.assertEqual(sorted(results), sorted(files))
|
||||||
|
|
||||||
|
def test_stop_event_cancels_pending_futures(self):
|
||||||
|
"""Setting stop_event should prevent unstarted files from processing."""
|
||||||
|
stop_event = threading.Event()
|
||||||
|
processed: List[str] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def process_file(file_id):
|
||||||
|
if stop_event.is_set():
|
||||||
|
return None
|
||||||
|
time.sleep(0.05)
|
||||||
|
with lock:
|
||||||
|
processed.append(file_id)
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
files = [f"file-{i}" for i in range(20)]
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
future_to_file = {
|
||||||
|
executor.submit(process_file, f): f for f in files
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for future in as_completed(future_to_file):
|
||||||
|
result = future.result()
|
||||||
|
count += 1
|
||||||
|
if count >= 3:
|
||||||
|
stop_event.set()
|
||||||
|
for f in future_to_file:
|
||||||
|
f.cancel()
|
||||||
|
break
|
||||||
|
|
||||||
|
# Should have processed some but not all files
|
||||||
|
self.assertGreater(len(processed), 0)
|
||||||
|
self.assertLess(len(processed), 20)
|
||||||
|
|
||||||
|
def test_thread_safe_counter_updates(self):
|
||||||
|
"""Counters updated inside lock should be accurate under concurrency."""
|
||||||
|
processed = 0
|
||||||
|
detected = 0
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def process_and_count(file_id):
|
||||||
|
nonlocal processed, detected
|
||||||
|
time.sleep(0.001)
|
||||||
|
with lock:
|
||||||
|
processed += 1
|
||||||
|
detected += 2
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||||
|
futures = [executor.submit(process_and_count, f"f-{i}") for i in range(100)]
|
||||||
|
for f in as_completed(futures):
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
self.assertEqual(processed, 100)
|
||||||
|
self.assertEqual(detected, 200)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileWorkersConfig(unittest.TestCase):
|
||||||
|
"""Test FILE_WORKERS configuration behavior."""
|
||||||
|
|
||||||
|
def test_file_workers_one_is_serial(self):
|
||||||
|
"""FILE_WORKERS=1 should process files sequentially."""
|
||||||
|
order: List[int] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def process(idx):
|
||||||
|
with lock:
|
||||||
|
order.append(idx)
|
||||||
|
time.sleep(0.01)
|
||||||
|
return idx
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
futures = [executor.submit(process, i) for i in range(5)]
|
||||||
|
for f in as_completed(futures):
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# With max_workers=1, execution is serial (though completion order
|
||||||
|
# via as_completed might differ; the key is that only 1 runs at a time)
|
||||||
|
self.assertEqual(len(order), 5)
|
||||||
|
|
||||||
|
def test_file_workers_gt_one_is_parallel(self):
|
||||||
|
"""FILE_WORKERS>1 should process files concurrently."""
|
||||||
|
start_times: List[float] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def process(idx):
|
||||||
|
with lock:
|
||||||
|
start_times.append(time.monotonic())
|
||||||
|
time.sleep(0.05)
|
||||||
|
return idx
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=4) as executor:
|
||||||
|
futures = [executor.submit(process, i) for i in range(4)]
|
||||||
|
for f in as_completed(futures):
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
# All 4 should start nearly simultaneously
|
||||||
|
self.assertEqual(len(start_times), 4)
|
||||||
|
time_spread = max(start_times) - min(start_times)
|
||||||
|
# With parallel execution, spread should be < 0.04s
|
||||||
|
# (serial would be ~0.15s with 0.05s sleep each)
|
||||||
|
self.assertLess(time_spread, 0.04)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerLoopSimplified(unittest.TestCase):
|
||||||
|
"""Test that _worker_loop no longer calls recovery."""
|
||||||
|
|
||||||
|
@patch("datamate.auto_annotation_worker._process_single_task")
|
||||||
|
@patch("datamate.auto_annotation_worker._fetch_pending_task")
|
||||||
|
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks")
|
||||||
|
def test_worker_loop_does_not_call_recovery(self, mock_recover, mock_fetch, mock_process):
|
||||||
|
"""_worker_loop should NOT call _recover_stale_running_tasks."""
|
||||||
|
from datamate.auto_annotation_worker import _worker_loop
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def side_effect():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count >= 2:
|
||||||
|
raise KeyboardInterrupt("stop test")
|
||||||
|
return None
|
||||||
|
|
||||||
|
mock_fetch.side_effect = side_effect
|
||||||
|
|
||||||
|
with patch("datamate.auto_annotation_worker.POLL_INTERVAL_SECONDS", 0.001):
|
||||||
|
try:
|
||||||
|
_worker_loop()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_recover.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressThrottling(unittest.TestCase):
|
||||||
|
"""Test time-based progress update throttling (improvement #5)."""
|
||||||
|
|
||||||
|
def test_progress_updates_throttled(self):
|
||||||
|
"""With PROGRESS_UPDATE_INTERVAL>0, rapid completions should batch DB writes."""
|
||||||
|
update_calls: List[float] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def mock_update(*args, **kwargs):
|
||||||
|
with lock:
|
||||||
|
update_calls.append(time.monotonic())
|
||||||
|
|
||||||
|
interval = 0.05 # 50ms throttle interval
|
||||||
|
processed = 0
|
||||||
|
# Initialize in the past so the first file triggers an update
|
||||||
|
last_progress_update = time.monotonic() - interval
|
||||||
|
total_files = 50
|
||||||
|
|
||||||
|
# Simulate the throttled update loop from _process_single_task
|
||||||
|
for i in range(total_files):
|
||||||
|
processed += 1
|
||||||
|
now = time.monotonic()
|
||||||
|
if interval <= 0 or (now - last_progress_update) >= interval:
|
||||||
|
mock_update(processed=processed, total=total_files)
|
||||||
|
last_progress_update = now
|
||||||
|
# Simulate very fast file processing (~1ms)
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
# With 50 files at ~1ms each (~50ms total) and 50ms interval,
|
||||||
|
# should get far fewer updates than total_files
|
||||||
|
self.assertLess(len(update_calls), total_files)
|
||||||
|
self.assertGreater(len(update_calls), 0)
|
||||||
|
|
||||||
|
def test_progress_interval_zero_updates_every_file(self):
|
||||||
|
"""PROGRESS_UPDATE_INTERVAL=0 should update on every file completion."""
|
||||||
|
update_count = 0
|
||||||
|
interval = 0.0
|
||||||
|
total_files = 20
|
||||||
|
last_progress_update = time.monotonic()
|
||||||
|
|
||||||
|
for i in range(total_files):
|
||||||
|
now = time.monotonic()
|
||||||
|
if interval <= 0 or (now - last_progress_update) >= interval:
|
||||||
|
update_count += 1
|
||||||
|
last_progress_update = now
|
||||||
|
|
||||||
|
self.assertEqual(update_count, total_files)
|
||||||
|
|
||||||
|
def test_progress_throttle_with_slow_processing(self):
|
||||||
|
"""When each file takes longer than the interval, every file triggers an update."""
|
||||||
|
update_count = 0
|
||||||
|
interval = 0.01 # 10ms interval
|
||||||
|
total_files = 5
|
||||||
|
last_progress_update = time.monotonic() - 1.0 # Start in the past
|
||||||
|
|
||||||
|
for i in range(total_files):
|
||||||
|
time.sleep(0.02) # 20ms per file > 10ms interval
|
||||||
|
now = time.monotonic()
|
||||||
|
if interval <= 0 or (now - last_progress_update) >= interval:
|
||||||
|
update_count += 1
|
||||||
|
last_progress_update = now
|
||||||
|
|
||||||
|
# Every file should trigger an update since processing time > interval
|
||||||
|
self.assertEqual(update_count, total_files)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user