Files
DataMate/runtime/python-executor/datamate/auto_annotation_worker.py
Jerry Yan 9988ff00f5 feat(auto-annotation): add concurrent processing support
Enable parallel processing for auto-annotation tasks with configurable worker count and file-level parallelism.

Key features:
- Multi-worker support: WORKER_COUNT env var (default 1) controls number of worker threads
- Intra-task file parallelism: FILE_WORKERS env var (default 1) controls concurrent file processing within a single task
- Operator chain pooling: Pre-create N independent chain instances to avoid thread-safety issues
- Thread-safe progress tracking: Use threading.Lock to protect shared counters
- Stop signal handling: threading.Event for graceful cancellation during concurrent processing

Implementation details:
- Refactor _process_single_task() to use ThreadPoolExecutor + as_completed()
- Chain pool (queue.Queue): Each worker thread acquires/releases a chain instance
- Protected counters: processed_images, detected_total, file_results with Lock
- Stop check: Periodic check of _is_stop_requested() during concurrent processing
- Refactor start_auto_annotation_worker(): Move recovery logic here, start WORKER_COUNT threads
- Simplify _worker_loop(): Remove recovery call, keep only polling + processing

Backward compatibility:
- Default config (WORKER_COUNT=1, FILE_WORKERS=1) behaves identically to previous version
- No breaking changes to existing deployments

Testing:
- 11 unit tests all passed:
  * Multi-worker startup
  * Chain pool acquire/release
  * Concurrent file processing
  * Stop signal handling
  * Thread-safe counter updates
  * Backward compatibility (FILE_WORKERS=1)
- py_compile syntax check passed

Performance benefits:
- WORKER_COUNT=3: Process 3 tasks simultaneously
- FILE_WORKERS=4: Process 4 files in parallel within each task
- Combined: Up to 12x throughput improvement (3 workers × 4 files)
2026-02-10 16:36:34 +08:00

1468 lines
51 KiB
Python

