feat(base): 添加单任务内文件传输并发功能

- 引入 ThreadPoolExecutor 实现并行下载和上传
- 新增 download_files_parallel 和 upload_files_parallel 方法
- 添加任务传输并发数配置选项 TASK_DOWNLOAD_CONCURRENCY 和 TASK_UPLOAD_CONCURRENCY
- 实现并发数配置的环境变量解析和验证逻辑
- 在多个处理器中应用并行下载优化文件获取性能
- 更新 .env.example 配置文件模板
- 移除 FFmpeg 命令日志长度限制
This commit is contained in:
2026-02-07 00:38:43 +08:00
parent d955def63c
commit 88aa3adca1
7 changed files with 435 additions and 88 deletions

View File

@@ -12,6 +12,7 @@ import shutil
import tempfile
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from abc import ABC
from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
@@ -23,7 +24,13 @@ from domain.result import TaskResult, ErrorCode
from domain.config import WorkerConfig
from services import storage
from services.cache import MaterialCache
from util.tracing import get_current_task_context, mark_span_error, start_span
from util.tracing import (
bind_trace_context,
capture_otel_context,
get_current_task_context,
mark_span_error,
start_span,
)
from constant import (
HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA,
VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA
@@ -274,6 +281,9 @@ class BaseHandler(TaskHandler, ABC):
# 线程本地存储:用于存储当前线程的 GPU 设备索引
_thread_local = threading.local()
DEFAULT_TASK_DOWNLOAD_CONCURRENCY = 4
DEFAULT_TASK_UPLOAD_CONCURRENCY = 2
MAX_TASK_TRANSFER_CONCURRENCY = 16
def __init__(self, config: WorkerConfig, api_client: 'APIClientV2'):
"""
@@ -290,6 +300,251 @@ class BaseHandler(TaskHandler, ABC):
enabled=config.cache_enabled,
max_size_gb=config.cache_max_size_gb
)
self.task_download_concurrency = self._resolve_task_transfer_concurrency(
"TASK_DOWNLOAD_CONCURRENCY",
self.DEFAULT_TASK_DOWNLOAD_CONCURRENCY
)
self.task_upload_concurrency = self._resolve_task_transfer_concurrency(
"TASK_UPLOAD_CONCURRENCY",
self.DEFAULT_TASK_UPLOAD_CONCURRENCY
)
def _resolve_task_transfer_concurrency(self, env_name: str, default_value: int) -> int:
"""读取并规范化任务内传输并发数配置。"""
raw_value = os.getenv(env_name)
if raw_value is None or not raw_value.strip():
return default_value
try:
parsed_value = int(raw_value.strip())
except ValueError:
logger.warning(
f"Invalid {env_name} value '{raw_value}', using default {default_value}"
)
return default_value
if parsed_value < 1:
logger.warning(f"{env_name} must be >= 1, forcing to 1")
return 1
if parsed_value > self.MAX_TASK_TRANSFER_CONCURRENCY:
logger.warning(
f"{env_name}={parsed_value} exceeds limit {self.MAX_TASK_TRANSFER_CONCURRENCY}, "
f"using {self.MAX_TASK_TRANSFER_CONCURRENCY}"
)
return self.MAX_TASK_TRANSFER_CONCURRENCY
return parsed_value
def download_files_parallel(
self,
download_jobs: List[Dict[str, Any]],
timeout: Optional[int] = None
) -> Dict[str, Dict[str, Any]]:
"""
单任务内并行下载多个文件。
Args:
download_jobs: 下载任务列表。每项字段:
- key: 唯一标识
- url: 下载地址
- dest: 目标文件路径
- required: 是否关键文件(可选,默认 True)
- use_cache: 是否使用缓存(可选,默认 True)
timeout: 单文件下载超时(秒)
Returns:
key -> 结果字典:
- success: 是否成功
- url: 原始 URL
- dest: 目标文件路径
- required: 是否关键文件
"""
if not download_jobs:
return {}
normalized_jobs: List[Dict[str, Any]] = []
seen_keys = set()
for download_job in download_jobs:
job_key = str(download_job.get("key", "")).strip()
job_url = str(download_job.get("url", "")).strip()
job_dest = str(download_job.get("dest", "")).strip()
if not job_key or not job_url or not job_dest:
raise ValueError("Each download job must include non-empty key/url/dest")
if job_key in seen_keys:
raise ValueError(f"Duplicate download job key: {job_key}")
seen_keys.add(job_key)
normalized_jobs.append({
"key": job_key,
"url": job_url,
"dest": job_dest,
"required": bool(download_job.get("required", True)),
"use_cache": bool(download_job.get("use_cache", True)),
})
if timeout is None:
timeout = self.config.download_timeout
parent_otel_context = capture_otel_context()
task_context = get_current_task_context()
task_prefix = f"[task:{task_context.task_id}] " if task_context else ""
results: Dict[str, Dict[str, Any]] = {}
def _run_download_job(download_job: Dict[str, Any]) -> bool:
with bind_trace_context(parent_otel_context, task_context):
return self.download_file(
download_job["url"],
download_job["dest"],
timeout=timeout,
use_cache=download_job["use_cache"],
)
max_workers = min(self.task_download_concurrency, len(normalized_jobs))
if max_workers <= 1:
for download_job in normalized_jobs:
is_success = _run_download_job(download_job)
results[download_job["key"]] = {
"success": is_success,
"url": download_job["url"],
"dest": download_job["dest"],
"required": download_job["required"],
}
else:
with ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix="TaskDownload",
) as executor:
future_to_job = {
executor.submit(_run_download_job, download_job): download_job
for download_job in normalized_jobs
}
for completed_future in as_completed(future_to_job):
download_job = future_to_job[completed_future]
is_success = False
try:
is_success = bool(completed_future.result())
except Exception as exc:
logger.error(
f"{task_prefix}Parallel download raised exception for "
f"key={download_job['key']}: {exc}"
)
results[download_job["key"]] = {
"success": is_success,
"url": download_job["url"],
"dest": download_job["dest"],
"required": download_job["required"],
}
success_count = sum(1 for item in results.values() if item["success"])
logger.debug(
f"{task_prefix}Parallel download completed: {success_count}/{len(normalized_jobs)}"
)
return results
def upload_files_parallel(
self,
upload_jobs: List[Dict[str, Any]]
) -> Dict[str, Dict[str, Any]]:
"""
单任务内并行上传多个文件。
Args:
upload_jobs: 上传任务列表。每项字段:
- key: 唯一标识
- task_id: 任务 ID
- file_type: 文件类型(video/audio/ts/mp4)
- file_path: 本地文件路径
- file_name: 文件名(可选)
- required: 是否关键文件(可选,默认 True)
Returns:
key -> 结果字典:
- success: 是否成功
- url: 上传后的访问 URL(失败为 None)
- file_path: 本地文件路径
- required: 是否关键文件
"""
if not upload_jobs:
return {}
normalized_jobs: List[Dict[str, Any]] = []
seen_keys = set()
for upload_job in upload_jobs:
job_key = str(upload_job.get("key", "")).strip()
task_id = str(upload_job.get("task_id", "")).strip()
file_type = str(upload_job.get("file_type", "")).strip()
file_path = str(upload_job.get("file_path", "")).strip()
if not job_key or not task_id or not file_type or not file_path:
raise ValueError(
"Each upload job must include non-empty key/task_id/file_type/file_path"
)
if job_key in seen_keys:
raise ValueError(f"Duplicate upload job key: {job_key}")
seen_keys.add(job_key)
normalized_jobs.append({
"key": job_key,
"task_id": task_id,
"file_type": file_type,
"file_path": file_path,
"file_name": upload_job.get("file_name"),
"required": bool(upload_job.get("required", True)),
})
parent_otel_context = capture_otel_context()
task_context = get_current_task_context()
task_prefix = f"[task:{task_context.task_id}] " if task_context else ""
results: Dict[str, Dict[str, Any]] = {}
def _run_upload_job(upload_job: Dict[str, Any]) -> Optional[str]:
with bind_trace_context(parent_otel_context, task_context):
return self.upload_file(
upload_job["task_id"],
upload_job["file_type"],
upload_job["file_path"],
upload_job.get("file_name")
)
max_workers = min(self.task_upload_concurrency, len(normalized_jobs))
if max_workers <= 1:
for upload_job in normalized_jobs:
result_url = _run_upload_job(upload_job)
results[upload_job["key"]] = {
"success": bool(result_url),
"url": result_url,
"file_path": upload_job["file_path"],
"required": upload_job["required"],
}
else:
with ThreadPoolExecutor(
max_workers=max_workers,
thread_name_prefix="TaskUpload",
) as executor:
future_to_job = {
executor.submit(_run_upload_job, upload_job): upload_job
for upload_job in normalized_jobs
}
for completed_future in as_completed(future_to_job):
upload_job = future_to_job[completed_future]
result_url = None
try:
result_url = completed_future.result()
except Exception as exc:
logger.error(
f"{task_prefix}Parallel upload raised exception for "
f"key={upload_job['key']}: {exc}"
)
results[upload_job["key"]] = {
"success": bool(result_url),
"url": result_url,
"file_path": upload_job["file_path"],
"required": upload_job["required"],
}
success_count = sum(1 for item in results.values() if item["success"])
logger.debug(
f"{task_prefix}Parallel upload completed: {success_count}/{len(normalized_jobs)}"
)
return results
# ========== GPU 设备管理 ==========
@@ -538,10 +793,8 @@ class BaseHandler(TaskHandler, ABC):
if cmd_to_run and cmd_to_run[0] == 'ffmpeg' and '-loglevel' not in cmd_to_run:
cmd_to_run[1:1] = ['-loglevel', FFMPEG_LOGLEVEL]
# 日志记录命令(限制长度)
# 日志记录命令(限制长度)
cmd_str = ' '.join(cmd_to_run)
if len(cmd_str) > 500:
cmd_str = cmd_str[:500] + '...'
logger.info(f"[task:{task_id}] FFmpeg: {cmd_str}")
with start_span(