diff --git a/services/task_executor.py b/services/task_executor.py index efb8883..5d4038e 100644 --- a/services/task_executor.py +++ b/services/task_executor.py @@ -12,6 +12,12 @@ from typing import Dict, Optional, TYPE_CHECKING from domain.task import Task, TaskType from domain.result import TaskResult, ErrorCode + +# 需要 GPU 加速的任务类型 +GPU_REQUIRED_TASK_TYPES = { + TaskType.RENDER_SEGMENT_VIDEO, + TaskType.COMPOSE_TRANSITION, +} from domain.config import WorkerConfig from core.handler import TaskHandler from services.lease_service import LeaseService @@ -179,9 +185,10 @@ class TaskExecutor: ) lease_service.start() - # 获取 GPU 设备 + # 获取 GPU 设备(仅对需要 GPU 的任务类型) device_index = None - if self.gpu_scheduler.enabled: + needs_gpu = task.task_type in GPU_REQUIRED_TASK_TYPES + if needs_gpu and self.gpu_scheduler.enabled: device_index = self.gpu_scheduler.acquire() if device_index is not None: logger.info(f"[task:{task_id}] Assigned to GPU device {device_index}") @@ -227,8 +234,8 @@ class TaskExecutor: if handler: handler.clear_gpu_device() - # 释放 GPU 设备 - if self.gpu_scheduler.enabled: + # 释放 GPU 设备(仅当实际分配了设备时) + if device_index is not None: self.gpu_scheduler.release(device_index) # 停止租约续期