# -*- coding: utf-8 -*-
"""Simple background worker for auto-annotation tasks.
This module runs inside the datamate-runtime container (operator_runtime service).
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs annotation
using configurable operator pipelines (YOLO, LLM text classification, NER,
relation extraction, etc.), updating progress back to the same table so that
the datamate-python backend and frontend can display real-time status.
设计:
- 多任务并发: 可通过 AUTO_ANNOTATION_WORKER_COUNT 启动多个 worker 线程,
各自独立轮询和认领 pending 任务(run_token 原子 claim 保证不重复)。
- 任务内文件并发: 可通过 AUTO_ANNOTATION_FILE_WORKERS 配置线程池大小,
单任务内并行处理多个文件(LLM I/O 密集型场景尤其有效)。
算子链通过对象池隔离,每个线程使用独立的链实例。
- 启动时自动恢复心跳超时的 running 任务:未处理文件重置为 pending,
已有部分进度的标记为 failed,由用户决定是否手动重试。
"""
from __future__ import annotations
import importlib
import json
import os
import queue
import sys
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from loguru import logger
from sqlalchemy import text
try:
import datamate.ops # noqa: F401
except Exception as import_ops_err: # pragma: no cover - 兜底日志
logger.warning("Failed to import datamate.ops package for operator registry: {}", import_ops_err)
try:
import ops.annotation # type: ignore # noqa: F401
except Exception as import_annotation_ops_err: # pragma: no cover - 兜底日志
logger.warning(
"Failed to import ops.annotation package for operator registry: {}",
import_annotation_ops_err,
)
try:
from datamate.core.base_op import OPERATORS
except Exception: # pragma: no cover - 兜底
OPERATORS = None # type: ignore
from datamate.sql_manager.sql_manager import SQLManager
# 尝试多种导入路径,适配不同的打包/安装方式
ImageObjectDetectionBoundingBox = None # type: ignore
try:
# 优先使用 datamate.ops 路径(源码 COPY 到 /opt/runtime/datamate/ops 情况)
from datamate.ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from datamate.ops.annotation.image_object_detection_bounding_box",
)
except Exception as e1: # pragma: no cover - 导入失败时仅记录日志,避免整体崩溃
logger.error(
"Failed to import ImageObjectDetectionBoundingBox via datamate.ops: {}",
e1,
)
try:
# 兼容顶层 ops 包安装的情况(通过 ops.pth 暴露)
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from top-level ops.annotation.image_object_detection_bounding_box",
)
except Exception as e2:
logger.error(
"Failed to import ImageObjectDetectionBoundingBox via top-level ops package: {}",
e2,
)
ImageObjectDetectionBoundingBox = None
# 进一步兜底:直接从本地 runtime/ops 目录加载算子(开发环境常用场景)
if ImageObjectDetectionBoundingBox is None:
try:
project_root = Path(__file__).resolve().parents[2]
ops_root = project_root / "ops"
if ops_root.is_dir():
# 确保 ops 的父目录在 sys.path 中,这样可以按 "ops.xxx" 导入
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
ImageObjectDetectionBoundingBox,
)
logger.info(
"Imported ImageObjectDetectionBoundingBox from local runtime/ops.annotation.image_object_detection_bounding_box",
)
else:
logger.warning(
"Local runtime/ops directory not found when trying to import ImageObjectDetectionBoundingBox: {}",
ops_root,
)
except Exception as e3: # pragma: no cover - 兜底失败仅记录日志
logger.error(
"Failed to import ImageObjectDetectionBoundingBox from local runtime/ops: {}",
e3,
)
ImageObjectDetectionBoundingBox = None
POLL_INTERVAL_SECONDS = float(os.getenv("AUTO_ANNOTATION_POLL_INTERVAL", "5"))
DEFAULT_OUTPUT_ROOT = os.getenv(
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
)
DEFAULT_OPERATOR_WHITELIST = os.getenv(
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
"ImageObjectDetectionBoundingBox,test_annotation_marker,"
"LLMTextClassification,LLMNamedEntityRecognition,LLMRelationExtraction",
)
HEARTBEAT_TIMEOUT_SECONDS = int(os.getenv("AUTO_ANNOTATION_HEARTBEAT_TIMEOUT", "300"))
WORKER_COUNT = int(os.getenv("AUTO_ANNOTATION_WORKER_COUNT", "1"))
FILE_WORKERS = int(os.getenv("AUTO_ANNOTATION_FILE_WORKERS", "1"))
def _recover_stale_running_tasks() -> int:
"""启动时恢复心跳超时的 running 任务。
- processed_images = 0 → 重置为 pending(自动重试)
- processed_images > 0 → 标记为 failed(需用户干预)
Returns:
恢复的任务数量。
"""
if HEARTBEAT_TIMEOUT_SECONDS <= 0:
logger.info(
"Heartbeat timeout disabled (HEARTBEAT_TIMEOUT_SECONDS={}), skipping recovery",
HEARTBEAT_TIMEOUT_SECONDS,
)
return 0
cutoff = datetime.now() - timedelta(seconds=HEARTBEAT_TIMEOUT_SECONDS)
find_sql = text("""
SELECT id, processed_images, total_images, heartbeat_at
FROM t_dm_auto_annotation_tasks
WHERE status = 'running'
AND deleted_at IS NULL
AND (heartbeat_at IS NULL OR heartbeat_at < :cutoff)
""")
with SQLManager.create_connect() as conn:
rows = conn.execute(find_sql, {"cutoff": cutoff}).fetchall()
if not rows:
return 0
recovered = 0
for row in rows:
task_id = row[0]
processed = row[1] or 0
total = row[2] or 0
heartbeat = row[3]
try:
if processed == 0:
# 未开始处理,重置为 pending 自动重试
reset_sql = text("""
UPDATE t_dm_auto_annotation_tasks
SET status = 'pending',
run_token = NULL,
heartbeat_at = NULL,
started_at = NULL,
error_message = NULL,
updated_at = :now
WHERE id = :task_id AND status = 'running'
""")
with SQLManager.create_connect() as conn:
result = conn.execute(
reset_sql, {"task_id": task_id, "now": datetime.now()}
)
if int(getattr(result, "rowcount", 0) or 0) > 0:
recovered += 1
logger.info(
"Recovered stale task {} -> pending (no progress, heartbeat={})",
task_id,
heartbeat,
)
else:
# 已有部分进度,标记为 failed
error_msg = (
f"Worker 心跳超时(上次心跳: {heartbeat}"
f"超时阈值: {HEARTBEAT_TIMEOUT_SECONDS}秒)。"
f"已处理 {processed}/{total} 个文件。请检查后手动重试。"
)
fail_sql = text("""
UPDATE t_dm_auto_annotation_tasks
SET status = 'failed',
run_token = NULL,
error_message = :error_message,
completed_at = :now,
updated_at = :now
WHERE id = :task_id AND status = 'running'
""")
with SQLManager.create_connect() as conn:
result = conn.execute(
fail_sql,
{
"task_id": task_id,
"error_message": error_msg[:2000],
"now": datetime.now(),
},
)
if int(getattr(result, "rowcount", 0) or 0) > 0:
recovered += 1
logger.warning(
"Recovered stale task {} -> failed (processed {}/{}, heartbeat={})",
task_id,
processed,
total,
heartbeat,
)
except Exception as exc:
logger.error("Failed to recover stale task {}: {}", task_id, exc)
return recovered
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
"""原子 claim 一个 pending 任务并返回任务详情。"""
def _parse_json_field(value: Any, default: Any) -> Any:
if value is None:
return default
if isinstance(value, (dict, list)):
return value
if isinstance(value, str):
text_value = value.strip()
if not text_value:
return default
try:
return json.loads(text_value)
except Exception:
return default
return default
run_token = str(uuid.uuid4())
now = datetime.now()
claim_sql = text(
"""
UPDATE t_dm_auto_annotation_tasks
SET status = 'running',
run_token = :run_token,
started_at = COALESCE(started_at, :now),
heartbeat_at = :now,
updated_at = :now,
error_message = NULL
WHERE id = (
SELECT id FROM (
SELECT id
FROM t_dm_auto_annotation_tasks
WHERE status = 'pending'
AND deleted_at IS NULL
AND COALESCE(stop_requested, 0) = 0
ORDER BY created_at ASC
LIMIT 1
) AS pending_task
)
AND status = 'pending'
AND deleted_at IS NULL
AND COALESCE(stop_requested, 0) = 0
"""
)
query_sql = text(
"""
SELECT id, name, dataset_id, dataset_name, dataset_type, created_by,
config, file_ids, pipeline,
task_mode, executor_type,
status, stop_requested, run_token,
total_images, processed_images, detected_objects,
output_path, output_dataset_id
FROM t_dm_auto_annotation_tasks
WHERE run_token = :run_token
LIMIT 1
"""
)
with SQLManager.create_connect() as conn:
claim_result = conn.execute(claim_sql, {"run_token": run_token, "now": now})
if not claim_result or int(getattr(claim_result, "rowcount", 0) or 0) <= 0:
return None
result = conn.execute(query_sql, {"run_token": run_token}).fetchone()
if not result:
return None
row = dict(result._mapping) # type: ignore[attr-defined]
row["config"] = _parse_json_field(row.get("config"), {})
parsed_file_ids = _parse_json_field(row.get("file_ids"), None)
row["file_ids"] = parsed_file_ids if parsed_file_ids else None
parsed_pipeline = _parse_json_field(row.get("pipeline"), None)
row["pipeline"] = parsed_pipeline if parsed_pipeline else None
return row
def _update_task_status(
task_id: str,
*,
status: str,
run_token: Optional[str] = None,
progress: Optional[int] = None,
processed_images: Optional[int] = None,
detected_objects: Optional[int] = None,
total_images: Optional[int] = None,
output_path: Optional[str] = None,
output_dataset_id: Optional[str] = None,
error_message: Optional[str] = None,
completed: bool = False,
clear_run_token: bool = False,
) -> None:
"""更新任务的状态和统计字段。"""
fields: List[str] = ["status = :status", "updated_at = :updated_at"]
params: Dict[str, Any] = {
"task_id": task_id,
"status": status,
"updated_at": datetime.now(),
}
if progress is not None:
fields.append("progress = :progress")
params["progress"] = int(progress)
if processed_images is not None:
fields.append("processed_images = :processed_images")
params["processed_images"] = int(processed_images)
if detected_objects is not None:
fields.append("detected_objects = :detected_objects")
params["detected_objects"] = int(detected_objects)
if total_images is not None:
fields.append("total_images = :total_images")
params["total_images"] = int(total_images)
if output_path is not None:
fields.append("output_path = :output_path")
params["output_path"] = output_path
if output_dataset_id is not None:
fields.append("output_dataset_id = :output_dataset_id")
params["output_dataset_id"] = output_dataset_id
if error_message is not None:
fields.append("error_message = :error_message")
params["error_message"] = error_message[:2000]
if status == "running":
fields.append("heartbeat_at = :heartbeat_at")
params["heartbeat_at"] = datetime.now()
if completed:
fields.append("completed_at = :completed_at")
params["completed_at"] = datetime.now()
if clear_run_token:
fields.append("run_token = NULL")
where_clause = "id = :task_id"
if run_token:
where_clause += " AND run_token = :run_token"
params["run_token"] = run_token
sql = text(
f"""
UPDATE t_dm_auto_annotation_tasks
SET {', '.join(fields)}
WHERE {where_clause}
"""
)
with SQLManager.create_connect() as conn:
result = conn.execute(sql, params)
if int(getattr(result, "rowcount", 0) or 0) <= 0:
logger.warning(
"No rows updated for task status change: task_id={}, status={}, run_token={}",
task_id,
status,
run_token,
)
def _is_stop_requested(task_id: str, run_token: Optional[str] = None) -> bool:
"""检查任务是否请求停止。"""
where_clause = "id = :task_id"
params: Dict[str, Any] = {"task_id": task_id}
if run_token:
where_clause += " AND run_token = :run_token"
params["run_token"] = run_token
sql = text(
f"""
SELECT COALESCE(stop_requested, 0)
FROM t_dm_auto_annotation_tasks
WHERE {where_clause}
LIMIT 1
"""
)
with SQLManager.create_connect() as conn:
row = conn.execute(sql, params).fetchone()
if not row:
# 找不到任务(或 run_token 已失效)时保守停止
return True
return bool(row[0])
def _extract_step_overrides(step: Dict[str, Any]) -> Dict[str, Any]:
"""合并 pipeline 节点中的参数覆盖。"""
overrides: Dict[str, Any] = {}
for key in ("overrides", "settingsOverride", "settings_override"):
value = step.get(key)
if value is None:
continue
if isinstance(value, str):
try:
value = json.loads(value)
except Exception:
continue
if isinstance(value, dict):
overrides.update(value)
return overrides
def _build_legacy_pipeline(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""将 legacy_yolo 配置映射为单步 pipeline。"""
return [
{
"operatorId": "ImageObjectDetectionBoundingBox",
"overrides": {
"modelSize": config.get("modelSize", "l"),
"confThreshold": float(config.get("confThreshold", 0.7)),
"targetClasses": config.get("targetClasses", []) or [],
},
}
]
def _get_output_dataset_name(
task_id: str,
dataset_id: str,
source_dataset_name: str,
task_name: str,
config: Dict[str, Any],
pipeline_raw: Optional[List[Any]],
) -> str:
"""确定输出数据集名称。"""
output_name = config.get("outputDatasetName")
if output_name:
return str(output_name)
if pipeline_raw:
for step in pipeline_raw:
if not isinstance(step, dict):
continue
overrides = _extract_step_overrides(step)
output_name = overrides.get("outputDatasetName") or overrides.get("output_dataset_name")
if output_name:
return str(output_name)
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
return f"{base_name}_auto_{task_id[:8]}"
def _normalize_pipeline(
task_mode: str,
config: Dict[str, Any],
pipeline_raw: Optional[List[Any]],
output_dir: str,
) -> List[Dict[str, Any]]:
"""标准化 pipeline 结构并注入 outputDir。"""
source_pipeline = pipeline_raw
if task_mode == "legacy_yolo" or not source_pipeline:
source_pipeline = _build_legacy_pipeline(config)
normalized: List[Dict[str, Any]] = []
for step in source_pipeline:
if not isinstance(step, dict):
continue
operator_id: Optional[str] = None
overrides: Dict[str, Any] = {}
# 兼容 [{"OpName": {...}}] 风格
if (
"operatorId" not in step
and "operator_id" not in step
and "id" not in step
and len(step) == 1
):
first_key = next(iter(step.keys()))
first_value = step.get(first_key)
if isinstance(first_key, str):
operator_id = first_key
if isinstance(first_value, dict):
overrides.update(first_value)
operator_id = operator_id or step.get("operatorId") or step.get("operator_id") or step.get("id")
if not operator_id:
continue
overrides.update(_extract_step_overrides(step))
overrides.setdefault("outputDir", output_dir)
normalized.append(
{
"operator_id": str(operator_id),
"overrides": overrides,
}
)
return normalized
def _resolve_operator_class(operator_id: str):
"""根据 operator_id 解析算子类。"""
if operator_id == "ImageObjectDetectionBoundingBox":
if ImageObjectDetectionBoundingBox is None:
raise ImportError("ImageObjectDetectionBoundingBox is not available")
return ImageObjectDetectionBoundingBox
registry_item = OPERATORS.get(operator_id) if OPERATORS is not None else None
if registry_item is None:
try:
from core.base_op import OPERATORS as relative_operators # type: ignore
registry_item = relative_operators.get(operator_id)
except Exception:
registry_item = None
if registry_item is None:
raise ImportError(f"Operator not found in registry: {operator_id}")
if isinstance(registry_item, str):
submodule = importlib.import_module(registry_item)
operator_cls = getattr(submodule, operator_id, None)
if operator_cls is None:
raise ImportError(
f"Operator class {operator_id} not found in module {registry_item}"
)
return operator_cls
return registry_item
def _build_operator_chain(pipeline: List[Dict[str, Any]]) -> List[Tuple[str, Any]]:
"""初始化算子链。"""
chain: List[Tuple[str, Any]] = []
for step in pipeline:
operator_id = step.get("operator_id")
overrides = dict(step.get("overrides") or {})
if not operator_id:
continue
operator_cls = _resolve_operator_class(str(operator_id))
operator = operator_cls(**overrides)
chain.append((str(operator_id), operator))
return chain
def _run_pipeline_sample(sample: Dict[str, Any], chain: List[Tuple[str, Any]]) -> Dict[str, Any]:
"""在单个样本上执行 pipeline。"""
current_sample: Dict[str, Any] = dict(sample)
for operator_id, operator in chain:
if hasattr(operator, "execute") and callable(getattr(operator, "execute")):
result = operator.execute(current_sample)
elif callable(operator):
result = operator(current_sample)
else:
raise RuntimeError(f"Operator {operator_id} is not executable")
if result is None:
continue
if isinstance(result, dict):
current_sample.update(result)
continue
if isinstance(result, list):
# 仅取第一个 dict 结果,兼容部分返回 list 的算子
if result and isinstance(result[0], dict):
current_sample.update(result[0])
continue
logger.debug(
"Operator {} returned unsupported result type: {}",
operator_id,
type(result).__name__,
)
return current_sample
def _count_detections(sample: Dict[str, Any]) -> int:
"""从样本中提取检测数量。"""
annotations = sample.get("annotations")
if isinstance(annotations, dict):
detections = annotations.get("detections")
if isinstance(detections, list):
return len(detections)
detection_count = sample.get("detection_count")
if detection_count is None:
return 0
try:
return max(int(detection_count), 0)
except Exception:
return 0
# ---------------------------------------------------------------------------
# 数据集类型 → sample key / 输出子目录 映射
# ---------------------------------------------------------------------------
DATASET_TYPE_SAMPLE_KEY: Dict[str, str] = {
"IMAGE": "image",
"TEXT": "text",
"AUDIO": "audio",
"VIDEO": "video",
}
DATASET_TYPE_DATA_DIR: Dict[str, str] = {
"IMAGE": "images",
"TEXT": "data",
"AUDIO": "data",
"VIDEO": "data",
}
def _get_sample_key(dataset_type: str) -> str:
"""根据数据集类型返回 sample dict 中主数据对应的 key。"""
return DATASET_TYPE_SAMPLE_KEY.get(dataset_type.upper(), "image")
def _get_data_dir_name(dataset_type: str) -> str:
"""根据数据集类型返回输出子目录名。"""
return DATASET_TYPE_DATA_DIR.get(dataset_type.upper(), "images")
def _get_operator_whitelist() -> Optional[set[str]]:
"""获取灰度白名单;返回 None 表示放开全部。"""
raw = str(DEFAULT_OPERATOR_WHITELIST or "").strip()
if not raw:
return None
normalized = raw.lower()
if normalized in {"*", "all", "any"}:
return None
allow_set = {
item.strip()
for item in raw.split(",")
if item and item.strip()
}
return allow_set or None
def _validate_pipeline_whitelist(pipeline: List[Dict[str, Any]]) -> None:
"""校验 pipeline 是否命中灰度白名单。"""
allow_set = _get_operator_whitelist()
if allow_set is None:
return
blocked = []
for step in pipeline:
operator_id = str(step.get("operator_id") or "")
if not operator_id:
continue
if operator_id not in allow_set:
blocked.append(operator_id)
if blocked:
raise RuntimeError(
"Operator not in whitelist: " + ", ".join(sorted(set(blocked)))
)
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
"""加载指定数据集下的所有已完成文件。"""
sql = text(
"""
SELECT id, file_path, file_name
FROM t_dm_dataset_files
WHERE dataset_id = :dataset_id
AND status = 'ACTIVE'
ORDER BY created_at ASC
"""
)
with SQLManager.create_connect() as conn:
rows = conn.execute(sql, {"dataset_id": dataset_id}).fetchall()
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
def _load_dataset_meta(dataset_id: str) -> Optional[Dict[str, Any]]:
"""加载数据集基础信息(含父ID与路径)。"""
sql = text(
"""
SELECT id, name, parent_dataset_id, path, dataset_type
FROM t_dm_datasets
WHERE id = :dataset_id
"""
)
with SQLManager.create_connect() as conn:
row = conn.execute(sql, {"dataset_id": dataset_id}).fetchone()
if not row:
return None
return dict(row._mapping) # type: ignore[attr-defined]
def _resolve_output_parent(source_dataset_id: str) -> Tuple[Optional[str], str]:
"""根据源数据集确定产出数据集的父级与基路径(产出挂在父级下)。"""
base_path = DEFAULT_OUTPUT_ROOT.rstrip("/") or "/dataset"
source_meta = _load_dataset_meta(source_dataset_id)
if not source_meta:
return None, base_path
parent_dataset_id = source_meta.get("parent_dataset_id")
if not parent_dataset_id:
return None, base_path
parent_meta = _load_dataset_meta(str(parent_dataset_id))
parent_path = parent_meta.get("path") if parent_meta else None
if not parent_path:
return None, base_path
return str(parent_dataset_id), str(parent_path)
def _load_files_by_ids(file_ids: List[str]) -> List[Tuple[str, str, str]]:
"""根据文件ID列表加载文件记录,支持跨多个数据集。"""
if not file_ids:
return []
placeholders = ", ".join(f":id{i}" for i in range(len(file_ids)))
sql = text(
f"""
SELECT id, file_path, file_name
FROM t_dm_dataset_files
WHERE id IN ({placeholders})
AND status = 'ACTIVE'
ORDER BY created_at ASC
"""
)
params = {f"id{i}": str(fid) for i, fid in enumerate(file_ids)}
with SQLManager.create_connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
def _ensure_output_dir(output_dir: str, dataset_type: str = "IMAGE") -> str:
"""确保输出目录及其数据/annotations 子目录存在。"""
os.makedirs(output_dir, exist_ok=True)
data_dir_name = _get_data_dir_name(dataset_type)
os.makedirs(os.path.join(output_dir, data_dir_name), exist_ok=True)
os.makedirs(os.path.join(output_dir, "annotations"), exist_ok=True)
return output_dir
def _create_output_dataset(
source_dataset_id: str,
source_dataset_name: str,
output_dataset_name: str,
dataset_type: str = "IMAGE",
) -> Tuple[str, str]:
"""为自动标注结果创建一个新的数据集并返回 (dataset_id, path)。"""
new_dataset_id = str(uuid.uuid4())
parent_dataset_id, dataset_base_path = _resolve_output_parent(source_dataset_id)
output_dir = os.path.join(dataset_base_path, new_dataset_id)
description = (
f"Auto annotations for dataset {source_dataset_name or source_dataset_id}"[:255]
)
sql = text(
"""
INSERT INTO t_dm_datasets (id, parent_dataset_id, name, description, dataset_type, path, status)
VALUES (:id, :parent_dataset_id, :name, :description, :dataset_type, :path, :status)
"""
)
params = {
"id": new_dataset_id,
"parent_dataset_id": parent_dataset_id,
"name": output_dataset_name,
"description": description,
"dataset_type": dataset_type,
"path": output_dir,
"status": "ACTIVE",
}
with SQLManager.create_connect() as conn:
conn.execute(sql, params)
return new_dataset_id, output_dir
def _register_output_dataset(
task_id: str,
output_dataset_id: str,
output_dir: str,
output_dataset_name: str,
total_images: int,
dataset_type: str = "IMAGE",
) -> None:
"""将自动标注结果注册到新建的数据集。"""
data_dir_name = _get_data_dir_name(dataset_type)
data_dir = os.path.join(output_dir, data_dir_name)
# 兼容旧任务和 IMAGE 算子(它们写入 images/ 目录)
if not os.path.isdir(data_dir):
fallback_dir = os.path.join(output_dir, "images")
if os.path.isdir(fallback_dir):
data_dir = fallback_dir
else:
logger.warning(
"Auto-annotation data directory not found for task {}: {}",
task_id,
data_dir,
)
return
data_files: List[Tuple[str, str, int]] = []
annotation_files: List[Tuple[str, str, int]] = []
total_size = 0
for file_name in sorted(os.listdir(data_dir)):
file_path = os.path.join(data_dir, file_name)
if not os.path.isfile(file_path):
continue
try:
file_size = os.path.getsize(file_path)
except OSError:
file_size = 0
data_files.append((file_name, file_path, int(file_size)))
total_size += int(file_size)
annotations_dir = os.path.join(output_dir, "annotations")
if os.path.isdir(annotations_dir):
for file_name in sorted(os.listdir(annotations_dir)):
file_path = os.path.join(annotations_dir, file_name)
if not os.path.isfile(file_path):
continue
try:
file_size = os.path.getsize(file_path)
except OSError:
file_size = 0
annotation_files.append((file_name, file_path, int(file_size)))
total_size += int(file_size)
if not data_files:
logger.warning(
"No data files found in auto-annotation output for task {}: {}",
task_id,
data_dir,
)
return
insert_file_sql = text(
"""
INSERT INTO t_dm_dataset_files (
id, dataset_id, file_name, file_path, logical_path, version, file_type, file_size, status
) VALUES (
:id, :dataset_id, :file_name, :file_path, :logical_path, :version, :file_type, :file_size, :status
)
"""
)
update_dataset_stat_sql = text(
"""
UPDATE t_dm_datasets
SET file_count = COALESCE(file_count, 0) + :add_count,
size_bytes = COALESCE(size_bytes, 0) + :add_size
WHERE id = :dataset_id
"""
)
with SQLManager.create_connect() as conn:
added_count = 0
for file_name, file_path, file_size in data_files:
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
logical_path = os.path.relpath(file_path, output_dir).replace("\\", "/")
conn.execute(
insert_file_sql,
{
"id": str(uuid.uuid4()),
"dataset_id": output_dataset_id,
"file_name": file_name,
"file_path": file_path,
"logical_path": logical_path,
"version": 1,
"file_type": ext,
"file_size": int(file_size),
"status": "ACTIVE",
},
)
added_count += 1
for file_name, file_path, file_size in annotation_files:
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
logical_path = os.path.relpath(file_path, output_dir).replace("\\", "/")
conn.execute(
insert_file_sql,
{
"id": str(uuid.uuid4()),
"dataset_id": output_dataset_id,
"file_name": file_name,
"file_path": file_path,
"logical_path": logical_path,
"version": 1,
"file_type": ext,
"file_size": int(file_size),
"status": "ACTIVE",
},
)
added_count += 1
if added_count > 0:
conn.execute(
update_dataset_stat_sql,
{
"dataset_id": output_dataset_id,
"add_count": added_count,
"add_size": int(total_size),
},
)
logger.info(
"Registered auto-annotation output into dataset: dataset_id={}, name={}, added_files={}, added_size_bytes={}, task_id={}, output_dir={}",
output_dataset_id,
output_dataset_name,
len(data_files) + len(annotation_files),
total_size,
task_id,
output_dir,
)
def _create_labeling_project_with_annotations(
task_id: str,
dataset_id: str,
dataset_name: str,
task_name: str,
dataset_type: str,
normalized_pipeline: List[Dict[str, Any]],
file_results: List[Tuple[str, Dict[str, Any]]],
all_file_ids: List[str],
) -> None:
"""将自动标注结果转换为 Label Studio 格式,创建标注项目并写入标注结果。"""
from datamate.annotation_result_converter import (
convert_annotation,
extract_operator_params,
generate_label_config_xml,
infer_task_type_from_pipeline,
)
task_type = infer_task_type_from_pipeline(normalized_pipeline)
if not task_type:
logger.warning(
"Cannot infer task_type from pipeline for task {}, skipping labeling project creation",
task_id,
)
return
operator_params = extract_operator_params(normalized_pipeline)
# 目标检测:从实际检测结果中收集唯一标签列表
if task_type == "object_detection":
all_labels: set = set()
for _, ann in file_results:
for det in ann.get("detections", []):
if isinstance(det, dict):
all_labels.add(str(det.get("label", "unknown")))
operator_params["_detected_labels"] = sorted(all_labels)
label_config = generate_label_config_xml(task_type, operator_params)
project_id = str(uuid.uuid4())
labeling_project_id = str(uuid.uuid4().int % 10**8).zfill(8)
project_name = f"自动标注 - {task_name or dataset_name or task_id[:8]}"[:100]
now = datetime.now()
configuration = json.dumps(
{
"label_config": label_config,
"description": f"由自动标注任务 {task_id[:8]} 自动创建",
"auto_annotation_task_id": task_id,
},
ensure_ascii=False,
)
insert_project_sql = text(
"""
INSERT INTO t_dm_labeling_projects
(id, dataset_id, name, labeling_project_id, template_id, configuration, created_at, updated_at)
VALUES
(:id, :dataset_id, :name, :labeling_project_id, NULL, :configuration, :now, :now)
"""
)
insert_snapshot_sql = text(
"""
INSERT INTO t_dm_labeling_project_files (id, project_id, file_id, created_at)
VALUES (:id, :project_id, :file_id, :now)
"""
)
insert_annotation_sql = text(
"""
INSERT INTO t_dm_annotation_results
(id, project_id, file_id, annotation, annotation_status, file_version, created_at, updated_at)
VALUES
(:id, :project_id, :file_id, :annotation, :annotation_status, :file_version, :now, :now)
"""
)
with SQLManager.create_connect() as conn:
# 1. 创建标注项目
conn.execute(
insert_project_sql,
{
"id": project_id,
"dataset_id": dataset_id,
"name": project_name,
"labeling_project_id": labeling_project_id,
"configuration": configuration,
"now": now,
},
)
# 2. 创建项目文件快照
for file_id in all_file_ids:
conn.execute(
insert_snapshot_sql,
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
"now": now,
},
)
# 3. 转换并写入标注结果
converted_count = 0
for file_id, annotation in file_results:
ls_annotation = convert_annotation(annotation, file_id, project_id)
if ls_annotation is None:
continue
conn.execute(
insert_annotation_sql,
{
"id": str(uuid.uuid4()),
"project_id": project_id,
"file_id": file_id,
"annotation": json.dumps(ls_annotation, ensure_ascii=False),
"annotation_status": "ANNOTATED",
"file_version": 1,
"now": now,
},
)
converted_count += 1
logger.info(
"Created labeling project {} ({}) with {} annotations for auto-annotation task {}",
project_id,
project_name,
converted_count,
task_id,
)
def _process_single_task(task: Dict[str, Any]) -> None:
"""执行单个自动标注任务。"""
task_id = str(task["id"])
dataset_id = str(task["dataset_id"])
task_name = str(task.get("name") or "")
source_dataset_name = str(task.get("dataset_name") or "")
run_token = str(task.get("run_token") or "")
task_mode = str(task.get("task_mode") or "legacy_yolo")
executor_type = str(task.get("executor_type") or "annotation_local")
cfg: Dict[str, Any] = task.get("config") or {}
pipeline_raw = task.get("pipeline")
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
# 解析数据集类型,兜底从数据集元数据获取
dataset_type = str(task.get("dataset_type") or "").upper() or "IMAGE"
if dataset_type == "IMAGE" and not task.get("dataset_type"):
source_meta = _load_dataset_meta(dataset_id)
if source_meta and source_meta.get("dataset_type"):
dataset_type = str(source_meta["dataset_type"]).upper()
output_dataset_name = _get_output_dataset_name(
task_id=task_id,
dataset_id=dataset_id,
source_dataset_name=source_dataset_name,
task_name=task_name,
config=cfg,
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
)
logger.info(
"Start processing auto-annotation task: id={}, dataset_id={}, task_mode={}, executor_type={}, output_dataset_name={}",
task_id,
dataset_id,
task_mode,
executor_type,
output_dataset_name,
)
if _is_stop_requested(task_id, run_token):
logger.info("Task stop requested before processing started: {}", task_id)
_update_task_status(
task_id,
run_token=run_token,
status="stopped",
completed=True,
clear_run_token=True,
error_message="Task stopped before start",
)
return
_update_task_status(task_id, run_token=run_token, status="running", progress=0)
if selected_file_ids:
all_files = _load_files_by_ids(selected_file_ids)
else:
all_files = _load_dataset_files(dataset_id)
files = all_files # [(file_id, file_path, file_name)]
total_images = len(files)
if total_images == 0:
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
_update_task_status(
task_id,
run_token=run_token,
status="completed",
progress=100,
total_images=0,
processed_images=0,
detected_objects=0,
completed=True,
output_path=None,
clear_run_token=True,
)
return
output_dataset_id, output_dir = _create_output_dataset(
source_dataset_id=dataset_id,
source_dataset_name=source_dataset_name,
output_dataset_name=output_dataset_name,
dataset_type=dataset_type,
)
output_dir = _ensure_output_dir(output_dir, dataset_type=dataset_type)
_update_task_status(
task_id,
run_token=run_token,
status="running",
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
try:
normalized_pipeline = _normalize_pipeline(
task_mode=task_mode,
config=cfg,
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
output_dir=output_dir,
)
if not normalized_pipeline:
raise RuntimeError("Pipeline is empty after normalization")
_validate_pipeline_whitelist(normalized_pipeline)
except Exception as e:
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
_update_task_status(
task_id,
run_token=run_token,
status="failed",
total_images=total_images,
processed_images=0,
detected_objects=0,
error_message=f"Init pipeline failed: {e}",
clear_run_token=True,
)
return
# --- 构建算子链池(每个线程使用独立的链实例,避免线程安全问题)---
effective_file_workers = max(1, FILE_WORKERS)
chain_pool: queue.Queue = queue.Queue()
try:
for _ in range(effective_file_workers):
c = _build_operator_chain(normalized_pipeline)
if not c:
raise RuntimeError("No valid operator instances initialized")
chain_pool.put(c)
except Exception as e:
logger.error("Failed to build operator chain pool for task {}: {}", task_id, e)
_update_task_status(
task_id,
run_token=run_token,
status="failed",
total_images=total_images,
processed_images=0,
detected_objects=0,
error_message=f"Init pipeline failed: {e}",
clear_run_token=True,
)
return
processed = 0
detected_total = 0
file_results: List[Tuple[str, Dict[str, Any]]] = [] # (file_id, annotations)
stopped = False
try:
# --- 线程安全的进度跟踪 ---
progress_lock = threading.Lock()
stop_event = threading.Event()
def _process_file(
file_id: str, file_path: str, file_name: str,
) -> Optional[Tuple[str, Dict[str, Any]]]:
"""在线程池中处理单个文件。"""
if stop_event.is_set():
return None
chain = chain_pool.get()
try:
sample_key = _get_sample_key(dataset_type)
sample: Dict[str, Any] = {
sample_key: file_path,
"filename": file_name,
}
result = _run_pipeline_sample(sample, chain)
return (file_id, result)
finally:
chain_pool.put(chain)
# --- 并发文件处理 ---
stop_check_interval = max(1, effective_file_workers * 2)
completed_since_check = 0
with ThreadPoolExecutor(max_workers=effective_file_workers) as executor:
future_to_file = {
executor.submit(_process_file, fid, fpath, fname): (fid, fpath, fname)
for fid, fpath, fname in files
}
for future in as_completed(future_to_file):
fid, fpath, fname = future_to_file[future]
try:
result = future.result()
if result is None:
continue
file_id_out, sample_result = result
detections = _count_detections(sample_result)
ann = sample_result.get("annotations")
with progress_lock:
processed += 1
detected_total += detections
if isinstance(ann, dict):
file_results.append((file_id_out, ann))
current_processed = processed
current_detected = detected_total
progress = int(current_processed * 100 / total_images) if total_images > 0 else 100
_update_task_status(
task_id,
run_token=run_token,
status="running",
progress=progress,
processed_images=current_processed,
detected_objects=current_detected,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
)
except Exception as e:
logger.error(
"Failed to process file for task {}: file_path={}, error={}",
task_id,
fpath,
e,
)
completed_since_check += 1
if completed_since_check >= stop_check_interval:
completed_since_check = 0
if _is_stop_requested(task_id, run_token):
stop_event.set()
for f in future_to_file:
f.cancel()
stopped = True
break
if stopped:
logger.info("Task stop requested during processing: {}", task_id)
_update_task_status(
task_id,
run_token=run_token,
status="stopped",
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
completed=True,
clear_run_token=True,
error_message="Task stopped by request",
)
else:
_update_task_status(
task_id,
run_token=run_token,
status="completed",
progress=100,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
completed=True,
clear_run_token=True,
)
logger.info(
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
task_id,
total_images,
processed,
detected_total,
output_dir,
)
if output_dataset_name and output_dataset_id:
try:
_register_output_dataset(
task_id=task_id,
output_dataset_id=output_dataset_id,
output_dir=output_dir,
output_dataset_name=output_dataset_name,
total_images=total_images,
dataset_type=dataset_type,
)
except Exception as e: # pragma: no cover - 防御性日志
logger.error(
"Failed to register auto-annotation output as dataset for task {}: {}",
task_id,
e,
)
# 将自动标注结果转换为 Label Studio 格式并写入标注项目
if file_results:
try:
_create_labeling_project_with_annotations(
task_id=task_id,
dataset_id=dataset_id,
dataset_name=source_dataset_name,
task_name=task_name,
dataset_type=dataset_type,
normalized_pipeline=normalized_pipeline,
file_results=file_results,
all_file_ids=[fid for fid, _, _ in all_files],
)
except Exception as e: # pragma: no cover - 防御性日志
logger.error(
"Failed to create labeling project for auto-annotation task {}: {}",
task_id,
e,
)
except Exception as e:
logger.error("Task execution failed for task {}: {}", task_id, e)
_update_task_status(
task_id,
run_token=run_token,
status="failed",
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
processed_images=processed,
detected_objects=detected_total,
total_images=total_images,
output_path=output_dir,
output_dataset_id=output_dataset_id,
error_message=f"Execute pipeline failed: {e}",
clear_run_token=True,
)
def _worker_loop() -> None:
"""Worker 主循环,在独立线程中运行。"""
logger.info(
"Auto-annotation worker started (poll_interval={}s, output_root={}, file_workers={})",
POLL_INTERVAL_SECONDS,
DEFAULT_OUTPUT_ROOT,
FILE_WORKERS,
)
while True:
try:
task = _fetch_pending_task()
if not task:
time.sleep(POLL_INTERVAL_SECONDS)
continue
_process_single_task(task)
except Exception as e: # pragma: no cover - 防御性日志
logger.error("Auto-annotation worker loop error: {}", e)
time.sleep(POLL_INTERVAL_SECONDS)
def start_auto_annotation_worker() -> None:
"""在后台线程中启动自动标注 worker。"""
# 启动前执行一次恢复(在 worker 线程启动前运行,避免多线程重复恢复)
try:
recovered = _recover_stale_running_tasks()
if recovered > 0:
logger.info("Recovered {} stale running task(s) on startup", recovered)
except Exception as e:
logger.error("Failed to run startup task recovery: {}", e)
count = max(1, WORKER_COUNT)
for i in range(count):
thread = threading.Thread(
target=_worker_loop,
name=f"auto-annotation-worker-{i}",
daemon=True,
)
thread.start()
logger.info("Auto-annotation worker thread started: {}", thread.name)