diff --git a/domain/config.py b/domain/config.py index 8fbbdfd..7715479 100644 --- a/domain/config.py +++ b/domain/config.py @@ -59,6 +59,11 @@ class WorkerConfig: # 硬件加速配置 hw_accel: str = HW_ACCEL_NONE # 硬件加速类型: none, qsv, cuda + # 素材缓存配置 + cache_enabled: bool = True # 是否启用素材缓存 + cache_dir: str = "" # 缓存目录,默认为 temp_dir/cache + cache_max_size_gb: float = 0 # 最大缓存大小(GB),0 表示不限制 + @classmethod def from_env(cls) -> 'WorkerConfig': """从环境变量创建配置""" @@ -108,6 +113,11 @@ class WorkerConfig: if hw_accel not in HW_ACCEL_TYPES: hw_accel = HW_ACCEL_NONE + # 素材缓存配置 + cache_enabled = os.getenv('CACHE_ENABLED', 'true').lower() in ('true', '1', 'yes') + cache_dir = os.getenv('CACHE_DIR', '') # 空字符串表示使用默认路径 + cache_max_size_gb = float(os.getenv('CACHE_MAX_SIZE_GB', '0')) + return cls( api_endpoint=api_endpoint, access_key=access_key, @@ -121,7 +131,10 @@ class WorkerConfig: ffmpeg_timeout=ffmpeg_timeout, download_timeout=download_timeout, upload_timeout=upload_timeout, - hw_accel=hw_accel + hw_accel=hw_accel, + cache_enabled=cache_enabled, + cache_dir=cache_dir if cache_dir else os.path.join(temp_dir, 'cache'), + cache_max_size_gb=cache_max_size_gb ) def get_work_dir_path(self, task_id: str) -> str: diff --git a/handlers/base.py b/handlers/base.py index 2542096..2742dc5 100644 --- a/handlers/base.py +++ b/handlers/base.py @@ -19,6 +19,7 @@ from domain.task import Task from domain.result import TaskResult, ErrorCode from domain.config import WorkerConfig from services import storage +from services.cache import MaterialCache from constant import ( HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA, VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA @@ -260,6 +261,11 @@ class BaseHandler(TaskHandler, ABC): """ self.config = config self.api_client = api_client + self.material_cache = MaterialCache( + cache_dir=config.cache_dir, + enabled=config.cache_enabled, + max_size_gb=config.cache_max_size_gb + ) def get_video_encode_args(self) -> List[str]: """ @@ -333,14 +339,15 @@ class BaseHandler(TaskHandler, ABC): except Exception as e: logger.warning(f"Failed to cleanup work directory {work_dir}: {e}") - def download_file(self, url: str, dest: str, timeout: int = None) -> bool: + def download_file(self, url: str, dest: str, timeout: int = None, use_cache: bool = True) -> bool: """ - 下载文件 + 下载文件(支持缓存) Args: url: 文件 URL dest: 目标路径 timeout: 超时时间(秒) + use_cache: 是否使用缓存(默认 True) Returns: 是否成功 @@ -349,7 +356,13 @@ class BaseHandler(TaskHandler, ABC): timeout = self.config.download_timeout try: - result = storage.download_file(url, dest, timeout=timeout) + if use_cache: + # 使用缓存下载 + result = self.material_cache.get_or_download(url, dest, timeout=timeout) + else: + # 直接下载(不走缓存) + result = storage.download_file(url, dest, timeout=timeout) + if result: file_size = os.path.getsize(dest) if os.path.exists(dest) else 0 logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)") diff --git a/services/cache.py b/services/cache.py new file mode 100644 index 0000000..97499f1 --- /dev/null +++ b/services/cache.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +""" +素材缓存服务 + +提供素材下载缓存功能,避免相同素材重复下载。 +""" + +import os +import hashlib +import logging +import shutil +import time +from typing import Optional, Tuple +from urllib.parse import urlparse, unquote + +from services import storage + +logger = logging.getLogger(__name__) + + +def _extract_cache_key(url: str) -> str: + """ + 从 URL 提取缓存键 + + 去除签名等查询参数,保留路径作为唯一标识。 + + Args: + url: 完整的素材 URL + + Returns: + 缓存键(URL 路径的 MD5 哈希) + """ + parsed = urlparse(url) + # 使用 scheme + host + path 作为唯一标识(忽略签名等查询参数) + cache_key_source = f"{parsed.scheme}://{parsed.netloc}{unquote(parsed.path)}" + return hashlib.md5(cache_key_source.encode('utf-8')).hexdigest() + + +def _get_file_extension(url: str) -> str: + """ + 从 URL 提取文件扩展名 + + Args: + url: 素材 URL + + Returns: + 文件扩展名(如 .mp4, .png),无法识别时返回空字符串 + """ + parsed = urlparse(url) + path = unquote(parsed.path) + _, ext = os.path.splitext(path) + return ext.lower() if ext else '' + + +class MaterialCache: + """ + 素材缓存管理器 + + 负责素材文件的缓存存储和检索。 + """ + + def __init__(self, cache_dir: str, enabled: bool = True, max_size_gb: float = 0): + """ + 初始化缓存管理器 + + Args: + cache_dir: 缓存目录路径 + enabled: 是否启用缓存 + max_size_gb: 最大缓存大小(GB),0 表示不限制 + """ + self.cache_dir = cache_dir + self.enabled = enabled + self.max_size_bytes = int(max_size_gb * 1024 * 1024 * 1024) if max_size_gb > 0 else 0 + + if self.enabled: + os.makedirs(self.cache_dir, exist_ok=True) + logger.info(f"Material cache initialized: {cache_dir}") + + def get_cache_path(self, url: str) -> str: + """ + 获取素材的缓存文件路径 + + Args: + url: 素材 URL + + Returns: + 缓存文件的完整路径 + """ + cache_key = _extract_cache_key(url) + ext = _get_file_extension(url) + filename = f"{cache_key}{ext}" + return os.path.join(self.cache_dir, filename) + + def is_cached(self, url: str) -> Tuple[bool, str]: + """ + 检查素材是否已缓存 + + Args: + url: 素材 URL + + Returns: + (是否已缓存, 缓存文件路径) + """ + if not self.enabled: + return False, '' + + cache_path = self.get_cache_path(url) + exists = os.path.exists(cache_path) and os.path.getsize(cache_path) > 0 + return exists, cache_path + + def get_or_download( + self, + url: str, + dest: str, + timeout: int = 300, + max_retries: int = 5 + ) -> bool: + """ + 从缓存获取素材,若未缓存则下载并缓存 + + Args: + url: 素材 URL + dest: 目标文件路径(任务工作目录中的路径) + timeout: 下载超时时间(秒) + max_retries: 最大重试次数 + + Returns: + 是否成功 + """ + # 确保目标目录存在 + dest_dir = os.path.dirname(dest) + if dest_dir: + os.makedirs(dest_dir, exist_ok=True) + + # 缓存未启用时直接下载 + if not self.enabled: + return storage.download_file(url, dest, max_retries=max_retries, timeout=timeout) + + # 检查缓存 + cached, cache_path = self.is_cached(url) + + if cached: + # 命中缓存,复制到目标路径 + try: + shutil.copy2(cache_path, dest) + # 更新访问时间(用于 LRU 清理) + os.utime(cache_path, None) + file_size = os.path.getsize(dest) + logger.info(f"Cache hit: {url[:80]}... -> {dest} ({file_size} bytes)") + return True + except Exception as e: + logger.warning(f"Failed to copy from cache: {e}, will re-download") + # 缓存复制失败,删除可能损坏的缓存文件 + try: + os.remove(cache_path) + except Exception: + pass + + # 未命中缓存,下载到缓存目录 + logger.debug(f"Cache miss: {url[:80]}...") + + # 先下载到临时文件 + temp_cache_path = cache_path + '.downloading' + try: + if not storage.download_file(url, temp_cache_path, max_retries=max_retries, timeout=timeout): + # 下载失败,清理临时文件 + if os.path.exists(temp_cache_path): + os.remove(temp_cache_path) + return False + + # 下载成功,移动到正式缓存路径 + if os.path.exists(cache_path): + os.remove(cache_path) + os.rename(temp_cache_path, cache_path) + + # 复制到目标路径 + shutil.copy2(cache_path, dest) + file_size = os.path.getsize(dest) + logger.info(f"Downloaded and cached: {url[:80]}... ({file_size} bytes)") + + # 检查是否需要清理缓存 + if self.max_size_bytes > 0: + self._cleanup_if_needed() + + return True + + except Exception as e: + logger.error(f"Cache download error: {e}") + # 清理临时文件 + if os.path.exists(temp_cache_path): + try: + os.remove(temp_cache_path) + except Exception: + pass + return False + + def _cleanup_if_needed(self) -> None: + """ + 检查并清理缓存(LRU 策略) + + 当缓存大小超过限制时,删除最久未访问的文件。 + """ + if self.max_size_bytes <= 0: + return + + try: + # 获取所有缓存文件及其信息 + cache_files = [] + total_size = 0 + + for filename in os.listdir(self.cache_dir): + if filename.endswith('.downloading'): + continue + file_path = os.path.join(self.cache_dir, filename) + if os.path.isfile(file_path): + stat = os.stat(file_path) + cache_files.append({ + 'path': file_path, + 'size': stat.st_size, + 'atime': stat.st_atime + }) + total_size += stat.st_size + + # 如果未超过限制,无需清理 + if total_size <= self.max_size_bytes: + return + + # 按访问时间排序(最久未访问的在前) + cache_files.sort(key=lambda x: x['atime']) + + # 删除文件直到低于限制的 80% + target_size = int(self.max_size_bytes * 0.8) + deleted_count = 0 + + for file_info in cache_files: + if total_size <= target_size: + break + try: + os.remove(file_info['path']) + total_size -= file_info['size'] + deleted_count += 1 + except Exception as e: + logger.warning(f"Failed to delete cache file: {e}") + + if deleted_count > 0: + logger.info(f"Cache cleanup: deleted {deleted_count} files, current size: {total_size / (1024*1024*1024):.2f} GB") + + except Exception as e: + logger.warning(f"Cache cleanup error: {e}") + + def clear(self) -> None: + """清空所有缓存""" + if not self.enabled: + return + + try: + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + os.makedirs(self.cache_dir, exist_ok=True) + logger.info("Cache cleared") + except Exception as e: + logger.error(f"Failed to clear cache: {e}") + + def get_stats(self) -> dict: + """ + 获取缓存统计信息 + + Returns: + 包含缓存统计的字典 + """ + if not self.enabled or not os.path.exists(self.cache_dir): + return {'enabled': False, 'file_count': 0, 'total_size_mb': 0} + + file_count = 0 + total_size = 0 + + for filename in os.listdir(self.cache_dir): + if filename.endswith('.downloading'): + continue + file_path = os.path.join(self.cache_dir, filename) + if os.path.isfile(file_path): + file_count += 1 + total_size += os.path.getsize(file_path) + + return { + 'enabled': True, + 'cache_dir': self.cache_dir, + 'file_count': file_count, + 'total_size_mb': round(total_size / (1024 * 1024), 2), + 'max_size_gb': self.max_size_bytes / (1024 * 1024 * 1024) if self.max_size_bytes > 0 else 0 + }