diff --git a/runtime/python-executor/datamate/auto_annotation_worker.py b/runtime/python-executor/datamate/auto_annotation_worker.py index b6bfa01..ae54bf6 100644 --- a/runtime/python-executor/datamate/auto_annotation_worker.py +++ b/runtime/python-executor/datamate/auto_annotation_worker.py @@ -14,8 +14,8 @@ frontend can display real-time status. - 失败时将任务标记为 `failed` 并记录 `error_message`。 注意: -- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时, - 已处于 running 的任务不会被重新拉起,需要后续扩展。 +- 启动时自动恢复心跳超时的 running 任务:未处理文件重置为 pending, + 已有部分进度的标记为 failed,由用户决定是否手动重试。 """ from __future__ import annotations @@ -26,7 +26,7 @@ import sys import threading import time import uuid -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -126,6 +126,111 @@ DEFAULT_OPERATOR_WHITELIST = os.getenv( "LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction", ) +HEARTBEAT_TIMEOUT_SECONDS = int(os.getenv("AUTO_ANNOTATION_HEARTBEAT_TIMEOUT", "300")) + + +def _recover_stale_running_tasks() -> int: + """启动时恢复心跳超时的 running 任务。 + + - processed_images = 0 → 重置为 pending(自动重试) + - processed_images > 0 → 标记为 failed(需用户干预) + + Returns: + 恢复的任务数量。 + """ + if HEARTBEAT_TIMEOUT_SECONDS <= 0: + logger.info( + "Heartbeat timeout disabled (HEARTBEAT_TIMEOUT_SECONDS={}), skipping recovery", + HEARTBEAT_TIMEOUT_SECONDS, + ) + return 0 + + cutoff = datetime.now() - timedelta(seconds=HEARTBEAT_TIMEOUT_SECONDS) + + find_sql = text(""" + SELECT id, processed_images, total_images, heartbeat_at + FROM t_dm_auto_annotation_tasks + WHERE status = 'running' + AND deleted_at IS NULL + AND (heartbeat_at IS NULL OR heartbeat_at < :cutoff) + """) + + with SQLManager.create_connect() as conn: + rows = conn.execute(find_sql, {"cutoff": cutoff}).fetchall() + + if not rows: + return 0 + + recovered = 0 + for row in rows: + task_id = row[0] + processed = row[1] or 0 + total = row[2] or 0 + heartbeat = row[3] + + try: + if processed == 0: + # 未开始处理,重置为 pending 自动重试 + reset_sql = text(""" + UPDATE t_dm_auto_annotation_tasks + SET status = 'pending', + run_token = NULL, + heartbeat_at = NULL, + started_at = NULL, + error_message = NULL, + updated_at = :now + WHERE id = :task_id AND status = 'running' + """) + with SQLManager.create_connect() as conn: + result = conn.execute( + reset_sql, {"task_id": task_id, "now": datetime.now()} + ) + if int(getattr(result, "rowcount", 0) or 0) > 0: + recovered += 1 + logger.info( + "Recovered stale task {} -> pending (no progress, heartbeat={})", + task_id, + heartbeat, + ) + else: + # 已有部分进度,标记为 failed + error_msg = ( + f"Worker 心跳超时(上次心跳: {heartbeat}," + f"超时阈值: {HEARTBEAT_TIMEOUT_SECONDS}秒)。" + f"已处理 {processed}/{total} 个文件。请检查后手动重试。" + ) + fail_sql = text(""" + UPDATE t_dm_auto_annotation_tasks + SET status = 'failed', + run_token = NULL, + error_message = :error_message, + completed_at = :now, + updated_at = :now + WHERE id = :task_id AND status = 'running' + """) + with SQLManager.create_connect() as conn: + result = conn.execute( + fail_sql, + { + "task_id": task_id, + "error_message": error_msg[:2000], + "now": datetime.now(), + }, + ) + if int(getattr(result, "rowcount", 0) or 0) > 0: + recovered += 1 + logger.warning( + "Recovered stale task {} -> failed (processed {}/{}, heartbeat={})", + task_id, + processed, + total, + heartbeat, + ) + except Exception as exc: + logger.error("Failed to recover stale task {}: {}", task_id, exc) + + return recovered + def _fetch_pending_task() -> Optional[Dict[str, Any]]: """原子 claim 一个 pending 任务并返回任务详情。""" @@ -1263,6 +1368,14 @@ def _worker_loop() -> None: DEFAULT_OUTPUT_ROOT, ) + # --- 启动时恢复心跳超时的 running 任务 --- + try: + recovered = _recover_stale_running_tasks() + if recovered > 0: + logger.info("Recovered {} stale running task(s) on startup", recovered) + except Exception as e: + logger.error("Failed to run startup task recovery: {}", e) + while True: try: task = _fetch_pending_task()