Files
DataMate/runtime/python-executor/datamate/scheduler/func_task_scheduler.py
2025-10-21 23:00:48 +08:00

134 lines
4.2 KiB
Python

import asyncio
from datetime import datetime
from typing import Callable, Optional, List
from loguru import logger
from .scheduler import TaskStatus, TaskResult, Task, TaskScheduler
class CallableTask(Task):
"""任务包装类"""
def __init__(self, task_id: str, func: Callable, *args, **kwargs):
super().__init__(task_id, *args, **kwargs)
self.func = func
def start(self) -> 'CallableTask':
"""启动任务"""
if self.status == TaskStatus.PENDING:
self.status = TaskStatus.RUNNING
self.started_at = datetime.now()
self._task = asyncio.create_task(self._execute())
return self
async def _execute(self):
"""执行任务"""
try:
self.result = await self.func(*self.args, **self.kwargs)
self.status = TaskStatus.COMPLETED
except asyncio.CancelledError:
self.status = TaskStatus.CANCELLED
self._cancelled = True
except Exception as e:
self.status = TaskStatus.FAILED
self.error = str(e)
finally:
self.completed_at = datetime.now()
def cancel(self) -> bool:
"""取消任务"""
if self._task and not self._task.done():
self._task.cancel()
return True
return False
class CallableScheduler(TaskScheduler):
"""异步任务调度器"""
def __init__(self, max_concurrent: int = 10):
super().__init__(max_concurrent)
async def submit(self, task_id, func: Callable, *args, **kwargs) -> str:
"""提交任务"""
task = CallableTask(task_id, func, *args, **kwargs)
self.tasks[task_id] = task
# 使用信号量限制并发
async with self.semaphore:
task.start()
logger.info(f"任务 {task_id} 已提交并开始执行")
return task_id
def get_task_status(self, task_id: str) -> Optional[TaskResult]:
"""获取任务状态"""
task = self.tasks.get(task_id)
if task:
return task.to_result()
return None
def get_all_tasks(self) -> List[TaskResult]:
"""获取所有任务状态"""
return [task.to_result() for task in self.tasks.values()]
def cancel_task(self, task_id: str) -> bool:
"""取消任务"""
task = self.tasks.get(task_id)
if task and task.status == TaskStatus.RUNNING:
cancelled = task.cancel()
if cancelled:
logger.info(f"任务 {task_id} 已取消")
return cancelled
return False
def get_tasks_by_status(self, status: TaskStatus) -> List[TaskResult]:
"""根据状态获取任务"""
return [
task.to_result()
for task in self.tasks.values()
if task.status == status
]
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> TaskResult:
"""等待任务完成"""
task = self.tasks.get(task_id)
if not task:
raise ValueError(f"任务 {task_id} 不存在")
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return task.to_result()
# 等待任务完成
if task.get():
try:
await asyncio.wait_for(task.get(), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"任务 {task_id} 超时")
return task.to_result()
async def shutdown(self):
"""关闭调度器,取消所有运行中的任务"""
logger.info("正在关闭调度器...")
running_tasks = [
task for task in self.tasks.values()
if task.status == TaskStatus.RUNNING
]
for task in running_tasks:
logger.info(f"取消运行中的任务: {task.task_id}")
task.cancel()
# 等待所有任务完成
for task in running_tasks:
if task.get() and not task.get().done():
try:
await asyncio.wait_for(task.get(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning(f"任务 {task.task_id} 无法正常停止")
logger.info("调度器已关闭")