feat(auto-annotation): add concurrent processing support

Enable parallel processing for auto-annotation tasks with configurable worker count and file-level parallelism.

Key features:
- Multi-worker support: WORKER_COUNT env var (default 1) controls number of worker threads
- Intra-task file parallelism: FILE_WORKERS env var (default 1) controls concurrent file processing within a single task
- Operator chain pooling: Pre-create N independent chain instances to avoid thread-safety issues
- Thread-safe progress tracking: Use threading.Lock to protect shared counters
- Stop signal handling: threading.Event for graceful cancellation during concurrent processing

Implementation details:
- Refactor _process_single_task() to use ThreadPoolExecutor + as_completed()
- Chain pool (queue.Queue): Each worker thread acquires/releases a chain instance
- Protected counters: processed_images, detected_total, file_results with Lock
- Stop check: Periodic check of _is_stop_requested() during concurrent processing
- Refactor start_auto_annotation_worker(): Move recovery logic here, start WORKER_COUNT threads
- Simplify _worker_loop(): Remove recovery call, keep only polling + processing

Backward compatibility:
- Default config (WORKER_COUNT=1, FILE_WORKERS=1) behaves identically to previous version
- No breaking changes to existing deployments

Testing:
- 11 unit tests all passed:
  * Multi-worker startup
  * Chain pool acquire/release
  * Concurrent file processing
  * Stop signal handling
  * Thread-safe counter updates
  * Backward compatibility (FILE_WORKERS=1)
- py_compile syntax check passed

Performance benefits:
- WORKER_COUNT=3: Process 3 tasks simultaneously
- FILE_WORKERS=4: Process 4 files in parallel within each task
- Combined: Up to 12x throughput improvement (3 workers × 4 files)
This commit is contained in:
2026-02-10 16:36:34 +08:00
parent 2fbfefdb91
commit 9988ff00f5
2 changed files with 468 additions and 74 deletions

View File

