Compare commits

...

5 Commits

Author SHA1 Message Date
f707ce9dae feat(auto-annotation): add batch progress updates to reduce DB write pressure
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (java-kotlin) (push) Has been cancelled
CodeQL Advanced / Analyze (javascript-typescript) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Throttle progress updates to reduce database write operations during large dataset processing.

Key features:
- Add PROGRESS_UPDATE_INTERVAL config (default 2.0s, configurable via AUTO_ANNOTATION_PROGRESS_INTERVAL env)
- Conditional progress updates: Only write to DB when (now - last_update) >= interval
- Use time.monotonic() for timing (immune to system clock adjustments)
- Final status updates (completed/stopped/failed) always execute (not throttled)

Implementation:
- Initialize last_progress_update timestamp before as_completed() loop
- Replace unconditional _update_task_status() with conditional call based on time interval
- Update docstring to reflect throttling capability

Performance impact (T=2s):
- 1,000 files / 100s processing: DB writes reduced from 1,000 to ~50 (95% reduction)
- 10,000 files / 500s processing: DB writes reduced from 10,000 to ~250 (97.5% reduction)
- Small datasets (10 files): Minimal difference

Backward compatibility:
- PROGRESS_UPDATE_INTERVAL=0: Updates every file (identical to previous behavior)
- Heartbeat mechanism unaffected (2s interval << 300s timeout)
- Stop check mechanism independent of progress updates
- Final status updates always execute

Testing:
- 14 unit tests all passed (11 existing + 3 new):
  * Fast processing with throttling
  * PROGRESS_UPDATE_INTERVAL=0 updates every file
  * Slow processing (per-file > T) updates every file
- py_compile syntax check passed

Edge cases handled:
- Single file task: Works normally
- Very slow processing: Degrades to per-file updates
- Concurrent FILE_WORKERS > 1: Counters accurate (lock-protected), DB reflects with max T seconds delay
2026-02-10 16:49:37 +08:00
9988ff00f5 feat(auto-annotation): add concurrent processing support
Enable parallel processing for auto-annotation tasks with configurable worker count and file-level parallelism.

Key features:
- Multi-worker support: WORKER_COUNT env var (default 1) controls number of worker threads
- Intra-task file parallelism: FILE_WORKERS env var (default 1) controls concurrent file processing within a single task
- Operator chain pooling: Pre-create N independent chain instances to avoid thread-safety issues
- Thread-safe progress tracking: Use threading.Lock to protect shared counters
- Stop signal handling: threading.Event for graceful cancellation during concurrent processing

Implementation details:
- Refactor _process_single_task() to use ThreadPoolExecutor + as_completed()
- Chain pool (queue.Queue): Each worker thread acquires/releases a chain instance
- Protected counters: processed_images, detected_total, file_results with Lock
- Stop check: Periodic check of _is_stop_requested() during concurrent processing
- Refactor start_auto_annotation_worker(): Move recovery logic here, start WORKER_COUNT threads
- Simplify _worker_loop(): Remove recovery call, keep only polling + processing

Backward compatibility:
- Default config (WORKER_COUNT=1, FILE_WORKERS=1) behaves identically to previous version
- No breaking changes to existing deployments

Testing:
- 11 unit tests all passed:
  * Multi-worker startup
  * Chain pool acquire/release
  * Concurrent file processing
  * Stop signal handling
  * Thread-safe counter updates
  * Backward compatibility (FILE_WORKERS=1)
- py_compile syntax check passed

Performance benefits:
- WORKER_COUNT=3: Process 3 tasks simultaneously
- FILE_WORKERS=4: Process 4 files in parallel within each task
- Combined: Up to 12x throughput improvement (3 workers × 4 files)
2026-02-10 16:36:34 +08:00
2fbfefdb91 feat(auto-annotation): add worker recovery mechanism for stale tasks
Automatically recover running tasks with stale heartbeats on worker startup, preventing tasks from being permanently stuck after container restarts.

