feat(cache): 添加素材缓存功能以避免重复下载

- 新增素材缓存配置选项包括启用状态、缓存目录和最大缓存大小
- 实现 MaterialCache 类提供缓存存储和检索功能
- 修改 download_file 方法支持缓存下载模式
- 添加缓存清理机制使用 LRU 策略管理磁盘空间
- 配置默认值优化本地开发体验
- 实现缓存统计和监控功能
This commit is contained in:
2026-01-17 15:07:12 +08:00
parent d5cd0dca03
commit fe757408b6
3 changed files with 321 additions and 4 deletions

View File

@@ -59,6 +59,11 @@ class WorkerConfig:
# 硬件加速配置 # 硬件加速配置
hw_accel: str = HW_ACCEL_NONE # 硬件加速类型: none, qsv, cuda hw_accel: str = HW_ACCEL_NONE # 硬件加速类型: none, qsv, cuda
# 素材缓存配置
cache_enabled: bool = True # 是否启用素材缓存
cache_dir: str = "" # 缓存目录,默认为 temp_dir/cache
cache_max_size_gb: float = 0 # 最大缓存大小(GB),0 表示不限制
@classmethod @classmethod
def from_env(cls) -> 'WorkerConfig': def from_env(cls) -> 'WorkerConfig':
"""从环境变量创建配置""" """从环境变量创建配置"""
@@ -108,6 +113,11 @@ class WorkerConfig:
if hw_accel not in HW_ACCEL_TYPES: if hw_accel not in HW_ACCEL_TYPES:
hw_accel = HW_ACCEL_NONE hw_accel = HW_ACCEL_NONE
# 素材缓存配置
cache_enabled = os.getenv('CACHE_ENABLED', 'true').lower() in ('true', '1', 'yes')
cache_dir = os.getenv('CACHE_DIR', '') # 空字符串表示使用默认路径
cache_max_size_gb = float(os.getenv('CACHE_MAX_SIZE_GB', '0'))
return cls( return cls(
api_endpoint=api_endpoint, api_endpoint=api_endpoint,
access_key=access_key, access_key=access_key,
@@ -121,7 +131,10 @@ class WorkerConfig:
ffmpeg_timeout=ffmpeg_timeout, ffmpeg_timeout=ffmpeg_timeout,
download_timeout=download_timeout, download_timeout=download_timeout,
upload_timeout=upload_timeout, upload_timeout=upload_timeout,
hw_accel=hw_accel hw_accel=hw_accel,
cache_enabled=cache_enabled,
cache_dir=cache_dir if cache_dir else os.path.join(temp_dir, 'cache'),
cache_max_size_gb=cache_max_size_gb
) )
def get_work_dir_path(self, task_id: str) -> str: def get_work_dir_path(self, task_id: str) -> str:

View File

@@ -19,6 +19,7 @@ from domain.task import Task
from domain.result import TaskResult, ErrorCode from domain.result import TaskResult, ErrorCode
from domain.config import WorkerConfig from domain.config import WorkerConfig
from services import storage from services import storage
from services.cache import MaterialCache
from constant import ( from constant import (
HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA, HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA,
VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA
@@ -260,6 +261,11 @@ class BaseHandler(TaskHandler, ABC):
""" """
self.config = config self.config = config
self.api_client = api_client self.api_client = api_client
self.material_cache = MaterialCache(
cache_dir=config.cache_dir,
enabled=config.cache_enabled,
max_size_gb=config.cache_max_size_gb
)
def get_video_encode_args(self) -> List[str]: def get_video_encode_args(self) -> List[str]:
""" """
@@ -333,14 +339,15 @@ class BaseHandler(TaskHandler, ABC):
except Exception as e: except Exception as e:
logger.warning(f"Failed to cleanup work directory {work_dir}: {e}") logger.warning(f"Failed to cleanup work directory {work_dir}: {e}")
def download_file(self, url: str, dest: str, timeout: int = None) -> bool: def download_file(self, url: str, dest: str, timeout: int = None, use_cache: bool = True) -> bool:
""" """
下载文件 下载文件(支持缓存)
Args: Args:
url: 文件 URL url: 文件 URL
dest: 目标路径 dest: 目标路径
timeout: 超时时间(秒) timeout: 超时时间(秒)
use_cache: 是否使用缓存(默认 True)
Returns: Returns:
是否成功 是否成功
@@ -349,7 +356,13 @@ class BaseHandler(TaskHandler, ABC):
timeout = self.config.download_timeout timeout = self.config.download_timeout
try: try:
if use_cache:
# 使用缓存下载
result = self.material_cache.get_or_download(url, dest, timeout=timeout)
else:
# 直接下载(不走缓存)
result = storage.download_file(url, dest, timeout=timeout) result = storage.download_file(url, dest, timeout=timeout)
if result: if result:
file_size = os.path.getsize(dest) if os.path.exists(dest) else 0 file_size = os.path.getsize(dest) if os.path.exists(dest) else 0
logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)") logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)")

