import asyncio from datetime import datetime from typing import Optional, List from loguru import logger from .scheduler import Task, TaskStatus, TaskResult, TaskScheduler class CommandTask(Task): """命令任务包装类""" def __init__(self, task_id: str, command: str, log_path = None, shell: bool = True, timeout: Optional[int] = None, *args, **kwargs): super().__init__(task_id, *args, **kwargs) self.log_path = log_path self.command = command self.shell = shell self.timeout = timeout self.return_code = None self._process = None def start(self) -> 'CommandTask': """启动任务""" if self.status == TaskStatus.PENDING: self.status = TaskStatus.RUNNING self.started_at = datetime.now() self._task = asyncio.create_task(self._execute()) return self async def _execute(self): """执行命令""" try: self.status = TaskStatus.RUNNING self.started_at = datetime.now() with open(self.log_path, 'a') as f: # 使用 asyncio.create_subprocess_shell 或 create_subprocess_exec if self.shell: process = await asyncio.create_subprocess_shell( self.command, stdout=f, stderr=asyncio.subprocess.STDOUT, **self.kwargs ) else: process = await asyncio.create_subprocess_exec( *self.command.split(), stdout=f, stderr=asyncio.subprocess.STDOUT, **self.kwargs ) self._process = process # 等待进程完成(带超时) try: if self.timeout: await asyncio.wait_for( process.wait(), timeout=self.timeout ) else: await process.wait() self.return_code = process.returncode if self._cancelled: self.status = TaskStatus.CANCELLED elif process.returncode == 0: self.status = TaskStatus.COMPLETED else: self.status = TaskStatus.FAILED except asyncio.TimeoutError: # 超时处理 self._process.terminate() try: await asyncio.wait_for(self._process.wait(), timeout=5.0) except asyncio.TimeoutError: self._process.kill() await self._process.wait() self.status = TaskStatus.FAILED f.write(f"\nCommand timed out after {self.timeout} seconds\n") except asyncio.CancelledError: # 任务被取消 if self._process: self._process.terminate() try: await asyncio.wait_for(self._process.wait(), timeout=5.0) except asyncio.TimeoutError: self._process.kill() await self._process.wait() self.status = TaskStatus.CANCELLED self._cancelled = True except Exception as e: self.status = TaskStatus.FAILED logger.error(f"Task(id: {self.task_id}) run failed. Cause: {e}") finally: self.completed_at = datetime.now() def cancel(self) -> bool: """取消任务""" if self._process and self.status == TaskStatus.RUNNING: try: # 尝试优雅终止 self._process.terminate() self._cancelled = True return True except Exception: # 如果无法终止,强制杀死 try: self._process.kill() self._cancelled = True return True except Exception: return False return False def to_result(self) -> TaskResult: """转换为结果对象""" self.result = { "command": self.command, "return_code": self.return_code, } return super().to_result() class CommandScheduler(TaskScheduler): """命令调度器""" def __init__(self, max_concurrent: int = 5): super().__init__(max_concurrent) async def submit(self, task_id, command: str, log_path = None, shell: bool = True, timeout: Optional[int] = None, **kwargs) -> str: if log_path is None: log_path = f"/flow/{task_id}/output.log" """提交命令任务""" task = CommandTask(task_id, command, log_path, shell, timeout, **kwargs) self.tasks[task_id] = task # 使用信号量限制并发 async with self.semaphore: # 异步执行任务 task.start() logger.info(f"命令任务 {task_id} 已提交并开始执行") return task_id def get_task_status(self, task_id: str) -> Optional[TaskResult]: """获取任务状态""" task = self.tasks.get(task_id) if task: return task.to_result() return None def get_all_tasks(self) -> List[TaskResult]: """获取所有任务状态""" return [task.to_result() for task in self.tasks.values()] def cancel_task(self, task_id: str) -> bool: """取消任务""" task = self.tasks.get(task_id) if not task: return True if task.status == TaskStatus.RUNNING: cancelled = task.cancel() if cancelled: logger.info(f"命令任务 {task_id} 已取消") return cancelled return False def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]: """根据状态获取任务""" return [ task.to_result() for task in self.tasks.values() if task.status == status ] async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult: """等待任务完成""" task = self.tasks.get(task_id) if not task: raise ValueError(f"任务 {task_id} 不存在") if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: return task.to_result() # 对于运行中的任务,我们已经通过 await task.execute() 等待了 # 所以这里直接返回结果 return task.to_result() async def shutdown(self): """关闭调度器,取消所有运行中的任务""" logger.info("正在关闭命令调度器...") running_tasks = [ task for task in self.tasks.values() if task.status == TaskStatus.RUNNING ] for task in running_tasks: logger.info(f"取消运行中的命令任务: {task.task_id}") task.cancel() logger.info("命令调度器已关闭")