@@ -2,18 +2,17 @@
"""Simple background worker for auto-annotation tasks.
This module runs inside the datamate-runtime container (operator_runtime service).
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
inference using the ImageObjectDetectionBoundingBox operator, updating
progress back to the same table so that the datamate-python backend and
frontend can display real-time status.
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs annotation
using configurable operator pipelines (YOLO, LLM text classification, NER,
relation extraction, etc.), updating progress back to the same table so that
the datamate-python backend and frontend can display real-time status.
设计目标(最小可用版本):
- 单实例 worker,串行处理 `pending` 状态的任务。
- 对指定数据集下的所有已完成文件逐张执行目标检测
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
- 失败时将任务标记为 `failed` 并记录 `error_message`
注意:
设计:
- 多任务并发: 可通过 AUTO_ANNOTATION_WORKER_COUNT 启动多个 worker 线程,
各自独立轮询和认领 pending 任务(run_token 原子 claim 保证不重复)
- 任务内文件并发: 可通过 AUTO_ANNOTATION_FILE_WORKERS 配置线程池大小,
单任务内并行处理多个文件(LLM I/O 密集型场景尤其有效)
算子链通过对象池隔离,每个线程使用独立的链实例。
- 启动时自动恢复心跳超时的 running 任务:未处理文件重置为 pending,
已有部分进度的标记为 failed,由用户决定是否手动重试。
"""
@@ -22,10 +21,12 @@ from __future__ import annotations
import importlib
import json
import os
import queue
import sys
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@@ -128,6 +129,10 @@ DEFAULT_OPERATOR_WHITELIST = os.getenv(
HEARTBEAT_TIMEOUT_SECONDS = int(os.getenv("AUTO_ANNOTATION_HEARTBEAT_TIMEOUT", "300"))
WORKER_COUNT = int(os.getenv("AUTO_ANNOTATION_WORKER_COUNT", "1"))
FILE_WORKERS = int(os.getenv("AUTO_ANNOTATION_FILE_WORKERS", "1"))
def _recover_stale_running_tasks() -> int:
"""启动时恢复心跳超时的 running 任务。
@@ -1201,10 +1206,6 @@ def _process_single_task(task: Dict[str, Any]) -> None:
raise RuntimeError("Pipeline is empty after normalization")
_validate_pipeline_whitelist(normalized_pipeline)
chain = _build_operator_chain(normalized_pipeline)
if not chain:
raise RuntimeError("No valid operator instances initialized")
except Exception as e:
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
_update_task_status(
@@ -1219,70 +1220,132 @@ def _process_single_task(task: Dict[str, Any]) -> None:
)
return
# --- 构建算子链池(每个线程使用独立的链实例,避免线程安全问题)---
effective_file_workers = max(1, FILE_WORKERS)
chain_pool: queue.Queue = queue.Queue()
try:
for _ in range(effective_file_workers):
c = _build_operator_chain(normalized_pipeline)
if not c:
raise RuntimeError("No valid operator instances initialized")
chain_pool.put(c)
except Exception as e:
logger.error("Failed to build operator chain pool for task {}: {}", task_id, e)
_update_task_status(
task_id,
run_token=run_token,
status="failed",
total_images=total_images,
processed_images=0,
detected_objects=0,
error_message=f"Init pipeline failed: {e}",
clear_run_token=True,
)
return
processed = 0
detected_total = 0
file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations)
stopped = False
try:
# --- 线程安全的进度跟踪 ---
progress_lock = threading.Lock()
stop_event = threading.Event()
for file_id, file_path, file_name in files:
if _is_stop_requested(task_id, run_token):
logger.info("Task stop requested during processing: {}", task_id)
_update_task_status(
task_id,
run_token=run_token,
status="stopped",
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
completed=True,
clear_run_token=True,
error_message="Task stopped by request",
)
break
def _process_file(
file_id: str, file_path: str, file_name: str,
) -> Optional[Tuple[str, Dict[str, Any]]]:
"""在线程池中处理单个文件。"""
if stop_event.is_set():
return None
chain = chain_pool.get()
try:
sample_key = _get_sample_key(dataset_type)
sample = {
sample: Dict[str, Any] = {
sample_key: file_path,
"filename": file_name,
}
result = _run_pipeline_sample(sample, chain)
detected_total += _count_detections(result)
processed += 1
return (file_id, result)
finally:
chain_pool.put(chain)
ann = result.get("annotations")
if isinstance(ann, dict):
file_results.append((file_id, ann))
# --- 并发文件处理 ---
stop_check_interval = max(1, effective_file_workers * 2)
completed_since_check = 0
progress = int(processed * 100 / total_images) if total_images > 0 else 100
with ThreadPoolExecutor(max_workers=effective_file_workers) as executor:
future_to_file = {
executor.submit(_process_file, fid, fpath, fname): (fid, fpath, fname)
for fid, fpath, fname in files
}
_update_task_status(
task_id,
run_token=run_token,
status="running",
progress=progress,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
except Exception as e:
logger.error(
"Failed to process file for task {}: file_path={}, error={}",
task_id,
file_path,
e,
)
continue
for future in as_completed(future_to_file):
fid, fpath, fname = future_to_file[future]
try:
result = future.result()
if result is None:
continue
file_id_out, sample_result = result
detections = _count_detections(sample_result)
ann = sample_result.get("annotations")
with progress_lock:
processed += 1
detected_total += detections
if isinstance(ann, dict):
file_results.append((file_id_out, ann))
current_processed = processed
current_detected = detected_total
progress = int(current_processed * 100 / total_images) if total_images > 0 else 100
_update_task_status(
task_id,
run_token=run_token,
status="running",
progress=progress,
processed_images=current_processed,
detected_objects=current_detected,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
except Exception as e:
logger.error(
"Failed to process file for task {}: file_path={}, error={}",
task_id,
fpath,
e,
)
completed_since_check += 1
if completed_since_check >= stop_check_interval:
completed_since_check = 0
if _is_stop_requested(task_id, run_token):
stop_event.set()
for f in future_to_file:
f.cancel()
stopped = True
break
if stopped:
logger.info("Task stop requested during processing: {}", task_id)
_update_task_status(
task_id,
run_token=run_token,
status="stopped",
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
completed=True,
clear_run_token=True,
error_message="Task stopped by request",
)
else:
# Loop completed without break (not stopped)
_update_task_status(
task_id,
run_token=run_token,
@@ -1363,19 +1426,12 @@ def _worker_loop() -> None:
"""Worker 主循环,在独立线程中运行。"""
logger.info(
"Auto-annotation worker started with poll interval {} seconds, output root {}",
"Auto-annotation worker started (poll_interval={}s, output_root={}, file_workers={})",
POLL_INTERVAL_SECONDS,
DEFAULT_OUTPUT_ROOT,
FILE_WORKERS,
)
# --- 启动时恢复心跳超时的 running 任务 ---
try:
recovered = _recover_stale_running_tasks()
if recovered > 0:
logger.info("Recovered {} stale running task(s) on startup", recovered)
except Exception as e:
logger.error("Failed to run startup task recovery: {}", e)
while True:
try:
task = _fetch_pending_task()
@@ -1392,6 +1448,20 @@ def _worker_loop() -> None:
def start_auto_annotation_worker() -> None:
"""在后台线程中启动自动标注 worker。"""
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
thread.start()
logger.info("Auto-annotation worker thread started: {}", thread.name)
# 启动前执行一次恢复(在 worker 线程启动前运行,避免多线程重复恢复)
try:
recovered = _recover_stale_running_tasks()
if recovered > 0:
logger.info("Recovered {} stale running task(s) on startup", recovered)
except Exception as e:
logger.error("Failed to run startup task recovery: {}", e)
count = max(1, WORKER_COUNT)
for i in range(count):
thread = threading.Thread(
target=_worker_loop,
name=f"auto-annotation-worker-{i}",
daemon=True,
)
thread.start()
logger.info("Auto-annotation worker thread started: {}", thread.name)