291
services/cache.py Normal file
View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
"""
素材缓存服务
提供素材下载缓存功能,避免相同素材重复下载。
"""
import os
import hashlib
import logging
import shutil
import time
from typing import Optional, Tuple
from urllib.parse import urlparse, unquote
from services import storage
logger = logging.getLogger(__name__)
def _extract_cache_key(url: str) -> str:
"""
从 URL 提取缓存键
去除签名等查询参数,保留路径作为唯一标识。
Args:
url: 完整的素材 URL
Returns:
缓存键(URL 路径的 MD5 哈希)
"""
parsed = urlparse(url)
# 使用 scheme + host + path 作为唯一标识(忽略签名等查询参数)
cache_key_source = f"{parsed.scheme}://{parsed.netloc}{unquote(parsed.path)}"
return hashlib.md5(cache_key_source.encode('utf-8')).hexdigest()
def _get_file_extension(url: str) -> str:
"""
从 URL 提取文件扩展名
Args:
url: 素材 URL
Returns:
文件扩展名(如 .mp4, .png),无法识别时返回空字符串
"""
parsed = urlparse(url)
path = unquote(parsed.path)
_, ext = os.path.splitext(path)
return ext.lower() if ext else ''
class MaterialCache:
"""
素材缓存管理器
负责素材文件的缓存存储和检索。
"""
def __init__(self, cache_dir: str, enabled: bool = True, max_size_gb: float = 0):
"""
初始化缓存管理器
Args:
cache_dir: 缓存目录路径
enabled: 是否启用缓存
max_size_gb: 最大缓存大小(GB),0 表示不限制
"""
self.cache_dir = cache_dir
self.enabled = enabled
self.max_size_bytes = int(max_size_gb * 1024 * 1024 * 1024) if max_size_gb > 0 else 0
if self.enabled:
os.makedirs(self.cache_dir, exist_ok=True)
logger.info(f"Material cache initialized: {cache_dir}")
def get_cache_path(self, url: str) -> str:
"""
获取素材的缓存文件路径
Args:
url: 素材 URL
Returns:
缓存文件的完整路径
"""
cache_key = _extract_cache_key(url)
ext = _get_file_extension(url)
filename = f"{cache_key}{ext}"
return os.path.join(self.cache_dir, filename)
def is_cached(self, url: str) -> Tuple[bool, str]:
"""
检查素材是否已缓存
Args:
url: 素材 URL
Returns:
(是否已缓存, 缓存文件路径)
"""
if not self.enabled:
return False, ''
cache_path = self.get_cache_path(url)
exists = os.path.exists(cache_path) and os.path.getsize(cache_path) > 0
return exists, cache_path
def get_or_download(
self,
url: str,
dest: str,
timeout: int = 300,
max_retries: int = 5
) -> bool:
"""
从缓存获取素材,若未缓存则下载并缓存
Args:
url: 素材 URL
dest: 目标文件路径(任务工作目录中的路径)
timeout: 下载超时时间(秒)
max_retries: 最大重试次数
Returns:
是否成功
"""
# 确保目标目录存在
dest_dir = os.path.dirname(dest)
if dest_dir:
os.makedirs(dest_dir, exist_ok=True)
# 缓存未启用时直接下载
if not self.enabled:
return storage.download_file(url, dest, max_retries=max_retries, timeout=timeout)
# 检查缓存
cached, cache_path = self.is_cached(url)
if cached:
# 命中缓存,复制到目标路径
try:
shutil.copy2(cache_path, dest)
# 更新访问时间(用于 LRU 清理)
os.utime(cache_path, None)
file_size = os.path.getsize(dest)
logger.info(f"Cache hit: {url[:80]}... -> {dest} ({file_size} bytes)")
return True
except Exception as e:
logger.warning(f"Failed to copy from cache: {e}, will re-download")
# 缓存复制失败,删除可能损坏的缓存文件
try:
os.remove(cache_path)
except Exception:
pass
# 未命中缓存,下载到缓存目录
logger.debug(f"Cache miss: {url[:80]}...")
# 先下载到临时文件
temp_cache_path = cache_path + '.downloading'
try:
if not storage.download_file(url, temp_cache_path, max_retries=max_retries, timeout=timeout):
# 下载失败,清理临时文件
if os.path.exists(temp_cache_path):
os.remove(temp_cache_path)
return False
# 下载成功,移动到正式缓存路径
if os.path.exists(cache_path):
os.remove(cache_path)
os.rename(temp_cache_path, cache_path)
# 复制到目标路径
shutil.copy2(cache_path, dest)
file_size = os.path.getsize(dest)
logger.info(f"Downloaded and cached: {url[:80]}... ({file_size} bytes)")
# 检查是否需要清理缓存
if self.max_size_bytes > 0:
self._cleanup_if_needed()
return True
except Exception as e:
logger.error(f"Cache download error: {e}")
# 清理临时文件
if os.path.exists(temp_cache_path):
try:
os.remove(temp_cache_path)
except Exception:
pass
return False
def _cleanup_if_needed(self) -> None:
"""
检查并清理缓存(LRU 策略)
当缓存大小超过限制时,删除最久未访问的文件。
"""
if self.max_size_bytes <= 0:
return
try:
# 获取所有缓存文件及其信息
cache_files = []
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith('.downloading'):
continue
file_path = os.path.join(self.cache_dir, filename)
if os.path.isfile(file_path):
stat = os.stat(file_path)
cache_files.append({
'path': file_path,
'size': stat.st_size,
'atime': stat.st_atime
})
total_size += stat.st_size
# 如果未超过限制,无需清理
if total_size <= self.max_size_bytes:
return
# 按访问时间排序(最久未访问的在前)
cache_files.sort(key=lambda x: x['atime'])
# 删除文件直到低于限制的 80%
target_size = int(self.max_size_bytes * 0.8)
deleted_count = 0
for file_info in cache_files:
if total_size <= target_size:
break
try:
os.remove(file_info['path'])
total_size -= file_info['size']
deleted_count += 1
except Exception as e:
logger.warning(f"Failed to delete cache file: {e}")
if deleted_count > 0:
logger.info(f"Cache cleanup: deleted {deleted_count} files, current size: {total_size / (1024*1024*1024):.2f} GB")
except Exception as e:
logger.warning(f"Cache cleanup error: {e}")
def clear(self) -> None:
"""清空所有缓存"""
if not self.enabled:
return
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
os.makedirs(self.cache_dir, exist_ok=True)
logger.info("Cache cleared")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
def get_stats(self) -> dict:
"""
获取缓存统计信息
Returns:
包含缓存统计的字典
"""
if not self.enabled or not os.path.exists(self.cache_dir):
return {'enabled': False, 'file_count': 0, 'total_size_mb': 0}
file_count = 0
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith('.downloading'):
continue
file_path = os.path.join(self.cache_dir, filename)
if os.path.isfile(file_path):
file_count += 1
total_size += os.path.getsize(file_path)
return {
'enabled': True,
'cache_dir': self.cache_dir,
'file_count': file_count,
'total_size_mb': round(total_size / (1024 * 1024), 2),
'max_size_gb': self.max_size_bytes / (1024 * 1024 * 1024) if self.max_size_bytes > 0 else 0
}