# -*- coding: utf-8 -*- """ GPU 调度器 提供多 GPU 设备的轮询调度功能。 """ import logging import threading from typing import List, Optional from domain.config import WorkerConfig from domain.gpu import GPUDevice from util.system import get_all_gpu_info, validate_gpu_device from constant import HW_ACCEL_CUDA, HW_ACCEL_QSV logger = logging.getLogger(__name__) class GPUScheduler: """ GPU 调度器 实现多 GPU 设备的轮询(Round Robin)调度。 线程安全,支持并发任务执行。 使用方式: scheduler = GPUScheduler(config) # 在任务执行时 device_index = scheduler.acquire() try: # 执行任务 pass finally: scheduler.release(device_index) """ def __init__(self, config: WorkerConfig): """ 初始化调度器 Args: config: Worker 配置 """ self._config = config self._devices: List[GPUDevice] = [] self._next_index: int = 0 self._lock = threading.Lock() self._enabled = False # 初始化设备列表 self._init_devices() def _init_devices(self) -> None: """初始化 GPU 设备列表""" # 仅在启用硬件加速时才初始化 if self._config.hw_accel not in (HW_ACCEL_CUDA, HW_ACCEL_QSV): logger.info("Hardware acceleration not enabled, GPU scheduler disabled") return configured_devices = self._config.gpu_devices if configured_devices: # 使用配置指定的设备 self._devices = self._validate_configured_devices(configured_devices) else: # 自动检测所有设备 self._devices = self._auto_detect_devices() if self._devices: self._enabled = True device_info = ', '.join(str(d) for d in self._devices) logger.info(f"GPU scheduler initialized with {len(self._devices)} device(s): {device_info}") else: logger.warning("No GPU devices available, scheduler disabled") def _validate_configured_devices(self, indices: List[int]) -> List[GPUDevice]: """ 验证配置的设备列表 Args: indices: 配置的设备索引列表 Returns: 验证通过的设备列表 """ devices = [] for index in indices: if validate_gpu_device(index): devices.append(GPUDevice( index=index, name=f"GPU-{index}", available=True )) else: logger.warning(f"GPU device {index} is not available, skipping") return devices def _auto_detect_devices(self) -> List[GPUDevice]: """ 自动检测所有可用 GPU Returns: 检测到的设备列表 """ all_devices = get_all_gpu_info() # 过滤不可用设备 return [d for d in all_devices if d.available] @property def enabled(self) -> bool: """调度器是否启用""" return self._enabled @property def device_count(self) -> int: """设备数量""" return len(self._devices) def acquire(self) -> Optional[int]: """ 获取下一个可用的 GPU 设备(轮询调度) Returns: GPU 设备索引,如果调度器未启用或无设备则返回 None """ if not self._enabled or not self._devices: return None with self._lock: device = self._devices[self._next_index] self._next_index = (self._next_index + 1) % len(self._devices) logger.debug(f"Acquired GPU device: {device.index}") return device.index def release(self, device_index: Optional[int]) -> None: """ 释放 GPU 设备 当前实现为无状态轮询,此方法仅用于日志记录。 Args: device_index: 设备索引 """ if device_index is not None: logger.debug(f"Released GPU device: {device_index}") def get_status(self) -> dict: """ 获取调度器状态信息 Returns: 状态字典 """ return { 'enabled': self._enabled, 'device_count': len(self._devices), 'devices': [ {'index': d.index, 'name': d.name, 'available': d.available} for d in self._devices ], 'hw_accel': self._config.hw_accel, }