You've already forked DataMate
移除自动标注任务创建流程中的 IMAGE-only 限制,使 TEXT、AUDIO、VIDEO 类型数据集均可用于自动标注任务。 - 新增数据库迁移:t_dm_auto_annotation_tasks 表添加 dataset_type 列 - 后端 schema/API/service 全链路传递 dataset_type - Worker 动态构建 sample key(image/text/audio/video)和输出目录 - 前端移除数据集类型校验,下拉框显示数据集类型标识 - 输出数据集继承源数据集类型,不再硬编码为 IMAGE - 保持向后兼容:默认值为 IMAGE,worker 有元数据回退和目录 fallback Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1124 lines
38 KiB
Python
1124 lines
38 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Simple background worker for auto-annotation tasks.
|
|
|
|
This module runs inside the datamate-runtime container (operator_runtime service).
|
|
It polls `t_dm_auto_annotation_tasks` for pending tasks and performs YOLO
|
|
inference using the ImageObjectDetectionBoundingBox operator, updating
|
|
progress back to the same table so that the datamate-python backend and
|
|
frontend can display real-time status.
|
|
|
|
设计目标(最小可用版本):
|
|
- 单实例 worker,串行处理 `pending` 状态的任务。
|
|
- 对指定数据集下的所有已完成文件逐张执行目标检测。
|
|
- 按已处理图片数更新 `processed_images`、`progress`、`detected_objects`、`status` 等字段。
|
|
- 失败时将任务标记为 `failed` 并记录 `error_message`。
|
|
|
|
注意:
|
|
- 为了保持简单,目前不处理 "running" 状态的恢复逻辑;容器重启时,
|
|
已处于 running 的任务不会被重新拉起,需要后续扩展。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import json
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import text
|
|
|
|
try:
|
|
import datamate.ops # noqa: F401
|
|
except Exception as import_ops_err: # pragma: no cover - 兜底日志
|
|
logger.warning("Failed to import datamate.ops package for operator registry: {}", import_ops_err)
|
|
|
|
try:
|
|
import ops.annotation # type: ignore # noqa: F401
|
|
except Exception as import_annotation_ops_err: # pragma: no cover - 兜底日志
|
|
logger.warning(
|
|
"Failed to import ops.annotation package for operator registry: {}",
|
|
import_annotation_ops_err,
|
|
)
|
|
|
|
try:
|
|
from datamate.core.base_op import OPERATORS
|
|
except Exception: # pragma: no cover - 兜底
|
|
OPERATORS = None # type: ignore
|
|
|
|
from datamate.sql_manager.sql_manager import SQLManager
|
|
|
|
# 尝试多种导入路径,适配不同的打包/安装方式
|
|
ImageObjectDetectionBoundingBox = None # type: ignore
|
|
try:
|
|
# 优先使用 datamate.ops 路径(源码 COPY 到 /opt/runtime/datamate/ops 情况)
|
|
from datamate.ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
|
ImageObjectDetectionBoundingBox,
|
|
)
|
|
logger.info(
|
|
"Imported ImageObjectDetectionBoundingBox from datamate.ops.annotation.image_object_detection_bounding_box",
|
|
)
|
|
except Exception as e1: # pragma: no cover - 导入失败时仅记录日志,避免整体崩溃
|
|
logger.error(
|
|
"Failed to import ImageObjectDetectionBoundingBox via datamate.ops: {}",
|
|
e1,
|
|
)
|
|
try:
|
|
# 兼容顶层 ops 包安装的情况(通过 ops.pth 暴露)
|
|
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
|
ImageObjectDetectionBoundingBox,
|
|
)
|
|
logger.info(
|
|
"Imported ImageObjectDetectionBoundingBox from top-level ops.annotation.image_object_detection_bounding_box",
|
|
)
|
|
except Exception as e2:
|
|
logger.error(
|
|
"Failed to import ImageObjectDetectionBoundingBox via top-level ops package: {}",
|
|
e2,
|
|
)
|
|
ImageObjectDetectionBoundingBox = None
|
|
|
|
|
|
# 进一步兜底:直接从本地 runtime/ops 目录加载算子(开发环境常用场景)
|
|
if ImageObjectDetectionBoundingBox is None:
|
|
try:
|
|
project_root = Path(__file__).resolve().parents[2]
|
|
ops_root = project_root / "ops"
|
|
if ops_root.is_dir():
|
|
# 确保 ops 的父目录在 sys.path 中,这样可以按 "ops.xxx" 导入
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from ops.annotation.image_object_detection_bounding_box.process import ( # type: ignore
|
|
ImageObjectDetectionBoundingBox,
|
|
)
|
|
|
|
logger.info(
|
|
"Imported ImageObjectDetectionBoundingBox from local runtime/ops.annotation.image_object_detection_bounding_box",
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Local runtime/ops directory not found when trying to import ImageObjectDetectionBoundingBox: {}",
|
|
ops_root,
|
|
)
|
|
except Exception as e3: # pragma: no cover - 兜底失败仅记录日志
|
|
logger.error(
|
|
"Failed to import ImageObjectDetectionBoundingBox from local runtime/ops: {}",
|
|
e3,
|
|
)
|
|
ImageObjectDetectionBoundingBox = None
|
|
|
|
|
|
POLL_INTERVAL_SECONDS = float(os.getenv("AUTO_ANNOTATION_POLL_INTERVAL", "5"))
|
|
|
|
DEFAULT_OUTPUT_ROOT = os.getenv(
|
|
"AUTO_ANNOTATION_OUTPUT_ROOT", "/dataset"
|
|
)
|
|
|
|
DEFAULT_OPERATOR_WHITELIST = os.getenv(
|
|
"AUTO_ANNOTATION_OPERATOR_WHITELIST",
|
|
"ImageObjectDetectionBoundingBox,test_annotation_marker",
|
|
)
|
|
|
|
|
|
def _fetch_pending_task() -> Optional[Dict[str, Any]]:
|
|
"""原子 claim 一个 pending 任务并返回任务详情。"""
|
|
|
|
def _parse_json_field(value: Any, default: Any) -> Any:
|
|
if value is None:
|
|
return default
|
|
if isinstance(value, (dict, list)):
|
|
return value
|
|
if isinstance(value, str):
|
|
text_value = value.strip()
|
|
if not text_value:
|
|
return default
|
|
try:
|
|
return json.loads(text_value)
|
|
except Exception:
|
|
return default
|
|
return default
|
|
|
|
run_token = str(uuid.uuid4())
|
|
now = datetime.now()
|
|
|
|
claim_sql = text(
|
|
"""
|
|
UPDATE t_dm_auto_annotation_tasks
|
|
SET status = 'running',
|
|
run_token = :run_token,
|
|
started_at = COALESCE(started_at, :now),
|
|
heartbeat_at = :now,
|
|
updated_at = :now,
|
|
error_message = NULL
|
|
WHERE id = (
|
|
SELECT id FROM (
|
|
SELECT id
|
|
FROM t_dm_auto_annotation_tasks
|
|
WHERE status = 'pending'
|
|
AND deleted_at IS NULL
|
|
AND COALESCE(stop_requested, 0) = 0
|
|
ORDER BY created_at ASC
|
|
LIMIT 1
|
|
) AS pending_task
|
|
)
|
|
AND status = 'pending'
|
|
AND deleted_at IS NULL
|
|
AND COALESCE(stop_requested, 0) = 0
|
|
"""
|
|
)
|
|
|
|
query_sql = text(
|
|
"""
|
|
SELECT id, name, dataset_id, dataset_name, dataset_type, created_by,
|
|
config, file_ids, pipeline,
|
|
task_mode, executor_type,
|
|
status, stop_requested, run_token,
|
|
total_images, processed_images, detected_objects,
|
|
output_path, output_dataset_id
|
|
FROM t_dm_auto_annotation_tasks
|
|
WHERE run_token = :run_token
|
|
LIMIT 1
|
|
"""
|
|
)
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
claim_result = conn.execute(claim_sql, {"run_token": run_token, "now": now})
|
|
if not claim_result or int(getattr(claim_result, "rowcount", 0) or 0) <= 0:
|
|
return None
|
|
|
|
result = conn.execute(query_sql, {"run_token": run_token}).fetchone()
|
|
if not result:
|
|
return None
|
|
|
|
row = dict(result._mapping) # type: ignore[attr-defined]
|
|
|
|
row["config"] = _parse_json_field(row.get("config"), {})
|
|
|
|
parsed_file_ids = _parse_json_field(row.get("file_ids"), None)
|
|
row["file_ids"] = parsed_file_ids if parsed_file_ids else None
|
|
|
|
parsed_pipeline = _parse_json_field(row.get("pipeline"), None)
|
|
row["pipeline"] = parsed_pipeline if parsed_pipeline else None
|
|
return row
|
|
|
|
|
|
def _update_task_status(
|
|
task_id: str,
|
|
*,
|
|
status: str,
|
|
run_token: Optional[str] = None,
|
|
progress: Optional[int] = None,
|
|
processed_images: Optional[int] = None,
|
|
detected_objects: Optional[int] = None,
|
|
total_images: Optional[int] = None,
|
|
output_path: Optional[str] = None,
|
|
output_dataset_id: Optional[str] = None,
|
|
error_message: Optional[str] = None,
|
|
completed: bool = False,
|
|
clear_run_token: bool = False,
|
|
) -> None:
|
|
"""更新任务的状态和统计字段。"""
|
|
|
|
fields: List[str] = ["status = :status", "updated_at = :updated_at"]
|
|
params: Dict[str, Any] = {
|
|
"task_id": task_id,
|
|
"status": status,
|
|
"updated_at": datetime.now(),
|
|
}
|
|
|
|
if progress is not None:
|
|
fields.append("progress = :progress")
|
|
params["progress"] = int(progress)
|
|
if processed_images is not None:
|
|
fields.append("processed_images = :processed_images")
|
|
params["processed_images"] = int(processed_images)
|
|
if detected_objects is not None:
|
|
fields.append("detected_objects = :detected_objects")
|
|
params["detected_objects"] = int(detected_objects)
|
|
if total_images is not None:
|
|
fields.append("total_images = :total_images")
|
|
params["total_images"] = int(total_images)
|
|
if output_path is not None:
|
|
fields.append("output_path = :output_path")
|
|
params["output_path"] = output_path
|
|
if output_dataset_id is not None:
|
|
fields.append("output_dataset_id = :output_dataset_id")
|
|
params["output_dataset_id"] = output_dataset_id
|
|
if error_message is not None:
|
|
fields.append("error_message = :error_message")
|
|
params["error_message"] = error_message[:2000]
|
|
if status == "running":
|
|
fields.append("heartbeat_at = :heartbeat_at")
|
|
params["heartbeat_at"] = datetime.now()
|
|
if completed:
|
|
fields.append("completed_at = :completed_at")
|
|
params["completed_at"] = datetime.now()
|
|
if clear_run_token:
|
|
fields.append("run_token = NULL")
|
|
|
|
where_clause = "id = :task_id"
|
|
if run_token:
|
|
where_clause += " AND run_token = :run_token"
|
|
params["run_token"] = run_token
|
|
|
|
sql = text(
|
|
f"""
|
|
UPDATE t_dm_auto_annotation_tasks
|
|
SET {', '.join(fields)}
|
|
WHERE {where_clause}
|
|
"""
|
|
)
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
result = conn.execute(sql, params)
|
|
if int(getattr(result, "rowcount", 0) or 0) <= 0:
|
|
logger.warning(
|
|
"No rows updated for task status change: task_id={}, status={}, run_token={}",
|
|
task_id,
|
|
status,
|
|
run_token,
|
|
)
|
|
|
|
|
|
def _is_stop_requested(task_id: str, run_token: Optional[str] = None) -> bool:
|
|
"""检查任务是否请求停止。"""
|
|
|
|
where_clause = "id = :task_id"
|
|
params: Dict[str, Any] = {"task_id": task_id}
|
|
if run_token:
|
|
where_clause += " AND run_token = :run_token"
|
|
params["run_token"] = run_token
|
|
|
|
sql = text(
|
|
f"""
|
|
SELECT COALESCE(stop_requested, 0)
|
|
FROM t_dm_auto_annotation_tasks
|
|
WHERE {where_clause}
|
|
LIMIT 1
|
|
"""
|
|
)
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
row = conn.execute(sql, params).fetchone()
|
|
if not row:
|
|
# 找不到任务(或 run_token 已失效)时保守停止
|
|
return True
|
|
return bool(row[0])
|
|
|
|
|
|
def _extract_step_overrides(step: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""合并 pipeline 节点中的参数覆盖。"""
|
|
|
|
overrides: Dict[str, Any] = {}
|
|
for key in ("overrides", "settingsOverride", "settings_override"):
|
|
value = step.get(key)
|
|
if value is None:
|
|
continue
|
|
if isinstance(value, str):
|
|
try:
|
|
value = json.loads(value)
|
|
except Exception:
|
|
continue
|
|
if isinstance(value, dict):
|
|
overrides.update(value)
|
|
return overrides
|
|
|
|
|
|
def _build_legacy_pipeline(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""将 legacy_yolo 配置映射为单步 pipeline。"""
|
|
|
|
return [
|
|
{
|
|
"operatorId": "ImageObjectDetectionBoundingBox",
|
|
"overrides": {
|
|
"modelSize": config.get("modelSize", "l"),
|
|
"confThreshold": float(config.get("confThreshold", 0.7)),
|
|
"targetClasses": config.get("targetClasses", []) or [],
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
def _get_output_dataset_name(
|
|
task_id: str,
|
|
dataset_id: str,
|
|
source_dataset_name: str,
|
|
task_name: str,
|
|
config: Dict[str, Any],
|
|
pipeline_raw: Optional[List[Any]],
|
|
) -> str:
|
|
"""确定输出数据集名称。"""
|
|
|
|
output_name = config.get("outputDatasetName")
|
|
if output_name:
|
|
return str(output_name)
|
|
|
|
if pipeline_raw:
|
|
for step in pipeline_raw:
|
|
if not isinstance(step, dict):
|
|
continue
|
|
overrides = _extract_step_overrides(step)
|
|
output_name = overrides.get("outputDatasetName") or overrides.get("output_dataset_name")
|
|
if output_name:
|
|
return str(output_name)
|
|
|
|
base_name = source_dataset_name or task_name or f"dataset-{dataset_id[:8]}"
|
|
return f"{base_name}_auto_{task_id[:8]}"
|
|
|
|
|
|
def _normalize_pipeline(
|
|
task_mode: str,
|
|
config: Dict[str, Any],
|
|
pipeline_raw: Optional[List[Any]],
|
|
output_dir: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""标准化 pipeline 结构并注入 outputDir。"""
|
|
|
|
source_pipeline = pipeline_raw
|
|
if task_mode == "legacy_yolo" or not source_pipeline:
|
|
source_pipeline = _build_legacy_pipeline(config)
|
|
|
|
normalized: List[Dict[str, Any]] = []
|
|
for step in source_pipeline:
|
|
if not isinstance(step, dict):
|
|
continue
|
|
|
|
operator_id: Optional[str] = None
|
|
overrides: Dict[str, Any] = {}
|
|
|
|
# 兼容 [{"OpName": {...}}] 风格
|
|
if (
|
|
"operatorId" not in step
|
|
and "operator_id" not in step
|
|
and "id" not in step
|
|
and len(step) == 1
|
|
):
|
|
first_key = next(iter(step.keys()))
|
|
first_value = step.get(first_key)
|
|
if isinstance(first_key, str):
|
|
operator_id = first_key
|
|
if isinstance(first_value, dict):
|
|
overrides.update(first_value)
|
|
|
|
operator_id = operator_id or step.get("operatorId") or step.get("operator_id") or step.get("id")
|
|
if not operator_id:
|
|
continue
|
|
|
|
overrides.update(_extract_step_overrides(step))
|
|
overrides.setdefault("outputDir", output_dir)
|
|
|
|
normalized.append(
|
|
{
|
|
"operator_id": str(operator_id),
|
|
"overrides": overrides,
|
|
}
|
|
)
|
|
|
|
return normalized
|
|
|
|
|
|
def _resolve_operator_class(operator_id: str):
|
|
"""根据 operator_id 解析算子类。"""
|
|
|
|
if operator_id == "ImageObjectDetectionBoundingBox":
|
|
if ImageObjectDetectionBoundingBox is None:
|
|
raise ImportError("ImageObjectDetectionBoundingBox is not available")
|
|
return ImageObjectDetectionBoundingBox
|
|
|
|
registry_item = OPERATORS.get(operator_id) if OPERATORS is not None else None
|
|
if registry_item is None:
|
|
try:
|
|
from core.base_op import OPERATORS as relative_operators # type: ignore
|
|
registry_item = relative_operators.get(operator_id)
|
|
except Exception:
|
|
registry_item = None
|
|
|
|
if registry_item is None:
|
|
raise ImportError(f"Operator not found in registry: {operator_id}")
|
|
|
|
if isinstance(registry_item, str):
|
|
submodule = importlib.import_module(registry_item)
|
|
operator_cls = getattr(submodule, operator_id, None)
|
|
if operator_cls is None:
|
|
raise ImportError(
|
|
f"Operator class {operator_id} not found in module {registry_item}"
|
|
)
|
|
return operator_cls
|
|
|
|
return registry_item
|
|
|
|
|
|
def _build_operator_chain(pipeline: List[Dict[str, Any]]) -> List[Tuple[str, Any]]:
|
|
"""初始化算子链。"""
|
|
|
|
chain: List[Tuple[str, Any]] = []
|
|
for step in pipeline:
|
|
operator_id = step.get("operator_id")
|
|
overrides = dict(step.get("overrides") or {})
|
|
if not operator_id:
|
|
continue
|
|
|
|
operator_cls = _resolve_operator_class(str(operator_id))
|
|
operator = operator_cls(**overrides)
|
|
chain.append((str(operator_id), operator))
|
|
return chain
|
|
|
|
|
|
def _run_pipeline_sample(sample: Dict[str, Any], chain: List[Tuple[str, Any]]) -> Dict[str, Any]:
|
|
"""在单个样本上执行 pipeline。"""
|
|
|
|
current_sample: Dict[str, Any] = dict(sample)
|
|
for operator_id, operator in chain:
|
|
if hasattr(operator, "execute") and callable(getattr(operator, "execute")):
|
|
result = operator.execute(current_sample)
|
|
elif callable(operator):
|
|
result = operator(current_sample)
|
|
else:
|
|
raise RuntimeError(f"Operator {operator_id} is not executable")
|
|
|
|
if result is None:
|
|
continue
|
|
|
|
if isinstance(result, dict):
|
|
current_sample.update(result)
|
|
continue
|
|
|
|
if isinstance(result, list):
|
|
# 仅取第一个 dict 结果,兼容部分返回 list 的算子
|
|
if result and isinstance(result[0], dict):
|
|
current_sample.update(result[0])
|
|
continue
|
|
|
|
logger.debug(
|
|
"Operator {} returned unsupported result type: {}",
|
|
operator_id,
|
|
type(result).__name__,
|
|
)
|
|
return current_sample
|
|
|
|
|
|
def _count_detections(sample: Dict[str, Any]) -> int:
|
|
"""从样本中提取检测数量。"""
|
|
|
|
annotations = sample.get("annotations")
|
|
if isinstance(annotations, dict):
|
|
detections = annotations.get("detections")
|
|
if isinstance(detections, list):
|
|
return len(detections)
|
|
|
|
detection_count = sample.get("detection_count")
|
|
if detection_count is None:
|
|
return 0
|
|
try:
|
|
return max(int(detection_count), 0)
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 数据集类型 → sample key / 输出子目录 映射
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DATASET_TYPE_SAMPLE_KEY: Dict[str, str] = {
|
|
"IMAGE": "image",
|
|
"TEXT": "text",
|
|
"AUDIO": "audio",
|
|
"VIDEO": "video",
|
|
}
|
|
|
|
DATASET_TYPE_DATA_DIR: Dict[str, str] = {
|
|
"IMAGE": "images",
|
|
"TEXT": "data",
|
|
"AUDIO": "data",
|
|
"VIDEO": "data",
|
|
}
|
|
|
|
|
|
def _get_sample_key(dataset_type: str) -> str:
|
|
"""根据数据集类型返回 sample dict 中主数据对应的 key。"""
|
|
return DATASET_TYPE_SAMPLE_KEY.get(dataset_type.upper(), "image")
|
|
|
|
|
|
def _get_data_dir_name(dataset_type: str) -> str:
|
|
"""根据数据集类型返回输出子目录名。"""
|
|
return DATASET_TYPE_DATA_DIR.get(dataset_type.upper(), "images")
|
|
|
|
|
|
def _get_operator_whitelist() -> Optional[set[str]]:
|
|
"""获取灰度白名单;返回 None 表示放开全部。"""
|
|
|
|
raw = str(DEFAULT_OPERATOR_WHITELIST or "").strip()
|
|
if not raw:
|
|
return None
|
|
|
|
normalized = raw.lower()
|
|
if normalized in {"*", "all", "any"}:
|
|
return None
|
|
|
|
allow_set = {
|
|
item.strip()
|
|
for item in raw.split(",")
|
|
if item and item.strip()
|
|
}
|
|
return allow_set or None
|
|
|
|
|
|
def _validate_pipeline_whitelist(pipeline: List[Dict[str, Any]]) -> None:
|
|
"""校验 pipeline 是否命中灰度白名单。"""
|
|
|
|
allow_set = _get_operator_whitelist()
|
|
if allow_set is None:
|
|
return
|
|
|
|
blocked = []
|
|
for step in pipeline:
|
|
operator_id = str(step.get("operator_id") or "")
|
|
if not operator_id:
|
|
continue
|
|
if operator_id not in allow_set:
|
|
blocked.append(operator_id)
|
|
|
|
if blocked:
|
|
raise RuntimeError(
|
|
"Operator not in whitelist: " + ", ".join(sorted(set(blocked)))
|
|
)
|
|
|
|
|
|
def _load_dataset_files(dataset_id: str) -> List[Tuple[str, str, str]]:
|
|
"""加载指定数据集下的所有已完成文件。"""
|
|
|
|
sql = text(
|
|
"""
|
|
SELECT id, file_path, file_name
|
|
FROM t_dm_dataset_files
|
|
WHERE dataset_id = :dataset_id
|
|
AND status = 'ACTIVE'
|
|
ORDER BY created_at ASC
|
|
"""
|
|
)
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
rows = conn.execute(sql, {"dataset_id": dataset_id}).fetchall()
|
|
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
|
|
|
|
|
def _load_dataset_meta(dataset_id: str) -> Optional[Dict[str, Any]]:
|
|
"""加载数据集基础信息(含父ID与路径)。"""
|
|
|
|
sql = text(
|
|
"""
|
|
SELECT id, name, parent_dataset_id, path, dataset_type
|
|
FROM t_dm_datasets
|
|
WHERE id = :dataset_id
|
|
"""
|
|
)
|
|
with SQLManager.create_connect() as conn:
|
|
row = conn.execute(sql, {"dataset_id": dataset_id}).fetchone()
|
|
if not row:
|
|
return None
|
|
return dict(row._mapping) # type: ignore[attr-defined]
|
|
|
|
|
|
def _resolve_output_parent(source_dataset_id: str) -> Tuple[Optional[str], str]:
|
|
"""根据源数据集确定产出数据集的父级与基路径(产出挂在父级下)。"""
|
|
|
|
base_path = DEFAULT_OUTPUT_ROOT.rstrip("/") or "/dataset"
|
|
source_meta = _load_dataset_meta(source_dataset_id)
|
|
if not source_meta:
|
|
return None, base_path
|
|
|
|
parent_dataset_id = source_meta.get("parent_dataset_id")
|
|
if not parent_dataset_id:
|
|
return None, base_path
|
|
|
|
parent_meta = _load_dataset_meta(str(parent_dataset_id))
|
|
parent_path = parent_meta.get("path") if parent_meta else None
|
|
if not parent_path:
|
|
return None, base_path
|
|
return str(parent_dataset_id), str(parent_path)
|
|
|
|
|
|
def _load_files_by_ids(file_ids: List[str]) -> List[Tuple[str, str, str]]:
|
|
"""根据文件ID列表加载文件记录,支持跨多个数据集。"""
|
|
|
|
if not file_ids:
|
|
return []
|
|
|
|
placeholders = ", ".join(f":id{i}" for i in range(len(file_ids)))
|
|
sql = text(
|
|
f"""
|
|
SELECT id, file_path, file_name
|
|
FROM t_dm_dataset_files
|
|
WHERE id IN ({placeholders})
|
|
AND status = 'ACTIVE'
|
|
ORDER BY created_at ASC
|
|
"""
|
|
)
|
|
params = {f"id{i}": str(fid) for i, fid in enumerate(file_ids)}
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
rows = conn.execute(sql, params).fetchall()
|
|
return [(str(r[0]), str(r[1]), str(r[2])) for r in rows]
|
|
|
|
|
|
def _ensure_output_dir(output_dir: str, dataset_type: str = "IMAGE") -> str:
|
|
"""确保输出目录及其数据/annotations 子目录存在。"""
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
data_dir_name = _get_data_dir_name(dataset_type)
|
|
os.makedirs(os.path.join(output_dir, data_dir_name), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, "annotations"), exist_ok=True)
|
|
return output_dir
|
|
|
|
|
|
def _create_output_dataset(
|
|
source_dataset_id: str,
|
|
source_dataset_name: str,
|
|
output_dataset_name: str,
|
|
dataset_type: str = "IMAGE",
|
|
) -> Tuple[str, str]:
|
|
"""为自动标注结果创建一个新的数据集并返回 (dataset_id, path)。"""
|
|
|
|
new_dataset_id = str(uuid.uuid4())
|
|
parent_dataset_id, dataset_base_path = _resolve_output_parent(source_dataset_id)
|
|
output_dir = os.path.join(dataset_base_path, new_dataset_id)
|
|
|
|
description = (
|
|
f"Auto annotations for dataset {source_dataset_name or source_dataset_id}"[:255]
|
|
)
|
|
|
|
sql = text(
|
|
"""
|
|
INSERT INTO t_dm_datasets (id, parent_dataset_id, name, description, dataset_type, path, status)
|
|
VALUES (:id, :parent_dataset_id, :name, :description, :dataset_type, :path, :status)
|
|
"""
|
|
)
|
|
params = {
|
|
"id": new_dataset_id,
|
|
"parent_dataset_id": parent_dataset_id,
|
|
"name": output_dataset_name,
|
|
"description": description,
|
|
"dataset_type": dataset_type,
|
|
"path": output_dir,
|
|
"status": "ACTIVE",
|
|
}
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
conn.execute(sql, params)
|
|
|
|
return new_dataset_id, output_dir
|
|
|
|
|
|
def _register_output_dataset(
|
|
task_id: str,
|
|
output_dataset_id: str,
|
|
output_dir: str,
|
|
output_dataset_name: str,
|
|
total_images: int,
|
|
dataset_type: str = "IMAGE",
|
|
) -> None:
|
|
"""将自动标注结果注册到新建的数据集。"""
|
|
|
|
data_dir_name = _get_data_dir_name(dataset_type)
|
|
data_dir = os.path.join(output_dir, data_dir_name)
|
|
# 兼容旧任务和 IMAGE 算子(它们写入 images/ 目录)
|
|
if not os.path.isdir(data_dir):
|
|
fallback_dir = os.path.join(output_dir, "images")
|
|
if os.path.isdir(fallback_dir):
|
|
data_dir = fallback_dir
|
|
else:
|
|
logger.warning(
|
|
"Auto-annotation data directory not found for task {}: {}",
|
|
task_id,
|
|
data_dir,
|
|
)
|
|
return
|
|
|
|
data_files: List[Tuple[str, str, int]] = []
|
|
annotation_files: List[Tuple[str, str, int]] = []
|
|
total_size = 0
|
|
|
|
for file_name in sorted(os.listdir(data_dir)):
|
|
file_path = os.path.join(data_dir, file_name)
|
|
if not os.path.isfile(file_path):
|
|
continue
|
|
try:
|
|
file_size = os.path.getsize(file_path)
|
|
except OSError:
|
|
file_size = 0
|
|
data_files.append((file_name, file_path, int(file_size)))
|
|
total_size += int(file_size)
|
|
|
|
annotations_dir = os.path.join(output_dir, "annotations")
|
|
if os.path.isdir(annotations_dir):
|
|
for file_name in sorted(os.listdir(annotations_dir)):
|
|
file_path = os.path.join(annotations_dir, file_name)
|
|
if not os.path.isfile(file_path):
|
|
continue
|
|
try:
|
|
file_size = os.path.getsize(file_path)
|
|
except OSError:
|
|
file_size = 0
|
|
annotation_files.append((file_name, file_path, int(file_size)))
|
|
total_size += int(file_size)
|
|
|
|
if not data_files:
|
|
logger.warning(
|
|
"No data files found in auto-annotation output for task {}: {}",
|
|
task_id,
|
|
data_dir,
|
|
)
|
|
return
|
|
|
|
insert_file_sql = text(
|
|
"""
|
|
INSERT INTO t_dm_dataset_files (
|
|
id, dataset_id, file_name, file_path, logical_path, version, file_type, file_size, status
|
|
) VALUES (
|
|
:id, :dataset_id, :file_name, :file_path, :logical_path, :version, :file_type, :file_size, :status
|
|
)
|
|
"""
|
|
)
|
|
update_dataset_stat_sql = text(
|
|
"""
|
|
UPDATE t_dm_datasets
|
|
SET file_count = COALESCE(file_count, 0) + :add_count,
|
|
size_bytes = COALESCE(size_bytes, 0) + :add_size
|
|
WHERE id = :dataset_id
|
|
"""
|
|
)
|
|
|
|
with SQLManager.create_connect() as conn:
|
|
added_count = 0
|
|
|
|
for file_name, file_path, file_size in data_files:
|
|
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
|
logical_path = os.path.relpath(file_path, output_dir).replace("\\", "/")
|
|
conn.execute(
|
|
insert_file_sql,
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"dataset_id": output_dataset_id,
|
|
"file_name": file_name,
|
|
"file_path": file_path,
|
|
"logical_path": logical_path,
|
|
"version": 1,
|
|
"file_type": ext,
|
|
"file_size": int(file_size),
|
|
"status": "ACTIVE",
|
|
},
|
|
)
|
|
added_count += 1
|
|
|
|
for file_name, file_path, file_size in annotation_files:
|
|
ext = os.path.splitext(file_name)[1].lstrip(".").upper() or None
|
|
logical_path = os.path.relpath(file_path, output_dir).replace("\\", "/")
|
|
conn.execute(
|
|
insert_file_sql,
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"dataset_id": output_dataset_id,
|
|
"file_name": file_name,
|
|
"file_path": file_path,
|
|
"logical_path": logical_path,
|
|
"version": 1,
|
|
"file_type": ext,
|
|
"file_size": int(file_size),
|
|
"status": "ACTIVE",
|
|
},
|
|
)
|
|
added_count += 1
|
|
|
|
if added_count > 0:
|
|
conn.execute(
|
|
update_dataset_stat_sql,
|
|
{
|
|
"dataset_id": output_dataset_id,
|
|
"add_count": added_count,
|
|
"add_size": int(total_size),
|
|
},
|
|
)
|
|
|
|
logger.info(
|
|
"Registered auto-annotation output into dataset: dataset_id={}, name={}, added_files={}, added_size_bytes={}, task_id={}, output_dir={}",
|
|
output_dataset_id,
|
|
output_dataset_name,
|
|
len(data_files) + len(annotation_files),
|
|
total_size,
|
|
task_id,
|
|
output_dir,
|
|
)
|
|
|
|
|
|
def _process_single_task(task: Dict[str, Any]) -> None:
|
|
"""执行单个自动标注任务。"""
|
|
|
|
task_id = str(task["id"])
|
|
dataset_id = str(task["dataset_id"])
|
|
task_name = str(task.get("name") or "")
|
|
source_dataset_name = str(task.get("dataset_name") or "")
|
|
run_token = str(task.get("run_token") or "")
|
|
task_mode = str(task.get("task_mode") or "legacy_yolo")
|
|
executor_type = str(task.get("executor_type") or "annotation_local")
|
|
cfg: Dict[str, Any] = task.get("config") or {}
|
|
pipeline_raw = task.get("pipeline")
|
|
selected_file_ids: Optional[List[str]] = task.get("file_ids") or None
|
|
|
|
# 解析数据集类型,兜底从数据集元数据获取
|
|
dataset_type = str(task.get("dataset_type") or "").upper() or "IMAGE"
|
|
if dataset_type == "IMAGE" and not task.get("dataset_type"):
|
|
source_meta = _load_dataset_meta(dataset_id)
|
|
if source_meta and source_meta.get("dataset_type"):
|
|
dataset_type = str(source_meta["dataset_type"]).upper()
|
|
|
|
output_dataset_name = _get_output_dataset_name(
|
|
task_id=task_id,
|
|
dataset_id=dataset_id,
|
|
source_dataset_name=source_dataset_name,
|
|
task_name=task_name,
|
|
config=cfg,
|
|
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
|
|
)
|
|
|
|
logger.info(
|
|
"Start processing auto-annotation task: id={}, dataset_id={}, task_mode={}, executor_type={}, output_dataset_name={}",
|
|
task_id,
|
|
dataset_id,
|
|
task_mode,
|
|
executor_type,
|
|
output_dataset_name,
|
|
)
|
|
|
|
if _is_stop_requested(task_id, run_token):
|
|
logger.info("Task stop requested before processing started: {}", task_id)
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="stopped",
|
|
completed=True,
|
|
clear_run_token=True,
|
|
error_message="Task stopped before start",
|
|
)
|
|
return
|
|
|
|
_update_task_status(task_id, run_token=run_token, status="running", progress=0)
|
|
|
|
if selected_file_ids:
|
|
all_files = _load_files_by_ids(selected_file_ids)
|
|
else:
|
|
all_files = _load_dataset_files(dataset_id)
|
|
|
|
files = [(path, name) for _, path, name in all_files]
|
|
|
|
total_images = len(files)
|
|
if total_images == 0:
|
|
logger.warning("No files found for dataset {} when running auto-annotation task {}", dataset_id, task_id)
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="completed",
|
|
progress=100,
|
|
total_images=0,
|
|
processed_images=0,
|
|
detected_objects=0,
|
|
completed=True,
|
|
output_path=None,
|
|
clear_run_token=True,
|
|
)
|
|
return
|
|
|
|
output_dataset_id, output_dir = _create_output_dataset(
|
|
source_dataset_id=dataset_id,
|
|
source_dataset_name=source_dataset_name,
|
|
output_dataset_name=output_dataset_name,
|
|
dataset_type=dataset_type,
|
|
)
|
|
output_dir = _ensure_output_dir(output_dir, dataset_type=dataset_type)
|
|
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="running",
|
|
total_images=total_images,
|
|
output_path=output_dir,
|
|
output_dataset_id=output_dataset_id,
|
|
)
|
|
|
|
try:
|
|
normalized_pipeline = _normalize_pipeline(
|
|
task_mode=task_mode,
|
|
config=cfg,
|
|
pipeline_raw=pipeline_raw if isinstance(pipeline_raw, list) else None,
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
if not normalized_pipeline:
|
|
raise RuntimeError("Pipeline is empty after normalization")
|
|
|
|
_validate_pipeline_whitelist(normalized_pipeline)
|
|
|
|
chain = _build_operator_chain(normalized_pipeline)
|
|
if not chain:
|
|
raise RuntimeError("No valid operator instances initialized")
|
|
except Exception as e:
|
|
logger.error("Failed to init operator pipeline for task {}: {}", task_id, e)
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="failed",
|
|
total_images=total_images,
|
|
processed_images=0,
|
|
detected_objects=0,
|
|
error_message=f"Init pipeline failed: {e}",
|
|
clear_run_token=True,
|
|
)
|
|
return
|
|
|
|
processed = 0
|
|
detected_total = 0
|
|
|
|
try:
|
|
|
|
for file_path, file_name in files:
|
|
if _is_stop_requested(task_id, run_token):
|
|
logger.info("Task stop requested during processing: {}", task_id)
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="stopped",
|
|
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
|
processed_images=processed,
|
|
detected_objects=detected_total,
|
|
total_images=total_images,
|
|
output_path=output_dir,
|
|
output_dataset_id=output_dataset_id,
|
|
completed=True,
|
|
clear_run_token=True,
|
|
error_message="Task stopped by request",
|
|
)
|
|
return
|
|
|
|
try:
|
|
sample_key = _get_sample_key(dataset_type)
|
|
sample = {
|
|
sample_key: file_path,
|
|
"filename": file_name,
|
|
}
|
|
|
|
result = _run_pipeline_sample(sample, chain)
|
|
detected_total += _count_detections(result)
|
|
processed += 1
|
|
|
|
progress = int(processed * 100 / total_images) if total_images > 0 else 100
|
|
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="running",
|
|
progress=progress,
|
|
processed_images=processed,
|
|
detected_objects=detected_total,
|
|
total_images=total_images,
|
|
output_path=output_dir,
|
|
output_dataset_id=output_dataset_id,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to process file for task {}: file_path={}, error={}",
|
|
task_id,
|
|
file_path,
|
|
e,
|
|
)
|
|
continue
|
|
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="completed",
|
|
progress=100,
|
|
processed_images=processed,
|
|
detected_objects=detected_total,
|
|
total_images=total_images,
|
|
output_path=output_dir,
|
|
output_dataset_id=output_dataset_id,
|
|
completed=True,
|
|
clear_run_token=True,
|
|
)
|
|
|
|
logger.info(
|
|
"Completed auto-annotation task: id={}, total_images={}, processed={}, detected_objects={}, output_path={}",
|
|
task_id,
|
|
total_images,
|
|
processed,
|
|
detected_total,
|
|
output_dir,
|
|
)
|
|
|
|
if output_dataset_name and output_dataset_id:
|
|
try:
|
|
_register_output_dataset(
|
|
task_id=task_id,
|
|
output_dataset_id=output_dataset_id,
|
|
output_dir=output_dir,
|
|
output_dataset_name=output_dataset_name,
|
|
total_images=total_images,
|
|
dataset_type=dataset_type,
|
|
)
|
|
except Exception as e: # pragma: no cover - 防御性日志
|
|
logger.error(
|
|
"Failed to register auto-annotation output as dataset for task {}: {}",
|
|
task_id,
|
|
e,
|
|
)
|
|
except Exception as e:
|
|
logger.error("Task execution failed for task {}: {}", task_id, e)
|
|
_update_task_status(
|
|
task_id,
|
|
run_token=run_token,
|
|
status="failed",
|
|
progress=int(processed * 100 / total_images) if total_images > 0 else 0,
|
|
processed_images=processed,
|
|
detected_objects=detected_total,
|
|
total_images=total_images,
|
|
output_path=output_dir,
|
|
output_dataset_id=output_dataset_id,
|
|
error_message=f"Execute pipeline failed: {e}",
|
|
clear_run_token=True,
|
|
)
|
|
|
|
|
|
def _worker_loop() -> None:
|
|
"""Worker 主循环,在独立线程中运行。"""
|
|
|
|
logger.info(
|
|
"Auto-annotation worker started with poll interval {} seconds, output root {}",
|
|
POLL_INTERVAL_SECONDS,
|
|
DEFAULT_OUTPUT_ROOT,
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
task = _fetch_pending_task()
|
|
if not task:
|
|
time.sleep(POLL_INTERVAL_SECONDS)
|
|
continue
|
|
|
|
_process_single_task(task)
|
|
except Exception as e: # pragma: no cover - 防御性日志
|
|
logger.error("Auto-annotation worker loop error: {}", e)
|
|
time.sleep(POLL_INTERVAL_SECONDS)
|
|
|
|
|
|
def start_auto_annotation_worker() -> None:
|
|
"""在后台线程中启动自动标注 worker。"""
|
|
|
|
thread = threading.Thread(target=_worker_loop, name="auto-annotation-worker", daemon=True)
|
|
thread.start()
|
|
logger.info("Auto-annotation worker thread started: {}", thread.name)
|