# -*- coding: utf-8 -*- """Tests for auto_annotation_worker concurrency features (improvement #4). Covers: - Multi-worker startup (WORKER_COUNT) - Intra-task file parallelism (FILE_WORKERS) - Chain pool acquire/release - Thread-safe progress tracking - Stop request handling during concurrent processing """ from __future__ import annotations import os import queue import sys import threading import time import unittest from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple from unittest.mock import MagicMock, patch, call # --------------------------------------------------------------------------- # Ensure the module under test can be imported # --------------------------------------------------------------------------- RUNTIME_ROOT = os.path.join(os.path.dirname(__file__), "..", "..") EXECUTOR_ROOT = os.path.join(os.path.dirname(__file__), "..") for p in (RUNTIME_ROOT, EXECUTOR_ROOT): abs_p = os.path.abspath(p) if abs_p not in sys.path: sys.path.insert(0, abs_p) class TestWorkerCountConfig(unittest.TestCase): """Test that start_auto_annotation_worker launches WORKER_COUNT threads.""" @patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0) @patch("datamate.auto_annotation_worker._worker_loop") @patch("datamate.auto_annotation_worker.WORKER_COUNT", 3) def test_multiple_worker_threads_launched(self, mock_loop, mock_recover): """WORKER_COUNT=3 should launch 3 daemon threads.""" from datamate.auto_annotation_worker import start_auto_annotation_worker started_threads: List[threading.Thread] = [] original_thread_init = threading.Thread.__init__ def track_thread(self_thread, *args, **kwargs): original_thread_init(self_thread, *args, **kwargs) started_threads.append(self_thread) with patch.object(threading.Thread, "__init__", track_thread): with patch.object(threading.Thread, "start"): start_auto_annotation_worker() self.assertEqual(len(started_threads), 3) for i, t in enumerate(started_threads): self.assertEqual(t.name, f"auto-annotation-worker-{i}") self.assertTrue(t.daemon) @patch("datamate.auto_annotation_worker._recover_stale_running_tasks", return_value=0) @patch("datamate.auto_annotation_worker._worker_loop") @patch("datamate.auto_annotation_worker.WORKER_COUNT", 1) def test_single_worker_default(self, mock_loop, mock_recover): """WORKER_COUNT=1 (default) should launch exactly 1 thread.""" from datamate.auto_annotation_worker import start_auto_annotation_worker started_threads: List[threading.Thread] = [] original_thread_init = threading.Thread.__init__ def track_thread(self_thread, *args, **kwargs): original_thread_init(self_thread, *args, **kwargs) started_threads.append(self_thread) with patch.object(threading.Thread, "__init__", track_thread): with patch.object(threading.Thread, "start"): start_auto_annotation_worker() self.assertEqual(len(started_threads), 1) @patch("datamate.auto_annotation_worker._recover_stale_running_tasks", side_effect=RuntimeError("db down")) @patch("datamate.auto_annotation_worker._worker_loop") @patch("datamate.auto_annotation_worker.WORKER_COUNT", 2) def test_recovery_failure_doesnt_block_workers(self, mock_loop, mock_recover): """Recovery failure should not prevent worker threads from starting.""" from datamate.auto_annotation_worker import start_auto_annotation_worker started_threads: List[threading.Thread] = [] original_thread_init = threading.Thread.__init__ def track_thread(self_thread, *args, **kwargs): original_thread_init(self_thread, *args, **kwargs) started_threads.append(self_thread) with patch.object(threading.Thread, "__init__", track_thread): with patch.object(threading.Thread, "start"): start_auto_annotation_worker() # Workers should still be launched despite recovery failure self.assertEqual(len(started_threads), 2) class TestChainPool(unittest.TestCase): """Test the chain pool pattern used for operator instance isolation.""" def test_pool_acquire_release(self): """Each thread should get its own chain and return it after use.""" pool: queue.Queue = queue.Queue() chains = [f"chain-{i}" for i in range(3)] for c in chains: pool.put(c) acquired: List[str] = [] lock = threading.Lock() def worker(): chain = pool.get() with lock: acquired.append(chain) time.sleep(0.01) pool.put(chain) threads = [threading.Thread(target=worker) for _ in range(6)] for t in threads: t.start() for t in threads: t.join() # All 6 workers should have acquired a chain self.assertEqual(len(acquired), 6) # Pool should have all 3 chains back self.assertEqual(pool.qsize(), 3) returned = set() while not pool.empty(): returned.add(pool.get()) self.assertEqual(returned, set(chains)) def test_pool_blocks_when_empty(self): """When pool is empty, threads should block until a chain is returned.""" pool: queue.Queue = queue.Queue() pool.put("only-chain") acquired_times: List[float] = [] lock = threading.Lock() def worker(): chain = pool.get() with lock: acquired_times.append(time.monotonic()) time.sleep(0.05) pool.put(chain) t1 = threading.Thread(target=worker) t2 = threading.Thread(target=worker) t1.start() time.sleep(0.01) # Ensure t1 starts first t2.start() t1.join() t2.join() # t2 should have acquired after t1 returned (at least 0.04s gap) self.assertEqual(len(acquired_times), 2) gap = acquired_times[1] - acquired_times[0] self.assertGreater(gap, 0.03) class TestConcurrentFileProcessing(unittest.TestCase): """Test the concurrent file processing logic from _process_single_task.""" def test_threadpool_processes_all_files(self): """ThreadPoolExecutor should process all submitted files.""" results: List[str] = [] lock = threading.Lock() def process_file(file_id): time.sleep(0.01) with lock: results.append(file_id) return file_id files = [f"file-{i}" for i in range(10)] with ThreadPoolExecutor(max_workers=4) as executor: futures = {executor.submit(process_file, f): f for f in files} for future in as_completed(futures): future.result() self.assertEqual(sorted(results), sorted(files)) def test_stop_event_cancels_pending_futures(self): """Setting stop_event should prevent unstarted files from processing.""" stop_event = threading.Event() processed: List[str] = [] lock = threading.Lock() def process_file(file_id): if stop_event.is_set(): return None time.sleep(0.05) with lock: processed.append(file_id) return file_id files = [f"file-{i}" for i in range(20)] with ThreadPoolExecutor(max_workers=2) as executor: future_to_file = { executor.submit(process_file, f): f for f in files } count = 0 for future in as_completed(future_to_file): result = future.result() count += 1 if count >= 3: stop_event.set() for f in future_to_file: f.cancel() break # Should have processed some but not all files self.assertGreater(len(processed), 0) self.assertLess(len(processed), 20) def test_thread_safe_counter_updates(self): """Counters updated inside lock should be accurate under concurrency.""" processed = 0 detected = 0 lock = threading.Lock() def process_and_count(file_id): nonlocal processed, detected time.sleep(0.001) with lock: processed += 1 detected += 2 return file_id with ThreadPoolExecutor(max_workers=8) as executor: futures = [executor.submit(process_and_count, f"f-{i}") for i in range(100)] for f in as_completed(futures): f.result() self.assertEqual(processed, 100) self.assertEqual(detected, 200) class TestFileWorkersConfig(unittest.TestCase): """Test FILE_WORKERS configuration behavior.""" def test_file_workers_one_is_serial(self): """FILE_WORKERS=1 should process files sequentially.""" order: List[int] = [] lock = threading.Lock() def process(idx): with lock: order.append(idx) time.sleep(0.01) return idx with ThreadPoolExecutor(max_workers=1) as executor: futures = [executor.submit(process, i) for i in range(5)] for f in as_completed(futures): f.result() # With max_workers=1, execution is serial (though completion order # via as_completed might differ; the key is that only 1 runs at a time) self.assertEqual(len(order), 5) def test_file_workers_gt_one_is_parallel(self): """FILE_WORKERS>1 should process files concurrently.""" start_times: List[float] = [] lock = threading.Lock() def process(idx): with lock: start_times.append(time.monotonic()) time.sleep(0.05) return idx with ThreadPoolExecutor(max_workers=4) as executor: futures = [executor.submit(process, i) for i in range(4)] for f in as_completed(futures): f.result() # All 4 should start nearly simultaneously self.assertEqual(len(start_times), 4) time_spread = max(start_times) - min(start_times) # With parallel execution, spread should be < 0.04s # (serial would be ~0.15s with 0.05s sleep each) self.assertLess(time_spread, 0.04) class TestWorkerLoopSimplified(unittest.TestCase): """Test that _worker_loop no longer calls recovery.""" @patch("datamate.auto_annotation_worker._process_single_task") @patch("datamate.auto_annotation_worker._fetch_pending_task") @patch("datamate.auto_annotation_worker._recover_stale_running_tasks") def test_worker_loop_does_not_call_recovery(self, mock_recover, mock_fetch, mock_process): """_worker_loop should NOT call _recover_stale_running_tasks.""" from datamate.auto_annotation_worker import _worker_loop call_count = 0 def side_effect(): nonlocal call_count call_count += 1 if call_count >= 2: raise KeyboardInterrupt("stop test") return None mock_fetch.side_effect = side_effect with patch("datamate.auto_annotation_worker.POLL_INTERVAL_SECONDS", 0.001): try: _worker_loop() except KeyboardInterrupt: pass mock_recover.assert_not_called() if __name__ == "__main__": unittest.main()