Key changes:
- Add HEARTBEAT_TIMEOUT_SECONDS constant (default 300s, configurable via env)
- Add _recover_stale_running_tasks() function:
  * Scans for status='running' tasks with heartbeat timeout
  * No progress (processed=0) → reset to pending (auto-retry)
  * Has progress (processed>0) → mark as failed with Chinese error message
  * Each task recovery is independent (single failure doesn't affect others)
  * Skip recovery if timeout is 0 or negative (disable feature)
- Call recovery function in _worker_loop() before polling loop
- Update file header comments to reflect recovery mechanism

Recovery logic:
- Query: status='running' AND (heartbeat_at IS NULL OR heartbeat_at < NOW() - timeout)
- Decision based on processed_images count
- Clear run_token to allow other workers to claim
- Single transaction per task for atomicity

Edge cases handled:
- Database unavailable: recovery failure doesn't block worker startup
- Concurrent recovery: UPDATE WHERE status='running' prevents duplicates
- NULL heartbeat: extreme case (crash right after claim) also recovered
- stop_requested tasks: automatically excluded by _fetch_pending_task()

Testing:
- 8 unit tests all passed:
  * No timeout tasks
  * Timeout disabled
  * No progress → pending
  * Has progress → failed
  * NULL heartbeat recovery
  * Multiple tasks mixed processing
  * DB error doesn't crash
  * Negative timeout disables feature
2026-02-10 16:19:22 +08:00
dc490f03be feat(auto-annotation): unify annotation results with Label Studio format
Automatically convert auto-annotation outputs to Label Studio format and write to t_dm_annotation_results table, enabling seamless editing in the annotation editor.

New file:
- runtime/python-executor/datamate/annotation_result_converter.py
  * 4 converters for different annotation types:
    - convert_text_classification → choices type
    - convert_ner → labels (span) type
    - convert_relation_extraction → labels + relation type
    - convert_object_detection → rectanglelabels type
  * convert_annotation() dispatcher (auto-detects task_type)
  * generate_label_config_xml() for dynamic XML generation
  * Pipeline introspection utilities
  * Label Studio ID generation logic

Modified file:
- runtime/python-executor/datamate/auto_annotation_worker.py
  * Preserve file_id through processing loop (line 918)
  * Collect file_results as (file_id, annotations) pairs
  * New _create_labeling_project_with_annotations() function:
    - Creates labeling project linked to source dataset
    - Snapshots all files
    - Converts results to Label Studio format
    - Writes to t_dm_annotation_results in single transaction
  * label_config XML stored in t_dm_labeling_projects.configuration

Key features:
- Supports 4 annotation types: text classification, NER, relation extraction, object detection
- Deterministic region IDs for entity references in relation extraction
- Pixel to percentage conversion for object detection
- XML escaping handled by xml.etree.ElementTree
- Partial results preserved on task stop

Users can now view and edit auto-annotation results seamlessly in the annotation editor.
2026-02-10 16:06:40 +08:00
49f99527cc feat(auto-annotation): add LLM-based annotation operators
Add three new LLM-powered auto-annotation operators:
- LLMTextClassification: Text classification using LLM
- LLMNamedEntityRecognition: Named entity recognition with type validation
- LLMRelationExtraction: Relation extraction with entity and relation type validation

Key features:
- Load LLM config from t_model_config table via modelId parameter
- Lazy loading of LLM configuration on first execute()
- Result validation with whitelist checking for entity/relation types
- Fault-tolerant: returns empty results on LLM failure instead of throwing
- Fully compatible with existing Worker pipeline

Files added:
- runtime/ops/annotation/_llm_utils.py: Shared LLM utilities
- runtime/ops/annotation/llm_text_classification/: Text classification operator
- runtime/ops/annotation/llm_named_entity_recognition/: NER operator
- runtime/ops/annotation/llm_relation_extraction/: Relation extraction operator

Files modified:
- runtime/ops/annotation/__init__.py: Register 3 new operators
- runtime/python-executor/datamate/auto_annotation_worker.py: Add to Worker whitelist
- frontend/src/pages/DataAnnotation/OperatorCreate/hooks/useOperatorOperations.ts: Add to frontend whitelist
2026-02-10 15:22:23 +08:00
15 changed files with 2152 additions and 88 deletions

View File

@@ -22,6 +22,9 @@ type CategoryGroup = {
const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([ const ANNOTATION_OPERATOR_ID_WHITELIST = new Set([
"ImageObjectDetectionBoundingBox", "ImageObjectDetectionBoundingBox",
"test_annotation_marker", "test_annotation_marker",
"LLMTextClassification",
"LLMNamedEntityRecognition",
"LLMRelationExtraction",
]); ]);
const ensureArray = (value: unknown): string[] => { const ensureArray = (value: unknown): string[] => {

View File

@@ -1,10 +1,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Annotation-related operators (e.g. YOLO detection).""" """Annotation-related operators (e.g. YOLO detection, LLM-based NLP annotation)."""
from . import image_object_detection_bounding_box from . import image_object_detection_bounding_box
from . import test_annotation_marker from . import test_annotation_marker
from . import llm_text_classification
from . import llm_named_entity_recognition
from . import llm_relation_extraction
__all__ = [ __all__ = [
"image_object_detection_bounding_box", "image_object_detection_bounding_box",
"test_annotation_marker", "test_annotation_marker",
"llm_text_classification",
"llm_named_entity_recognition",
"llm_relation_extraction",
] ]

View File

@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
"""LLM 配置加载 & OpenAI 兼容调用工具(标注算子共享)。
提供三项核心能力:
1. 从 t_model_config 表加载模型配置(按 ID / 按默认)
2. 调用 OpenAI 兼容 chat/completions API
3. 从 LLM 原始输出中提取 JSON
"""
import json
import re
from typing import Any, Dict
from loguru import logger
# ---------------------------------------------------------------------------
# 模型配置加载
# ---------------------------------------------------------------------------
def load_model_config(model_id: str) -> Dict[str, Any]:
"""根据 model_id 从 t_model_config 读取已启用的模型配置。"""
from datamate.sql_manager.sql_manager import SQLManager
from sqlalchemy import text as sql_text
sql = sql_text(
"""
SELECT model_name, provider, base_url, api_key, type
FROM t_model_config
WHERE id = :model_id AND is_enabled = 1
LIMIT 1
"""
)
with SQLManager.create_connect() as conn:
row = conn.execute(sql, {"model_id": model_id}).fetchone()
if not row:
raise ValueError(f"Model config not found or disabled: {model_id}")
return dict(row._mapping)
def load_default_model_config() -> Dict[str, Any]:
"""加载默认的 chat 模型配置(is_default=1 且 type='chat')。"""
from datamate.sql_manager.sql_manager import SQLManager
from sqlalchemy import text as sql_text
sql = sql_text(
"""
SELECT id, model_name, provider, base_url, api_key, type
FROM t_model_config
WHERE is_enabled = 1 AND is_default = 1 AND type = 'chat'
LIMIT 1
"""
)
with SQLManager.create_connect() as conn:
row = conn.execute(sql).fetchone()
if not row:
raise ValueError("No default chat model configured in t_model_config")
return dict(row._mapping)
def get_llm_config(model_id: str = "") -> Dict[str, Any]:
"""优先按 model_id 加载,未提供则加载默认模型。"""
if model_id:
return load_model_config(model_id)
return load_default_model_config()
# ---------------------------------------------------------------------------
# LLM 调用
# ---------------------------------------------------------------------------
def call_llm(
config: Dict[str, Any],
prompt: str,
system_prompt: str = "",
temperature: float = 0.1,
max_retries: int = 2,
) -> str:
"""调用 OpenAI 兼容的 chat/completions API 并返回文本内容。"""
import requests as http_requests
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
headers: Dict[str, str] = {"Content-Type": "application/json"}
api_key = config.get("api_key", "")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
base_url = str(config["base_url"]).rstrip("/")
# 兼容 base_url 已包含 /v1 或不包含的情况
if not base_url.endswith("/chat/completions"):
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
url = f"{base_url}/chat/completions"
else:
url = base_url
body = {
"model": config["model_name"],
"messages": messages,
"temperature": temperature,
}
last_err = None
for attempt in range(max_retries + 1):
try:
resp = http_requests.post(url, json=body, headers=headers, timeout=120)
resp.raise_for_status()
content = resp.json()["choices"][0]["message"]["content"]
return content
except Exception as e:
last_err = e
logger.warning(
"LLM call attempt {}/{} failed: {}",
attempt + 1,
max_retries + 1,
e,
)
raise RuntimeError(f"LLM call failed after {max_retries + 1} attempts: {last_err}")
# ---------------------------------------------------------------------------
# JSON 提取
# ---------------------------------------------------------------------------
def extract_json(raw: str) -> Any:
"""从 LLM 原始输出中提取 JSON 对象/数组。
处理常见干扰:Markdown 代码块、<think> 标签、前后说明文字。
"""
if not raw:
raise ValueError("Empty LLM response")
# 1. 去除 <think>...</think> 等思考标签
thought_tags = ["think", "thinking", "analysis", "reasoning", "reflection"]
for tag in thought_tags:
raw = re.sub(rf"<{tag}>[\s\S]*?</{tag}>", "", raw, flags=re.IGNORECASE)
# 2. 去除 Markdown 代码块标记
raw = re.sub(r"```(?:json)?\s*", "", raw)
raw = raw.replace("```", "")
# 3. 定位第一个 { 或 [ 到最后一个 } 或 ]
start = None
end = None
for i, ch in enumerate(raw):
if ch in "{[":
start = i
break
for i in range(len(raw) - 1, -1, -1):
if raw[i] in "]}":
end = i + 1
break
if start is not None and end is not None and start < end:
return json.loads(raw[start:end])
# 兜底:直接尝试解析
return json.loads(raw.strip())

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from datamate.core.base_op import OPERATORS
from .process import LLMNamedEntityRecognition
OPERATORS.register_module(
module_name="LLMNamedEntityRecognition",
module_path="ops.annotation.llm_named_entity_recognition.process",
)
__all__ = ["LLMNamedEntityRecognition"]

View File

@@ -0,0 +1,29 @@
name: 'LLM命名实体识别'
name_en: 'LLM Named Entity Recognition'
description: '基于大语言模型的命名实体识别算子,支持自定义实体类型。'
description_en: 'LLM-based NER operator with custom entity types.'
language: 'python'
vendor: 'datamate'
raw_id: 'LLMNamedEntityRecognition'
version: '1.0.0'
types:
- 'annotation'
modal: 'text'
inputs: 'text'
outputs: 'text'
settings:
modelId:
name: '模型ID'
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
type: 'input'
defaultVal: ''
entityTypes:
name: '实体类型'
description: '逗号分隔的实体类型,如:PER,ORG,LOC,DATE'
type: 'input'
defaultVal: 'PER,ORG,LOC,DATE'
outputDir:
name: '输出目录'
description: '算子输出目录(由运行时自动注入)。'
type: 'input'
defaultVal: ''

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
"""LLM 命名实体识别 (NER) 算子。
基于大语言模型从文本中识别命名实体(人名、地名、机构名等),
输出实体列表(含文本片段、实体类型、在原文中的起止位置)。
"""
import json
import os
import shutil
import time
from typing import Any, Dict, List
from loguru import logger
from datamate.core.base_op import Mapper
SYSTEM_PROMPT = (
"你是一个专业的命名实体识别(NER)专家。根据给定的实体类型列表,"
"从输入文本中识别所有命名实体。\n"
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
)
USER_PROMPT_TEMPLATE = """请从以下文本中识别所有命名实体。
实体类型列表:{entity_types}
实体类型说明:
- PER:人名
- ORG:组织/机构名
- LOC:地点/地名
- DATE:日期/时间
- EVENT:事件
- PRODUCT:产品名
- MONEY:金额
- PERCENT:百分比
文本内容:
{text}
请以如下 JSON 格式输出(entities 为实体数组,每个实体包含 text、type、start、end 四个字段):
{{"entities": [{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}]}}
注意:
- type 必须是实体类型列表中的值之一
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
- 如果没有找到任何实体,返回 {{"entities": []}}"""
class LLMNamedEntityRecognition(Mapper):
"""基于 LLM 的命名实体识别算子。"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._model_id: str = kwargs.get("modelId", "")
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC,DATE")
self._output_dir: str = kwargs.get("outputDir", "") or ""
self._llm_config = None
def _get_llm_config(self) -> Dict[str, Any]:
if self._llm_config is None:
from ops.annotation._llm_utils import get_llm_config
self._llm_config = get_llm_config(self._model_id)
return self._llm_config
@staticmethod
def _validate_entities(
entities_raw: Any, allowed_types: List[str]
) -> List[Dict[str, Any]]:
"""校验并过滤实体列表,确保类型在允许范围内。"""
if not isinstance(entities_raw, list):
return []
validated: List[Dict[str, Any]] = []
allowed_set = {t.strip().upper() for t in allowed_types}
for ent in entities_raw:
if not isinstance(ent, dict):
continue
ent_type = str(ent.get("type", "")).strip().upper()
ent_text = str(ent.get("text", "")).strip()
if not ent_text:
continue
# 保留匹配的类型,或在类型列表为空时全部保留
if allowed_set and ent_type not in allowed_set:
continue
validated.append(
{
"text": ent_text,
"type": ent_type,
"start": ent.get("start"),
"end": ent.get("end"),
}
)
return validated
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
start = time.time()
text_path = sample.get(self.text_key)
if not text_path or not os.path.exists(str(text_path)):
logger.warning("Text file not found: {}", text_path)
return sample
text_path = str(text_path)
with open(text_path, "r", encoding="utf-8") as f:
text_content = f.read()
if not text_content.strip():
logger.warning("Empty text file: {}", text_path)
return sample
max_chars = 8000
truncated = text_content[:max_chars]
from ops.annotation._llm_utils import call_llm, extract_json
config = self._get_llm_config()
prompt = USER_PROMPT_TEMPLATE.format(
entity_types=self._entity_types,
text=truncated,
)
allowed_types = [t.strip() for t in self._entity_types.split(",") if t.strip()]
try:
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
parsed = extract_json(raw_response)
entities_raw = parsed.get("entities", []) if isinstance(parsed, dict) else parsed
entities = self._validate_entities(entities_raw, allowed_types)
except Exception as e:
logger.error("LLM NER failed for {}: {}", text_path, e)
entities = []
annotation = {
"file": os.path.basename(text_path),
"task_type": "ner",
"entity_types": self._entity_types,
"model": config.get("model_name", ""),
"entities": entities,
}
# 写入输出
output_dir = self._output_dir or os.path.dirname(text_path)
annotations_dir = os.path.join(output_dir, "annotations")
data_dir = os.path.join(output_dir, "data")
os.makedirs(annotations_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
base_name = os.path.splitext(os.path.basename(text_path))[0]
dst_data = os.path.join(data_dir, os.path.basename(text_path))
if not os.path.exists(dst_data):
shutil.copy2(text_path, dst_data)
json_path = os.path.join(annotations_dir, f"{base_name}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotation, f, indent=2, ensure_ascii=False)
sample["detection_count"] = len(entities)
sample["annotations_file"] = json_path
sample["annotations"] = annotation
elapsed = time.time() - start
logger.info(
"NER: {} -> {} entities, Time: {:.2f}s",
os.path.basename(text_path),
len(entities),
elapsed,
)
return sample

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from datamate.core.base_op import OPERATORS
from .process import LLMRelationExtraction
OPERATORS.register_module(
module_name="LLMRelationExtraction",
module_path="ops.annotation.llm_relation_extraction.process",
)
__all__ = ["LLMRelationExtraction"]

View File

@@ -0,0 +1,34 @@
name: 'LLM关系抽取'
name_en: 'LLM Relation Extraction'
description: '基于大语言模型的关系抽取算子,识别实体并抽取实体间关系三元组。'
description_en: 'LLM-based relation extraction operator that identifies entities and extracts relation triples.'
language: 'python'
vendor: 'datamate'
raw_id: 'LLMRelationExtraction'
version: '1.0.0'
types:
- 'annotation'
modal: 'text'
inputs: 'text'
outputs: 'text'
settings:
modelId:
name: '模型ID'
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
type: 'input'
defaultVal: ''
entityTypes:
name: '实体类型'
description: '逗号分隔的实体类型,如:PER,ORG,LOC'
type: 'input'
defaultVal: 'PER,ORG,LOC'
relationTypes:
name: '关系类型'
description: '逗号分隔的关系类型,如:属于,位于,创立,工作于'
type: 'input'
defaultVal: '属于,位于,创立,工作于'
outputDir:
name: '输出目录'
description: '算子输出目录(由运行时自动注入)。'
type: 'input'
defaultVal: ''

View File

@@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
"""LLM 关系抽取算子。
基于大语言模型从文本中识别实体,并抽取实体之间的关系,
输出实体列表和关系三元组(subject, relation, object)。
"""
import json
import os
import shutil
import time
from typing import Any, Dict, List
from loguru import logger
from datamate.core.base_op import Mapper
SYSTEM_PROMPT = (
"你是一个专业的信息抽取专家。你需要从文本中识别命名实体,并抽取实体之间的关系。\n"
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
)
USER_PROMPT_TEMPLATE = """请从以下文本中识别实体并抽取实体间的关系。
实体类型列表:{entity_types}
关系类型列表:{relation_types}
文本内容:
{text}
请以如下 JSON 格式输出:
{{
"entities": [
{{"text": "实体文本", "type": "PER", "start": 0, "end": 3}}
],
"relations": [
{{
"subject": {{"text": "主语实体", "type": "PER"}},
"relation": "关系类型",
"object": {{"text": "宾语实体", "type": "ORG"}}
}}
]
}}
注意:
- 实体的 type 必须是实体类型列表中的值之一
- 关系的 relation 必须是关系类型列表中的值之一
- start 和 end 是实体在原文中的字符偏移位置(从 0 开始,左闭右开)
- 如果没有找到任何实体或关系,对应数组返回空 []"""
class LLMRelationExtraction(Mapper):
"""基于 LLM 的关系抽取算子。"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._model_id: str = kwargs.get("modelId", "")
self._entity_types: str = kwargs.get("entityTypes", "PER,ORG,LOC")
self._relation_types: str = kwargs.get("relationTypes", "属于,位于,创立,工作于")
self._output_dir: str = kwargs.get("outputDir", "") or ""
self._llm_config = None
def _get_llm_config(self) -> Dict[str, Any]:
if self._llm_config is None:
from ops.annotation._llm_utils import get_llm_config
self._llm_config = get_llm_config(self._model_id)
return self._llm_config
@staticmethod
def _validate_entities(
entities_raw: Any, allowed_types: List[str]
) -> List[Dict[str, Any]]:
if not isinstance(entities_raw, list):
return []
validated: List[Dict[str, Any]] = []
allowed_set = {t.strip().upper() for t in allowed_types} if allowed_types else set()
for ent in entities_raw:
if not isinstance(ent, dict):
continue
ent_text = str(ent.get("text", "")).strip()
ent_type = str(ent.get("type", "")).strip().upper()
if not ent_text:
continue
if allowed_set and ent_type not in allowed_set:
continue
validated.append(
{
"text": ent_text,
"type": ent_type,
"start": ent.get("start"),
"end": ent.get("end"),
}
)
return validated
@staticmethod
def _validate_relations(
relations_raw: Any, allowed_relation_types: List[str]
) -> List[Dict[str, Any]]:
if not isinstance(relations_raw, list):
return []
validated: List[Dict[str, Any]] = []
allowed_set = {t.strip() for t in allowed_relation_types} if allowed_relation_types else set()
for rel in relations_raw:
if not isinstance(rel, dict):
continue
subject = rel.get("subject")
relation = str(rel.get("relation", "")).strip()
obj = rel.get("object")
if not isinstance(subject, dict) or not isinstance(obj, dict):
continue
if not relation:
continue
if allowed_set and relation not in allowed_set:
continue
validated.append(
{
"subject": {
"text": str(subject.get("text", "")),
"type": str(subject.get("type", "")),
},
"relation": relation,
"object": {
"text": str(obj.get("text", "")),
"type": str(obj.get("type", "")),
},
}
)
return validated
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
start = time.time()
text_path = sample.get(self.text_key)
if not text_path or not os.path.exists(str(text_path)):
logger.warning("Text file not found: {}", text_path)
return sample
text_path = str(text_path)
with open(text_path, "r", encoding="utf-8") as f:
text_content = f.read()
if not text_content.strip():
logger.warning("Empty text file: {}", text_path)
return sample
max_chars = 8000
truncated = text_content[:max_chars]
from ops.annotation._llm_utils import call_llm, extract_json
config = self._get_llm_config()
prompt = USER_PROMPT_TEMPLATE.format(
entity_types=self._entity_types,
relation_types=self._relation_types,
text=truncated,
)
allowed_entity_types = [
t.strip() for t in self._entity_types.split(",") if t.strip()
]
allowed_relation_types = [
t.strip() for t in self._relation_types.split(",") if t.strip()
]
try:
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
parsed = extract_json(raw_response)
if not isinstance(parsed, dict):
parsed = {}
entities = self._validate_entities(
parsed.get("entities", []), allowed_entity_types
)
relations = self._validate_relations(
parsed.get("relations", []), allowed_relation_types
)
except Exception as e:
logger.error("LLM relation extraction failed for {}: {}", text_path, e)
entities = []
relations = []
annotation = {
"file": os.path.basename(text_path),
"task_type": "relation_extraction",
"entity_types": self._entity_types,
"relation_types": self._relation_types,
"model": config.get("model_name", ""),
"entities": entities,
"relations": relations,
}
# 写入输出
output_dir = self._output_dir or os.path.dirname(text_path)
annotations_dir = os.path.join(output_dir, "annotations")
data_dir = os.path.join(output_dir, "data")
os.makedirs(annotations_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
base_name = os.path.splitext(os.path.basename(text_path))[0]
dst_data = os.path.join(data_dir, os.path.basename(text_path))
if not os.path.exists(dst_data):
shutil.copy2(text_path, dst_data)
json_path = os.path.join(annotations_dir, f"{base_name}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotation, f, indent=2, ensure_ascii=False)
sample["detection_count"] = len(relations)
sample["annotations_file"] = json_path
sample["annotations"] = annotation
elapsed = time.time() - start
logger.info(
"RelationExtraction: {} -> {} entities, {} relations, Time: {:.2f}s",
os.path.basename(text_path),
len(entities),
len(relations),
elapsed,
)
return sample

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from datamate.core.base_op import OPERATORS
from .process import LLMTextClassification
OPERATORS.register_module(
module_name="LLMTextClassification",
module_path="ops.annotation.llm_text_classification.process",
)
__all__ = ["LLMTextClassification"]

View File

@@ -0,0 +1,29 @@
name: 'LLM文本分类'
name_en: 'LLM Text Classification'
description: '基于大语言模型的文本分类算子,支持自定义类别标签。'
description_en: 'LLM-based text classification operator with custom category labels.'
language: 'python'
vendor: 'datamate'
raw_id: 'LLMTextClassification'
version: '1.0.0'
types:
- 'annotation'
modal: 'text'
inputs: 'text'
outputs: 'text'
settings:
modelId:
name: '模型ID'
description: '已配置的 LLM 模型 ID(留空使用系统默认模型)。'
type: 'input'
defaultVal: ''
categories:
name: '分类标签'
description: '逗号分隔的分类标签列表,如:正面,负面,中性'
type: 'input'
defaultVal: '正面,负面,中性'
outputDir:
name: '输出目录'
description: '算子输出目录(由运行时自动注入)。'
type: 'input'
defaultVal: ''

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""LLM 文本分类算子。
基于大语言模型对文本进行分类,输出分类标签、置信度和简短理由。
支持通过 categories 参数自定义分类标签体系。
"""
import json
import os
import shutil
import time
from typing import Any, Dict
from loguru import logger
from datamate.core.base_op import Mapper
SYSTEM_PROMPT = (
"你是一个专业的文本分类专家。根据给定的类别列表,对输入文本进行分类。\n"
"你必须严格输出 JSON 格式,不要输出任何其他内容。"
)
USER_PROMPT_TEMPLATE = """请对以下文本进行分类。
可选类别:{categories}
文本内容:
{text}
请以如下 JSON 格式输出(label 必须是可选类别之一,confidence 为 0~1 的浮点数):
{{"label": "类别名", "confidence": 0.95, "reasoning": "简短理由"}}"""
class LLMTextClassification(Mapper):
"""基于 LLM 的文本分类算子。"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._model_id: str = kwargs.get("modelId", "")
self._categories: str = kwargs.get("categories", "正面,负面,中性")
self._output_dir: str = kwargs.get("outputDir", "") or ""
self._llm_config = None
def _get_llm_config(self) -> Dict[str, Any]:
if self._llm_config is None:
from ops.annotation._llm_utils import get_llm_config
self._llm_config = get_llm_config(self._model_id)
return self._llm_config
def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]:
start = time.time()
text_path = sample.get(self.text_key)
if not text_path or not os.path.exists(str(text_path)):
logger.warning("Text file not found: {}", text_path)
return sample
text_path = str(text_path)
with open(text_path, "r", encoding="utf-8") as f:
text_content = f.read()
if not text_content.strip():
logger.warning("Empty text file: {}", text_path)
return sample
# 截断过长文本以适应 LLM 上下文窗口
max_chars = 8000
truncated = text_content[:max_chars]
from ops.annotation._llm_utils import call_llm, extract_json
config = self._get_llm_config()
prompt = USER_PROMPT_TEMPLATE.format(
categories=self._categories,
text=truncated,
)
try:
raw_response = call_llm(config, prompt, system_prompt=SYSTEM_PROMPT)
result = extract_json(raw_response)
except Exception as e:
logger.error("LLM classification failed for {}: {}", text_path, e)
result = {
"label": "unknown",
"confidence": 0.0,
"reasoning": f"LLM call or JSON parse failed: {e}",
}
annotation = {
"file": os.path.basename(text_path),
"task_type": "text_classification",
"categories": self._categories,
"model": config.get("model_name", ""),
"result": result,
}
# 确定输出目录
output_dir = self._output_dir or os.path.dirname(text_path)
annotations_dir = os.path.join(output_dir, "annotations")
data_dir = os.path.join(output_dir, "data")
os.makedirs(annotations_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
base_name = os.path.splitext(os.path.basename(text_path))[0]
# 复制原始文本到 data 目录
dst_data = os.path.join(data_dir, os.path.basename(text_path))
if not os.path.exists(dst_data):
shutil.copy2(text_path, dst_data)
# 写入标注 JSON
json_path = os.path.join(annotations_dir, f"{base_name}.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(annotation, f, indent=2, ensure_ascii=False)
sample["detection_count"] = 1
sample["annotations_file"] = json_path
sample["annotations"] = annotation
elapsed = time.time() - start
logger.info(
"TextClassification: {} -> {}, Time: {:.2f}s",
os.path.basename(text_path),
result.get("label", "N/A"),
elapsed,
)
return sample

View File

@@ -0,0 +1,491 @@
# -*- coding: utf-8 -*-
"""将自动标注算子输出转换为 Label Studio 兼容格式。
支持的算子类型:
- LLMTextClassification → choices
- LLMNamedEntityRecognition → labels (span)
- LLMRelationExtraction → labels + relation
- ImageObjectDetectionBoundingBox → rectanglelabels
"""
from __future__ import annotations
import hashlib
import uuid
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
# ---------------------------------------------------------------------------
# 颜色调色板(Label Studio 背景色)
# ---------------------------------------------------------------------------
_LABEL_COLORS = [
"#e53935",
"#fb8c00",
"#43a047",
"#1e88e5",
"#8e24aa",
"#00897b",
"#d81b60",
"#3949ab",
"#fdd835",
"#6d4c41",
"#546e7a",
"#f4511e",
]
def _pick_color(index: int) -> str:
return _LABEL_COLORS[index % len(_LABEL_COLORS)]
# ---------------------------------------------------------------------------
# 稳定 LS ID 生成(与 editor.py:118-136 一致)
# ---------------------------------------------------------------------------
def _stable_ls_id(seed: str) -> int:
"""生成稳定的 Label Studio 风格整数 ID(JS 安全整数范围内)。"""
digest = hashlib.sha1(seed.encode("utf-8")).hexdigest()
value = int(digest[:13], 16)
return value if value > 0 else 1
def make_ls_task_id(project_id: str, file_id: str) -> int:
return _stable_ls_id(f"task:{project_id}:{file_id}")
def make_ls_annotation_id(project_id: str, file_id: str) -> int:
return _stable_ls_id(f"annotation:{project_id}:{file_id}")
# ---------------------------------------------------------------------------
# 确定性 region ID(关系抽取需要稳定引用)
# ---------------------------------------------------------------------------
_NAMESPACE_REGION = uuid.UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
def _make_region_id(
project_id: str, file_id: str, text: str, entity_type: str, start: Any, end: Any
) -> str:
seed = f"{project_id}:{file_id}:{text}:{entity_type}:{start}:{end}"
return str(uuid.uuid5(_NAMESPACE_REGION, seed))
# ---------------------------------------------------------------------------
# 1. 文本分类
# ---------------------------------------------------------------------------
def convert_text_classification(
annotation: Dict[str, Any], file_id: str, project_id: str
) -> Optional[Dict[str, Any]]:
"""将 LLMTextClassification 算子输出转换为 LS annotation。"""
result_data = annotation.get("result")
if not isinstance(result_data, dict):
return None
label = result_data.get("label")
if not label:
return None
region_id = str(uuid.uuid4())
ls_result = [
{
"id": region_id,
"from_name": "sentiment",
"to_name": "text",
"type": "choices",
"value": {"choices": [str(label)]},
}
]
return {
"id": make_ls_annotation_id(project_id, file_id),
"task": make_ls_task_id(project_id, file_id),
"result": ls_result,
}
# ---------------------------------------------------------------------------
# 2. 命名实体识别
# ---------------------------------------------------------------------------
def convert_ner(
annotation: Dict[str, Any], file_id: str, project_id: str
) -> Optional[Dict[str, Any]]:
"""将 LLMNamedEntityRecognition 算子输出转换为 LS annotation。"""
entities = annotation.get("entities")
if not isinstance(entities, list) or not entities:
return None
ls_result: List[Dict[str, Any]] = []
for ent in entities:
if not isinstance(ent, dict):
continue
ent_text = ent.get("text", "")
ent_type = ent.get("type", "")
start = ent.get("start")
end = ent.get("end")
if not ent_text or start is None or end is None:
continue
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
ls_result.append(
{
"id": region_id,
"from_name": "label",
"to_name": "text",
"type": "labels",
"value": {
"start": int(start),
"end": int(end),
"text": str(ent_text),
"labels": [str(ent_type)],
},
}
)
if not ls_result:
return None
return {
"id": make_ls_annotation_id(project_id, file_id),
"task": make_ls_task_id(project_id, file_id),
"result": ls_result,
}
# ---------------------------------------------------------------------------
# 3. 关系抽取
# ---------------------------------------------------------------------------
def _find_entity_region_id(
entity_regions: List[Dict[str, Any]], text: str, entity_type: str
) -> Optional[str]:
"""在已生成的 entity regions 中查找匹配的 region ID。"""
for region in entity_regions:
value = region.get("value", {})
if value.get("text") == text and entity_type in value.get("labels", []):
return region["id"]
return None
def convert_relation_extraction(
annotation: Dict[str, Any], file_id: str, project_id: str
) -> Optional[Dict[str, Any]]:
"""将 LLMRelationExtraction 算子输出转换为 LS annotation。"""
entities = annotation.get("entities", [])
relations = annotation.get("relations", [])
if not isinstance(entities, list):
entities = []
if not isinstance(relations, list):
relations = []
if not entities and not relations:
return None
# 构建实体 label regions(去重)
entity_regions: List[Dict[str, Any]] = []
seen_keys: set = set()
for ent in entities:
if not isinstance(ent, dict):
continue
ent_text = str(ent.get("text", "")).strip()
ent_type = str(ent.get("type", "")).strip()
start = ent.get("start")
end = ent.get("end")
if not ent_text or start is None or end is None:
continue
dedup_key = (ent_text, ent_type, int(start), int(end))
if dedup_key in seen_keys:
continue
seen_keys.add(dedup_key)
region_id = _make_region_id(project_id, file_id, ent_text, ent_type, start, end)
entity_regions.append(
{
"id": region_id,
"from_name": "label",
"to_name": "text",
"type": "labels",
"value": {
"start": int(start),
"end": int(end),
"text": ent_text,
"labels": [ent_type],
},
}
)
# 构建关系
relation_results: List[Dict[str, Any]] = []
for rel in relations:
if not isinstance(rel, dict):
continue
subject = rel.get("subject")
obj = rel.get("object")
relation_type = str(rel.get("relation", "")).strip()
if not isinstance(subject, dict) or not isinstance(obj, dict) or not relation_type:
continue
subj_text = str(subject.get("text", "")).strip()
subj_type = str(subject.get("type", "")).strip()
obj_text = str(obj.get("text", "")).strip()
obj_type = str(obj.get("type", "")).strip()
from_id = _find_entity_region_id(entity_regions, subj_text, subj_type)
to_id = _find_entity_region_id(entity_regions, obj_text, obj_type)
if not from_id or not to_id:
logger.debug(
"Skipping relation '{}': could not find region for subject='{}' or object='{}'",
relation_type,
subj_text,
obj_text,
)
continue
relation_results.append(
{
"from_id": from_id,
"to_id": to_id,
"type": "relation",
"direction": "right",
"labels": [relation_type],
}
)
ls_result = entity_regions + relation_results
if not ls_result:
return None
return {
"id": make_ls_annotation_id(project_id, file_id),
"task": make_ls_task_id(project_id, file_id),
"result": ls_result,
}
# ---------------------------------------------------------------------------
# 4. 目标检测
# ---------------------------------------------------------------------------
def convert_object_detection(
annotation: Dict[str, Any], file_id: str, project_id: str
) -> Optional[Dict[str, Any]]:
"""将 ImageObjectDetectionBoundingBox 算子输出转换为 LS annotation。"""
detections = annotation.get("detections")
if not isinstance(detections, list) or not detections:
return None
img_width = annotation.get("width", 0)
img_height = annotation.get("height", 0)
if not img_width or not img_height:
return None
ls_result: List[Dict[str, Any]] = []
for det in detections:
if not isinstance(det, dict):
continue
label = det.get("label", "unknown")
bbox = det.get("bbox_xyxy")
if not isinstance(bbox, (list, tuple)) or len(bbox) < 4:
continue
x1, y1, x2, y2 = float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])
x_pct = x1 * 100.0 / img_width
y_pct = y1 * 100.0 / img_height
w_pct = (x2 - x1) * 100.0 / img_width
h_pct = (y2 - y1) * 100.0 / img_height
region_id = str(uuid.uuid4())
ls_result.append(
{
"id": region_id,
"from_name": "label",
"to_name": "image",
"type": "rectanglelabels",
"original_width": int(img_width),
"original_height": int(img_height),
"image_rotation": 0,
"value": {
"x": round(x_pct, 4),
"y": round(y_pct, 4),
"width": round(w_pct, 4),
"height": round(h_pct, 4),
"rotation": 0,
"rectanglelabels": [str(label)],
},
}
)
if not ls_result:
return None
return {
"id": make_ls_annotation_id(project_id, file_id),
"task": make_ls_task_id(project_id, file_id),
"result": ls_result,
}
# ---------------------------------------------------------------------------
# 分发器
# ---------------------------------------------------------------------------
TASK_TYPE_CONVERTERS = {
"text_classification": convert_text_classification,
"ner": convert_ner,
"relation_extraction": convert_relation_extraction,
"object_detection": convert_object_detection,
}
def convert_annotation(
annotation: Dict[str, Any], file_id: str, project_id: str
) -> Optional[Dict[str, Any]]:
"""根据 task_type 分发到对应的转换函数。"""
task_type = annotation.get("task_type")
if task_type is None and "detections" in annotation:
task_type = "object_detection"
converter = TASK_TYPE_CONVERTERS.get(task_type) # type: ignore[arg-type]
if converter is None:
logger.warning("No LS converter for task_type: {}", task_type)
return None
try:
return converter(annotation, file_id, project_id)
except Exception as exc:
logger.error("Failed to convert annotation (task_type={}): {}", task_type, exc)
return None
# ---------------------------------------------------------------------------
# label_config XML 生成
# ---------------------------------------------------------------------------
def _split_labels(raw: str) -> List[str]:
"""逗号分隔字符串 → 去空白列表。"""
return [s.strip() for s in raw.split(",") if s.strip()]
def generate_label_config_xml(task_type: str, operator_params: Dict[str, Any]) -> str:
"""根据标注类型和算子参数生成 Label Studio label_config XML。"""
if task_type == "text_classification":
return _gen_text_classification_xml(operator_params)
if task_type == "ner":
return _gen_ner_xml(operator_params)
if task_type == "relation_extraction":
return _gen_relation_extraction_xml(operator_params)
if task_type == "object_detection":
return _gen_object_detection_xml(operator_params)
# 未知类型:返回最小可用 XML
return "<View><Header value=\"Unknown annotation type\"/></View>"
def _gen_text_classification_xml(params: Dict[str, Any]) -> str:
categories = _split_labels(str(params.get("categories", "正面,负面,中性")))
view = ET.Element("View")
ET.SubElement(view, "Text", name="text", value="$text")
choices = ET.SubElement(
view,
"Choices",
name="sentiment",
toName="text",
choice="single",
showInline="true",
)
for cat in categories:
ET.SubElement(choices, "Choice", value=cat)
return ET.tostring(view, encoding="unicode")
def _gen_ner_xml(params: Dict[str, Any]) -> str:
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC,DATE")))
view = ET.Element("View")
labels = ET.SubElement(view, "Labels", name="label", toName="text")
for i, et in enumerate(entity_types):
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
ET.SubElement(view, "Text", name="text", value="$text")
return ET.tostring(view, encoding="unicode")
def _gen_relation_extraction_xml(params: Dict[str, Any]) -> str:
entity_types = _split_labels(str(params.get("entityTypes", "PER,ORG,LOC")))
relation_types = _split_labels(
str(params.get("relationTypes", "属于,位于,创立,工作于"))
)
view = ET.Element("View")
relations = ET.SubElement(view, "Relations")
for rt in relation_types:
ET.SubElement(relations, "Relation", value=rt)
labels = ET.SubElement(view, "Labels", name="label", toName="text")
for i, et in enumerate(entity_types):
ET.SubElement(labels, "Label", value=et, background=_pick_color(i))
ET.SubElement(view, "Text", name="text", value="$text")
return ET.tostring(view, encoding="unicode")
def _gen_object_detection_xml(params: Dict[str, Any]) -> str:
detected_labels: List[str] = params.get("_detected_labels", [])
if not detected_labels:
detected_labels = ["object"]
view = ET.Element("View")
ET.SubElement(view, "Image", name="image", value="$image")
rect_labels = ET.SubElement(view, "RectangleLabels", name="label", toName="image")
for i, lbl in enumerate(detected_labels):
ET.SubElement(rect_labels, "Label", value=lbl, background=_pick_color(i))
return ET.tostring(view, encoding="unicode")
# ---------------------------------------------------------------------------
# Pipeline 信息提取
# ---------------------------------------------------------------------------
OPERATOR_TASK_TYPE_MAP: Dict[str, str] = {
"LLMTextClassification": "text_classification",
"LLMNamedEntityRecognition": "ner",
"LLMRelationExtraction": "relation_extraction",
"ImageObjectDetectionBoundingBox": "object_detection",
}
def infer_task_type_from_pipeline(pipeline: List[Dict[str, Any]]) -> Optional[str]:
"""从 normalized pipeline 中推断标注类型(取第一个匹配的算子)。"""
for step in pipeline:
operator_id = str(step.get("operator_id", ""))
task_type = OPERATOR_TASK_TYPE_MAP.get(operator_id)
if task_type is not None:
return task_type
return None
def extract_operator_params(pipeline: List[Dict[str, Any]]) -> Dict[str, Any]:
"""从 normalized pipeline 中提取第一个标注算子的 overrides 参数。"""
for step in pipeline:
operator_id = str(step.get("operator_id", ""))
if operator_id in OPERATOR_TASK_TYPE_MAP:
return dict(step.get("overrides", {}))
return {}

View File

@@ -2,31 +2,34 @@
"""Simple background worker for auto-annotation tasks. """Simple background worker for auto-annotation tasks.
This module runs inside the datamate-runtime container (operator_runtime service). This module runs inside the datamate-runtime container (operator_runtime service).
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO It polls `t_dm_auto_annotation_tasks` for pending tasks and performs annotation
inference using the ImageObjectDetectionBoundingBox operator, updating using configurable operator pipelines (YOLO, LLM text classification, NER,
progress back to the same table so that the datamate-python backend and relation extraction, etc.), updating progress back to the same table so that
frontend can display real-time status. the datamate-python backend and frontend can display real-time status.
设计目标(最小可用版本): 设计:
- 单实例 worker,串行处理 `pending` 状态的任务。 - 多任务并发: 可通过 AUTO_ANNOTATION_WORKER_COUNT 启动多个 worker 线程,
- 对指定数据集下的所有已完成文件逐张执行目标检测 各自独立轮询和认领 pending 任务(run_token 原子 claim 保证不重复)
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。 - 任务内文件并发: 可通过 AUTO_ANNOTATION_FILE_WORKERS 配置线程池大小,
- 失败时将任务标记为 `failed` 并记录 `error_message` 单任务内并行处理多个文件(LLM I/O 密集型场景尤其有效)
算子链通过对象池隔离,每个线程使用独立的链实例。
注意: - 进度更新节流: 可通过 AUTO_ANNOTATION_PROGRESS_INTERVAL 控制进度写入频率,
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时, 避免大数据集每文件都写 DB 造成的写压力(默认 2 秒间隔)。
已处于 running 任务不会被重新拉起,需要后续扩展。 - 启动时自动恢复心跳超时的 running 任务:未处理文件重置为 pending,
已有部分进度的标记为 failed,由用户决定是否手动重试。
""" """
from __future__ import annotations from __future__ import annotations
import importlib import importlib
import json import json
import os import os
import queue
import sys import sys
import threading import threading
import time import time
import uuid import uuid
from datetime import datetime from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@@ -122,9 +125,121 @@ DEFAULT_OUTPUT_ROOT = os.getenv(
DEFAULT_OPERATOR_WHITELIST = os.getenv( DEFAULT_OPERATOR_WHITELIST = os.getenv(
"AUTO_ANNOTATION_OPERATOR_WHITELIST", "AUTO_ANNOTATION_OPERATOR_WHITELIST",
"ImageObjectDetectionBoundingBox,test_annotation_marker", "ImageObjectDetectionBoundingBox,test_annotation_marker,"
"LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction",
) )
HEARTBEAT_TIMEOUT_SECONDS = int(os.getenv("AUTO_ANNOTATION_HEARTBEAT_TIMEOUT", "300"))
WORKER_COUNT = int(os.getenv("AUTO_ANNOTATION_WORKER_COUNT", "1"))
FILE_WORKERS = int(os.getenv("AUTO_ANNOTATION_FILE_WORKERS", "1"))
PROGRESS_UPDATE_INTERVAL = float(os.getenv("AUTO_ANNOTATION_PROGRESS_INTERVAL", "2.0"))
def _recover_stale_running_tasks() -> int:
"""启动时恢复心跳超时的 running 任务。
- processed_images = 0 → 重置为 pending(自动重试)
- processed_images > 0 → 标记为 failed(需用户干预)
Returns:
恢复的任务数量。
"""
if HEARTBEAT_TIMEOUT_SECONDS <= 0:
logger.info(
"Heartbeat timeout disabled (HEARTBEAT_TIMEOUT_SECONDS={}), skipping recovery",
HEARTBEAT_TIMEOUT_SECONDS,
)
return 0
cutoff = datetime.now() - timedelta(seconds=HEARTBEAT_TIMEOUT_SECONDS)
find_sql = text("""
SELECT id, processed_images, total_images, heartbeat_at
FROM t_dm_auto_annotation_tasks
WHERE status = 'running'
AND deleted_at IS NULL
AND (heartbeat_at IS NULL OR heartbeat_at < :cutoff)
""")
with SQLManager.create_connect() as conn:
rows = conn.execute(find_sql, {"cutoff": cutoff}).fetchall()
if not rows:
return 0
recovered = 0
for row in rows:
task_id = row[0]
processed = row[1] or 0
total = row[2] or 0
heartbeat = row[3]
try:
if processed == 0:
# 未开始处理,重置为 pending 自动重试
reset_sql = text("""
UPDATE t_dm_auto_annotation_tasks
SET status = 'pending',
run_token = NULL,
heartbeat_at = NULL,
started_at = NULL,
error_message = NULL,
updated_at = :now
WHERE id = :task_id AND status = 'running'
""")
with SQLManager.create_connect() as conn:
result = conn.execute(
reset_sql, {"task_id": task_id, "now": datetime.now()}
)
if int(getattr(result, "rowcount", 0) or 0) > 0:
recovered += 1
logger.info(
"Recovered stale task {} -> pending (no progress, heartbeat={})",
task_id,
heartbeat,
)
else:
# 已有部分进度,标记为 failed
error_msg = (
f"Worker 心跳超时(上次心跳: {heartbeat}"
f"超时阈值: {HEARTBEAT_TIMEOUT_SECONDS}秒)。"
f"已处理 {processed}/{total} 个文件。请检查后手动重试。"
)
fail_sql = text("""
UPDATE t_dm_auto_annotation_tasks
SET status = 'failed',
run_token = NULL,
error_message = :error_message,
completed_at = :now,
updated_at = :now
WHERE id = :task_id AND status = 'running'
""")
with SQLManager.create_connect() as conn:
result = conn.execute(
fail_sql,
{
"task_id": task_id,
"error_message": error_msg[:2000],
"now": datetime.now(),
},
)
if int(getattr(result, "rowcount", 0) or 0) > 0:
recovered += 1
logger.warning(
"Recovered stale task {} -> failed (processed {}/{}, heartbeat={})",
task_id,
processed,
total,
heartbeat,
)
except Exception as exc:
logger.error("Failed to recover stale task {}: {}", task_id, exc)
return recovered
def _fetch_pending_task() -> Optional[Dict[str, Any]]: def _fetch_pending_task() -> Optional[Dict[str, Any]]:
"""原子 claim 一个 pending 任务并返回任务详情。""" """原子 claim 一个 pending 任务并返回任务详情。"""
@@ -856,6 +971,139 @@ def _register_output_dataset(
) )
def _create_labeling_project_with_annotations(
task_id: str,
dataset_id: str,
dataset_name: str,
task_name: str,
dataset_type: str,
normalized_pipeline: List[Dict[str, Any]],
file_results: List[Tuple[str, Dict[str, Any]]],
all_file_ids: List[str],
) -> None:
"""将自动标注结果转换为 Label Studio 格式,创建标注项目并写入标注结果。"""
from datamate.annotation_result_converter import (
convert_annotation,
extract_operator_params,
generate_label_config_xml,
infer_task_type_from_pipeline,
)
task_type = infer_task_type_from_pipeline(normalized_pipeline)
if not task_type:
logger.warning(
"Cannot infer task_type from pipeline for task {}, skipping labeling project creation",
task_id,
)
return
operator_params = extract_operator_params(normalized_pipeline)
# 目标检测:从实际检测结果中收集唯一标签列表
if task_type == "object_detection":
all_labels: set = set()
for _, ann in file_results:
for det in ann.get("detections", []):
if isinstance(det, dict):
all_labels.add(str(det.get("label", "unknown")))
operator_params["_detected_labels"] = sorted(all_labels)
label_config = generate_label_config_xml(task_type, operator_params)
project_id = str(uuid.uuid4())
labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8)
project_name = f"自动标注 - {task_name or dataset_name or task_id[:8]}"[:100]
now = datetime.now()
configuration = json.dumps(
{
"label_config": label_config,
"description": f"由自动标注任务 {task_id[:8]} 自动创建",
"auto_annotation_task_id": task_id,
},
ensure_ascii=False,
)
insert_project_sql = text(
"""
INSERT INTO t_dm_labeling_projects
(id, dataset_id, name, labeling_project_id, template_id, configuration, created_at, updated_at)
VALUES
(:id, :dataset_id, :name, :labeling_project_id, NULL, :configuration, :now, :now)
"""
)
insert_snapshot_sql = text(
"""
INSERT INTO t_dm_labeling_project_files (id, project_id, file_id, created_at)
VALUES (:id, :project_id, :file_id, :now)
"""
)
insert_annotation_sql = text(
"""
INSERT INTO t_dm_annotation_results
(id, project_id, file_id, annotation, annotation_status, file_version, created_at, updated_at)
VALUES
(:id, :project_id, :file_id, :annotation, :annotation_status, :file_version, :now, :now)
"""
)
with SQLManager.create_connect() as conn:
# 1. 创建标注项目
conn.execute(
insert_project_sql,
{
"id": project_id,
"dataset_id": dataset_id,
"name": project_name,
"labeling_project_id": labeling_project_id,
"configuration": configuration,
"now": now,
},
)
# 2. 创建项目文件快照
for file_id in all_file_ids:
conn.execute(
insert_snapshot_sql,
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
"now": now,
},
)
# 3. 转换并写入标注结果
converted_count = 0
for file_id, annotation in file_results:
ls_annotation = convert_annotation(annotation, file_id, project_id)
if ls_annotation is None:
continue
conn.execute(
insert_annotation_sql,
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
"annotation": json.dumps(ls_annotation, ensure_ascii=False),
"annotation_status": "ANNOTATED",
"file_version": 1,
"now": now,
},
)
converted_count += 1
logger.info(
"Created labeling project {} ({}) with {} annotations for auto-annotation task {}",
project_id,
project_name,
converted_count,
task_id,
)
def _process_single_task(task: Dict[str, Any]) -> None: def _process_single_task(task: Dict[str, Any]) -> None:
"""执行单个自动标注任务。""" """执行单个自动标注任务。"""
@@ -914,7 +1162,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
else: else:
all_files = _load_dataset_files(dataset_id) all_files = _load_dataset_files(dataset_id)
files = [(path, name) for _, path, name in all_files] files = all_files # [(file_id, file_path, file_name)]
total_images = len(files) total_images = len(files)
if total_images == 0: if total_images == 0:
@@ -962,10 +1210,6 @@ def _process_single_task(task: Dict[str, Any]) -> None:
raise RuntimeError("Pipeline is empty after normalization") raise RuntimeError("Pipeline is empty after normalization")
_validate_pipeline_whitelist(normalized_pipeline) _validate_pipeline_whitelist(normalized_pipeline)
chain = _build_operator_chain(normalized_pipeline)
if not chain:
raise RuntimeError("No valid operator instances initialized")
except Exception as e: except Exception as e:
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e) logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
_update_task_status( _update_task_status(
@@ -980,13 +1224,120 @@ def _process_single_task(task: Dict[str, Any]) -> None:
) )
return return
# --- 构建算子链池(每个线程使用独立的链实例,避免线程安全问题)---
effective_file_workers = max(1, FILE_WORKERS)
chain_pool: queue.Queue = queue.Queue()
try:
for _ in range(effective_file_workers):
c = _build_operator_chain(normalized_pipeline)
if not c:
raise RuntimeError("No valid operator instances initialized")
chain_pool.put(c)
except Exception as e:
logger.error("Failed to build operator chain pool for task {}: {}", task_id, e)
_update_task_status(
task_id,
run_token=run_token,
status="failed",
total_images=total_images,
processed_images=0,
detected_objects=0,
error_message=f"Init pipeline failed: {e}",
clear_run_token=True,
)
return
processed = 0 processed = 0
detected_total = 0 detected_total = 0
file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations)
stopped = False
try: try:
# --- 线程安全的进度跟踪 ---
progress_lock = threading.Lock()
stop_event = threading.Event()
for file_path, file_name in files: def _process_file(
file_id: str, file_path: str, file_name: str,
) -> Optional[Tuple[str, Dict[str, Any]]]:
"""在线程池中处理单个文件。"""
if stop_event.is_set():
return None
chain = chain_pool.get()
try:
sample_key = _get_sample_key(dataset_type)
sample: Dict[str, Any] = {
sample_key: file_path,
"filename": file_name,
}
result = _run_pipeline_sample(sample, chain)
return (file_id, result)
finally:
chain_pool.put(chain)
# --- 并发文件处理 ---
stop_check_interval = max(1, effective_file_workers * 2)
completed_since_check = 0
last_progress_update = time.monotonic()
with ThreadPoolExecutor(max_workers=effective_file_workers) as executor:
future_to_file = {
executor.submit(_process_file, fid, fpath, fname): (fid, fpath, fname)
for fid, fpath, fname in files
}
for future in as_completed(future_to_file):
fid, fpath, fname = future_to_file[future]
try:
result = future.result()
if result is None:
continue
file_id_out, sample_result = result
detections = _count_detections(sample_result)
ann = sample_result.get("annotations")
with progress_lock:
processed += 1
detected_total += detections
if isinstance(ann, dict):
file_results.append((file_id_out, ann))
current_processed = processed
current_detected = detected_total
now = time.monotonic()
if PROGRESS_UPDATE_INTERVAL <= 0 or (now - last_progress_update) >= PROGRESS_UPDATE_INTERVAL:
progress = int(current_processed * 100 / total_images) if total_images > 0 else 100
_update_task_status(
task_id,
run_token=run_token,
status="running",
progress=progress,
processed_images=current_processed,
detected_objects=current_detected,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
last_progress_update = now
except Exception as e:
logger.error(
"Failed to process file for task {}: file_path={}, error={}",
task_id,
fpath,
e,
)
completed_since_check += 1
if completed_since_check >= stop_check_interval:
completed_since_check = 0
if _is_stop_requested(task_id, run_token): if _is_stop_requested(task_id, run_token):
stop_event.set()
for f in future_to_file:
f.cancel()
stopped = True
break
if stopped:
logger.info("Task stop requested during processing: {}", task_id) logger.info("Task stop requested during processing: {}", task_id)
_update_task_status( _update_task_status(
task_id, task_id,
@@ -1002,41 +1353,7 @@ def _process_single_task(task: Dict[str, Any]) -> None:
clear_run_token=True, clear_run_token=True,
error_message="Task stopped by request", error_message="Task stopped by request",
) )
return else:
try:
sample_key = _get_sample_key(dataset_type)
sample = {
sample_key: file_path,
"filename": file_name,
}
result = _run_pipeline_sample(sample, chain)
detected_total += _count_detections(result)
processed += 1
progress = int(processed * 100 / total_images) if total_images > 0 else 100
_update_task_status(
task_id,
run_token=run_token,
status="running",
progress=progress,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
except Exception as e:
logger.error(
"Failed to process file for task {}: file_path={}, error={}",
task_id,
file_path,
e,
)
continue
_update_task_status( _update_task_status(
task_id, task_id,
run_token=run_token, run_token=run_token,
@@ -1076,6 +1393,26 @@ def _process_single_task(task: Dict[str, Any]) -> None:
task_id, task_id,
e, e,
) )
# 将自动标注结果转换为 Label Studio 格式并写入标注项目
if file_results:
try:
_create_labeling_project_with_annotations(
task_id=task_id,
dataset_id=dataset_id,
dataset_name=source_dataset_name,
task_name=task_name,
dataset_type=dataset_type,
normalized_pipeline=normalized_pipeline,
file_results=file_results,
all_file_ids=[fid for fid, _, _ in all_files],
)
except Exception as e: # pragma: no cover - 防御性日志
logger.error(
"Failed to create labeling project for auto-annotation task {}: {}",
task_id,
e,
)
except Exception as e: except Exception as e:
logger.error("Task execution failed for task {}: {}", task_id, e) logger.error("Task execution failed for task {}: {}", task_id, e)
_update_task_status( _update_task_status(
@@ -1097,9 +1434,10 @@ def _worker_loop() -> None:
"""Worker 主循环,在独立线程中运行。""" """Worker 主循环,在独立线程中运行。"""
logger.info( logger.info(
"Auto-annotation worker started with poll interval {} seconds, output root {}", "Auto-annotation worker started (poll_interval={}s, output_root={}, file_workers={})",
POLL_INTERVAL_SECONDS, POLL_INTERVAL_SECONDS,
DEFAULT_OUTPUT_ROOT, DEFAULT_OUTPUT_ROOT,
FILE_WORKERS,
) )
while True: while True:
@@ -1118,6 +1456,20 @@ def _worker_loop() -> None:
def start_auto_annotation_worker() -> None: def start_auto_annotation_worker() -> None:
"""在后台线程中启动自动标注 worker。""" """在后台线程中启动自动标注 worker。"""
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True) # 启动前执行一次恢复(在 worker 线程启动前运行,避免多线程重复恢复)
try:
recovered = _recover_stale_running_tasks()
if recovered > 0:
logger.info("Recovered {} stale running task(s) on startup", recovered)
except Exception as e:
logger.error("Failed to run startup task recovery: {}", e)
count = max(1, WORKER_COUNT)
for i in range(count):
thread = threading.Thread(
target=_worker_loop,
name=f"auto-annotation-worker-{i}",
daemon=True,
)
thread.start() thread.start()
logger.info("Auto-annotation worker thread started: {}", thread.name) logger.info("Auto-annotation worker thread started: {}", thread.name)

View File

@@ -0,0 +1,390 @@
# -*- coding: utf-8 -*-
"""Tests for auto_annotation_worker concurrency features (improvement #4).
Covers:
- Multi-worker startup (WORKER_COUNT)
- Intra-task file parallelism (FILE_WORKERS)
- Chain pool acquire/release
- Thread-safe progress tracking
- Stop request handling during concurrent processing
"""
from __future__ import annotations
import os
import queue
import sys
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import MagicMock, patch, call
# ---------------------------------------------------------------------------
# Ensure the module under test can be imported
# ---------------------------------------------------------------------------
RUNTIME_ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
EXECUTOR_ROOT = os.path.join(os.path.dirname(__file__), "..")
for p in (RUNTIME_ROOT, EXECUTOR_ROOT):
abs_p = os.path.abspath(p)
if abs_p not in sys.path:
sys.path.insert(0, abs_p)
class TestWorkerCountConfig(unittest.TestCase):
"""Test that start_auto_annotation_worker launches WORKER_COUNT threads."""
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0)
@patch("datamate.auto_annotation_worker._worker_loop")
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 3)
def test_multiple_worker_threads_launched(self, mock_loop, mock_recover):
"""WORKER_COUNT=3 should launch 3 daemon threads."""
from datamate.auto_annotation_worker import start_auto_annotation_worker
started_threads: List[threading.Thread] = []
original_thread_init = threading.Thread.__init__
def track_thread(self_thread, *args, **kwargs):
original_thread_init(self_thread, *args, **kwargs)
started_threads.append(self_thread)
with patch.object(threading.Thread, "__init__", track_thread):
with patch.object(threading.Thread, "start"):
start_auto_annotation_worker()
self.assertEqual(len(started_threads), 3)
for i, t in enumerate(started_threads):
self.assertEqual(t.name, f"auto-annotation-worker-{i}")
self.assertTrue(t.daemon)
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0)
@patch("datamate.auto_annotation_worker._worker_loop")
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 1)
def test_single_worker_default(self, mock_loop, mock_recover):
"""WORKER_COUNT=1 (default) should launch exactly 1 thread."""
from datamate.auto_annotation_worker import start_auto_annotation_worker
started_threads: List[threading.Thread] = []
original_thread_init = threading.Thread.__init__
def track_thread(self_thread, *args, **kwargs):
original_thread_init(self_thread, *args, **kwargs)
started_threads.append(self_thread)
with patch.object(threading.Thread, "__init__", track_thread):
with patch.object(threading.Thread, "start"):
start_auto_annotation_worker()
self.assertEqual(len(started_threads), 1)
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks", side_effect=RuntimeError("db down"))
@patch("datamate.auto_annotation_worker._worker_loop")
@patch("datamate.auto_annotation_worker.WORKER_COUNT", 2)
def test_recovery_failure_doesnt_block_workers(self, mock_loop, mock_recover):
"""Recovery failure should not prevent worker threads from starting."""
from datamate.auto_annotation_worker import start_auto_annotation_worker
started_threads: List[threading.Thread] = []
original_thread_init = threading.Thread.__init__
def track_thread(self_thread, *args, **kwargs):
original_thread_init(self_thread, *args, **kwargs)
started_threads.append(self_thread)
with patch.object(threading.Thread, "__init__", track_thread):
with patch.object(threading.Thread, "start"):
start_auto_annotation_worker()
# Workers should still be launched despite recovery failure
self.assertEqual(len(started_threads), 2)
class TestChainPool(unittest.TestCase):
"""Test the chain pool pattern used for operator instance isolation."""
def test_pool_acquire_release(self):
"""Each thread should get its own chain and return it after use."""
pool: queue.Queue = queue.Queue()
chains = [f"chain-{i}" for i in range(3)]
for c in chains:
pool.put(c)
acquired: List[str] = []
lock = threading.Lock()
def worker():
chain = pool.get()
with lock:
acquired.append(chain)
time.sleep(0.01)
pool.put(chain)
threads = [threading.Thread(target=worker) for _ in range(6)]
for t in threads:
t.start()
for t in threads:
t.join()
# All 6 workers should have acquired a chain
self.assertEqual(len(acquired), 6)
# Pool should have all 3 chains back
self.assertEqual(pool.qsize(), 3)
returned = set()
while not pool.empty():
returned.add(pool.get())
self.assertEqual(returned, set(chains))
def test_pool_blocks_when_empty(self):
"""When pool is empty, threads should block until a chain is returned."""
pool: queue.Queue = queue.Queue()
pool.put("only-chain")
acquired_times: List[float] = []
lock = threading.Lock()
def worker():
chain = pool.get()
with lock:
acquired_times.append(time.monotonic())
time.sleep(0.05)
pool.put(chain)
t1 = threading.Thread(target=worker)
t2 = threading.Thread(target=worker)
t1.start()
time.sleep(0.01) # Ensure t1 starts first
t2.start()
t1.join()
t2.join()
# t2 should have acquired after t1 returned (at least 0.04s gap)
self.assertEqual(len(acquired_times), 2)
gap = acquired_times[1] - acquired_times[0]
self.assertGreater(gap, 0.03)
class TestConcurrentFileProcessing(unittest.TestCase):
"""Test the concurrent file processing logic from _process_single_task."""
def test_threadpool_processes_all_files(self):
"""ThreadPoolExecutor should process all submitted files."""
results: List[str] = []
lock = threading.Lock()
def process_file(file_id):
time.sleep(0.01)
with lock:
results.append(file_id)
return file_id
files = [f"file-{i}" for i in range(10)]
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(process_file, f): f for f in files}
for future in as_completed(futures):
future.result()
self.assertEqual(sorted(results), sorted(files))
def test_stop_event_cancels_pending_futures(self):
"""Setting stop_event should prevent unstarted files from processing."""
stop_event = threading.Event()
processed: List[str] = []
lock = threading.Lock()
def process_file(file_id):
if stop_event.is_set():
return None
time.sleep(0.05)
with lock:
processed.append(file_id)
return file_id
files = [f"file-{i}" for i in range(20)]
with ThreadPoolExecutor(max_workers=2) as executor:
future_to_file = {
executor.submit(process_file, f): f for f in files
}
count = 0
for future in as_completed(future_to_file):
result = future.result()
count += 1
if count >= 3:
stop_event.set()
for f in future_to_file:
f.cancel()
break
# Should have processed some but not all files
self.assertGreater(len(processed), 0)
self.assertLess(len(processed), 20)
def test_thread_safe_counter_updates(self):
"""Counters updated inside lock should be accurate under concurrency."""
processed = 0
detected = 0
lock = threading.Lock()
def process_and_count(file_id):
nonlocal processed, detected
time.sleep(0.001)
with lock:
processed += 1
detected += 2
return file_id
with ThreadPoolExecutor(max_workers=8) as executor:
futures = [executor.submit(process_and_count, f"f-{i}") for i in range(100)]
for f in as_completed(futures):
f.result()
self.assertEqual(processed, 100)
self.assertEqual(detected, 200)
class TestFileWorkersConfig(unittest.TestCase):
"""Test FILE_WORKERS configuration behavior."""
def test_file_workers_one_is_serial(self):
"""FILE_WORKERS=1 should process files sequentially."""
order: List[int] = []
lock = threading.Lock()
def process(idx):
with lock:
order.append(idx)
time.sleep(0.01)
return idx
with ThreadPoolExecutor(max_workers=1) as executor:
futures = [executor.submit(process, i) for i in range(5)]
for f in as_completed(futures):
f.result()
# With max_workers=1, execution is serial (though completion order
# via as_completed might differ; the key is that only 1 runs at a time)
self.assertEqual(len(order), 5)
def test_file_workers_gt_one_is_parallel(self):
"""FILE_WORKERS>1 should process files concurrently."""
start_times: List[float] = []
lock = threading.Lock()
def process(idx):
with lock:
start_times.append(time.monotonic())
time.sleep(0.05)
return idx
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(process, i) for i in range(4)]
for f in as_completed(futures):
f.result()
# All 4 should start nearly simultaneously
self.assertEqual(len(start_times), 4)
time_spread = max(start_times) - min(start_times)
# With parallel execution, spread should be < 0.04s
# (serial would be ~0.15s with 0.05s sleep each)
self.assertLess(time_spread, 0.04)
class TestWorkerLoopSimplified(unittest.TestCase):
"""Test that _worker_loop no longer calls recovery."""
@patch("datamate.auto_annotation_worker._process_single_task")
@patch("datamate.auto_annotation_worker._fetch_pending_task")
@patch("datamate.auto_annotation_worker._recover_stale_running_tasks")
def test_worker_loop_does_not_call_recovery(self, mock_recover, mock_fetch, mock_process):
"""_worker_loop should NOT call _recover_stale_running_tasks."""
from datamate.auto_annotation_worker import _worker_loop
call_count = 0
def side_effect():
nonlocal call_count
call_count += 1
if call_count >= 2:
raise KeyboardInterrupt("stop test")
return None
mock_fetch.side_effect = side_effect
with patch("datamate.auto_annotation_worker.POLL_INTERVAL_SECONDS", 0.001):
try:
_worker_loop()
except KeyboardInterrupt:
pass
mock_recover.assert_not_called()
class TestProgressThrottling(unittest.TestCase):
"""Test time-based progress update throttling (improvement #5)."""
def test_progress_updates_throttled(self):
"""With PROGRESS_UPDATE_INTERVAL>0, rapid completions should batch DB writes."""
update_calls: List[float] = []
lock = threading.Lock()
def mock_update(*args, **kwargs):
with lock:
update_calls.append(time.monotonic())
interval = 0.05 # 50ms throttle interval
processed = 0
# Initialize in the past so the first file triggers an update
last_progress_update = time.monotonic() - interval
total_files = 50
# Simulate the throttled update loop from _process_single_task
for i in range(total_files):
processed += 1
now = time.monotonic()
if interval <= 0 or (now - last_progress_update) >= interval:
mock_update(processed=processed, total=total_files)
last_progress_update = now
# Simulate very fast file processing (~1ms)
time.sleep(0.001)
# With 50 files at ~1ms each (~50ms total) and 50ms interval,
# should get far fewer updates than total_files
self.assertLess(len(update_calls), total_files)
self.assertGreater(len(update_calls), 0)
def test_progress_interval_zero_updates_every_file(self):
"""PROGRESS_UPDATE_INTERVAL=0 should update on every file completion."""
update_count = 0
interval = 0.0
total_files = 20
last_progress_update = time.monotonic()
for i in range(total_files):
now = time.monotonic()
if interval <= 0 or (now - last_progress_update) >= interval:
update_count += 1
last_progress_update = now
self.assertEqual(update_count, total_files)
def test_progress_throttle_with_slow_processing(self):
"""When each file takes longer than the interval, every file triggers an update."""
update_count = 0
interval = 0.01 # 10ms interval
total_files = 5
last_progress_update = time.monotonic() - 1.0 # Start in the past
for i in range(total_files):
time.sleep(0.02) # 20ms per file > 10ms interval
now = time.monotonic()
if interval <= 0 or (now - last_progress_update) >= interval:
update_count += 1
last_progress_update = now
# Every file should trigger an update since processing time > interval
self.assertEqual(update_count, total_files)
if __name__ == "__main__":
unittest.main()