You've already forked FrameTour-RenderWorker
feat(base): 添加单任务内文件传输并发功能
- 引入 ThreadPoolExecutor 实现并行下载和上传 - 新增 download_files_parallel 和 upload_files_parallel 方法 - 添加任务传输并发数配置选项 TASK_DOWNLOAD_CONCURRENCY 和 TASK_UPLOAD_CONCURRENCY - 实现并发数配置的环境变量解析和验证逻辑 - 在多个处理器中应用并行下载优化文件获取性能 - 更新 .env.example 配置文件模板 - 移除 FFmpeg 命令日志长度限制
This commit is contained in:
261
handlers/base.py
261
handlers/base.py
@@ -12,6 +12,7 @@ import shutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from abc import ABC
|
||||
from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
|
||||
|
||||
@@ -23,7 +24,13 @@ from domain.result import TaskResult, ErrorCode
|
||||
from domain.config import WorkerConfig
|
||||
from services import storage
|
||||
from services.cache import MaterialCache
|
||||
from util.tracing import get_current_task_context, mark_span_error, start_span
|
||||
from util.tracing import (
|
||||
bind_trace_context,
|
||||
capture_otel_context,
|
||||
get_current_task_context,
|
||||
mark_span_error,
|
||||
start_span,
|
||||
)
|
||||
from constant import (
|
||||
HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA,
|
||||
VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA
|
||||
@@ -274,6 +281,9 @@ class BaseHandler(TaskHandler, ABC):
|
||||
|
||||
# 线程本地存储:用于存储当前线程的 GPU 设备索引
|
||||
_thread_local = threading.local()
|
||||
DEFAULT_TASK_DOWNLOAD_CONCURRENCY = 4
|
||||
DEFAULT_TASK_UPLOAD_CONCURRENCY = 2
|
||||
MAX_TASK_TRANSFER_CONCURRENCY = 16
|
||||
|
||||
def __init__(self, config: WorkerConfig, api_client: 'APIClientV2'):
|
||||
"""
|
||||
@@ -290,6 +300,251 @@ class BaseHandler(TaskHandler, ABC):
|
||||
enabled=config.cache_enabled,
|
||||
max_size_gb=config.cache_max_size_gb
|
||||
)
|
||||
self.task_download_concurrency = self._resolve_task_transfer_concurrency(
|
||||
"TASK_DOWNLOAD_CONCURRENCY",
|
||||
self.DEFAULT_TASK_DOWNLOAD_CONCURRENCY
|
||||
)
|
||||
self.task_upload_concurrency = self._resolve_task_transfer_concurrency(
|
||||
"TASK_UPLOAD_CONCURRENCY",
|
||||
self.DEFAULT_TASK_UPLOAD_CONCURRENCY
|
||||
)
|
||||
|
||||
def _resolve_task_transfer_concurrency(self, env_name: str, default_value: int) -> int:
|
||||
"""读取并规范化任务内传输并发数配置。"""
|
||||
raw_value = os.getenv(env_name)
|
||||
if raw_value is None or not raw_value.strip():
|
||||
return default_value
|
||||
|
||||
try:
|
||||
parsed_value = int(raw_value.strip())
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid {env_name} value '{raw_value}', using default {default_value}"
|
||||
)
|
||||
return default_value
|
||||
|
||||
if parsed_value < 1:
|
||||
logger.warning(f"{env_name} must be >= 1, forcing to 1")
|
||||
return 1
|
||||
|
||||
if parsed_value > self.MAX_TASK_TRANSFER_CONCURRENCY:
|
||||
logger.warning(
|
||||
f"{env_name}={parsed_value} exceeds limit {self.MAX_TASK_TRANSFER_CONCURRENCY}, "
|
||||
f"using {self.MAX_TASK_TRANSFER_CONCURRENCY}"
|
||||
)
|
||||
return self.MAX_TASK_TRANSFER_CONCURRENCY
|
||||
|
||||
return parsed_value
|
||||
|
||||
def download_files_parallel(
|
||||
self,
|
||||
download_jobs: List[Dict[str, Any]],
|
||||
timeout: Optional[int] = None
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
单任务内并行下载多个文件。
|
||||
|
||||
Args:
|
||||
download_jobs: 下载任务列表。每项字段:
|
||||
- key: 唯一标识
|
||||
- url: 下载地址
|
||||
- dest: 目标文件路径
|
||||
- required: 是否关键文件(可选,默认 True)
|
||||
- use_cache: 是否使用缓存(可选,默认 True)
|
||||
timeout: 单文件下载超时(秒)
|
||||
|
||||
Returns:
|
||||
key -> 结果字典:
|
||||
- success: 是否成功
|
||||
- url: 原始 URL
|
||||
- dest: 目标文件路径
|
||||
- required: 是否关键文件
|
||||
"""
|
||||
if not download_jobs:
|
||||
return {}
|
||||
|
||||
normalized_jobs: List[Dict[str, Any]] = []
|
||||
seen_keys = set()
|
||||
for download_job in download_jobs:
|
||||
job_key = str(download_job.get("key", "")).strip()
|
||||
job_url = str(download_job.get("url", "")).strip()
|
||||
job_dest = str(download_job.get("dest", "")).strip()
|
||||
if not job_key or not job_url or not job_dest:
|
||||
raise ValueError("Each download job must include non-empty key/url/dest")
|
||||
if job_key in seen_keys:
|
||||
raise ValueError(f"Duplicate download job key: {job_key}")
|
||||
seen_keys.add(job_key)
|
||||
normalized_jobs.append({
|
||||
"key": job_key,
|
||||
"url": job_url,
|
||||
"dest": job_dest,
|
||||
"required": bool(download_job.get("required", True)),
|
||||
"use_cache": bool(download_job.get("use_cache", True)),
|
||||
})
|
||||
|
||||
if timeout is None:
|
||||
timeout = self.config.download_timeout
|
||||
|
||||
parent_otel_context = capture_otel_context()
|
||||
task_context = get_current_task_context()
|
||||
task_prefix = f"[task:{task_context.task_id}] " if task_context else ""
|
||||
results: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _run_download_job(download_job: Dict[str, Any]) -> bool:
|
||||
with bind_trace_context(parent_otel_context, task_context):
|
||||
return self.download_file(
|
||||
download_job["url"],
|
||||
download_job["dest"],
|
||||
timeout=timeout,
|
||||
use_cache=download_job["use_cache"],
|
||||
)
|
||||
|
||||
max_workers = min(self.task_download_concurrency, len(normalized_jobs))
|
||||
if max_workers <= 1:
|
||||
for download_job in normalized_jobs:
|
||||
is_success = _run_download_job(download_job)
|
||||
results[download_job["key"]] = {
|
||||
"success": is_success,
|
||||
"url": download_job["url"],
|
||||
"dest": download_job["dest"],
|
||||
"required": download_job["required"],
|
||||
}
|
||||
else:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="TaskDownload",
|
||||
) as executor:
|
||||
future_to_job = {
|
||||
executor.submit(_run_download_job, download_job): download_job
|
||||
for download_job in normalized_jobs
|
||||
}
|
||||
for completed_future in as_completed(future_to_job):
|
||||
download_job = future_to_job[completed_future]
|
||||
is_success = False
|
||||
try:
|
||||
is_success = bool(completed_future.result())
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"{task_prefix}Parallel download raised exception for "
|
||||
f"key={download_job['key']}: {exc}"
|
||||
)
|
||||
results[download_job["key"]] = {
|
||||
"success": is_success,
|
||||
"url": download_job["url"],
|
||||
"dest": download_job["dest"],
|
||||
"required": download_job["required"],
|
||||
}
|
||||
|
||||
success_count = sum(1 for item in results.values() if item["success"])
|
||||
logger.debug(
|
||||
f"{task_prefix}Parallel download completed: {success_count}/{len(normalized_jobs)}"
|
||||
)
|
||||
return results
|
||||
|
||||
def upload_files_parallel(
|
||||
self,
|
||||
upload_jobs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
单任务内并行上传多个文件。
|
||||
|
||||
Args:
|
||||
upload_jobs: 上传任务列表。每项字段:
|
||||
- key: 唯一标识
|
||||
- task_id: 任务 ID
|
||||
- file_type: 文件类型(video/audio/ts/mp4)
|
||||
- file_path: 本地文件路径
|
||||
- file_name: 文件名(可选)
|
||||
- required: 是否关键文件(可选,默认 True)
|
||||
|
||||
Returns:
|
||||
key -> 结果字典:
|
||||
- success: 是否成功
|
||||
- url: 上传后的访问 URL(失败为 None)
|
||||
- file_path: 本地文件路径
|
||||
- required: 是否关键文件
|
||||
"""
|
||||
if not upload_jobs:
|
||||
return {}
|
||||
|
||||
normalized_jobs: List[Dict[str, Any]] = []
|
||||
seen_keys = set()
|
||||
for upload_job in upload_jobs:
|
||||
job_key = str(upload_job.get("key", "")).strip()
|
||||
task_id = str(upload_job.get("task_id", "")).strip()
|
||||
file_type = str(upload_job.get("file_type", "")).strip()
|
||||
file_path = str(upload_job.get("file_path", "")).strip()
|
||||
if not job_key or not task_id or not file_type or not file_path:
|
||||
raise ValueError(
|
||||
"Each upload job must include non-empty key/task_id/file_type/file_path"
|
||||
)
|
||||
if job_key in seen_keys:
|
||||
raise ValueError(f"Duplicate upload job key: {job_key}")
|
||||
seen_keys.add(job_key)
|
||||
normalized_jobs.append({
|
||||
"key": job_key,
|
||||
"task_id": task_id,
|
||||
"file_type": file_type,
|
||||
"file_path": file_path,
|
||||
"file_name": upload_job.get("file_name"),
|
||||
"required": bool(upload_job.get("required", True)),
|
||||
})
|
||||
|
||||
parent_otel_context = capture_otel_context()
|
||||
task_context = get_current_task_context()
|
||||
task_prefix = f"[task:{task_context.task_id}] " if task_context else ""
|
||||
results: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _run_upload_job(upload_job: Dict[str, Any]) -> Optional[str]:
|
||||
with bind_trace_context(parent_otel_context, task_context):
|
||||
return self.upload_file(
|
||||
upload_job["task_id"],
|
||||
upload_job["file_type"],
|
||||
upload_job["file_path"],
|
||||
upload_job.get("file_name")
|
||||
)
|
||||
|
||||
max_workers = min(self.task_upload_concurrency, len(normalized_jobs))
|
||||
if max_workers <= 1:
|
||||
for upload_job in normalized_jobs:
|
||||
result_url = _run_upload_job(upload_job)
|
||||
results[upload_job["key"]] = {
|
||||
"success": bool(result_url),
|
||||
"url": result_url,
|
||||
"file_path": upload_job["file_path"],
|
||||
"required": upload_job["required"],
|
||||
}
|
||||
else:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=max_workers,
|
||||
thread_name_prefix="TaskUpload",
|
||||
) as executor:
|
||||
future_to_job = {
|
||||
executor.submit(_run_upload_job, upload_job): upload_job
|
||||
for upload_job in normalized_jobs
|
||||
}
|
||||
for completed_future in as_completed(future_to_job):
|
||||
upload_job = future_to_job[completed_future]
|
||||
result_url = None
|
||||
try:
|
||||
result_url = completed_future.result()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"{task_prefix}Parallel upload raised exception for "
|
||||
f"key={upload_job['key']}: {exc}"
|
||||
)
|
||||
results[upload_job["key"]] = {
|
||||
"success": bool(result_url),
|
||||
"url": result_url,
|
||||
"file_path": upload_job["file_path"],
|
||||
"required": upload_job["required"],
|
||||
}
|
||||
|
||||
success_count = sum(1 for item in results.values() if item["success"])
|
||||
logger.debug(
|
||||
f"{task_prefix}Parallel upload completed: {success_count}/{len(normalized_jobs)}"
|
||||
)
|
||||
return results
|
||||
|
||||
# ========== GPU 设备管理 ==========
|
||||
|
||||
@@ -538,10 +793,8 @@ class BaseHandler(TaskHandler, ABC):
|
||||
if cmd_to_run and cmd_to_run[0] == 'ffmpeg' and '-loglevel' not in cmd_to_run:
|
||||
cmd_to_run[1:1] = ['-loglevel', FFMPEG_LOGLEVEL]
|
||||
|
||||
# 日志记录命令(限制长度)
|
||||
# 日志记录命令(不限制长度)
|
||||
cmd_str = ' '.join(cmd_to_run)
|
||||
if len(cmd_str) > 500:
|
||||
cmd_str = cmd_str[:500] + '...'
|
||||
logger.info(f"[task:{task_id}] FFmpeg: {cmd_str}")
|
||||
|
||||
with start_span(
|
||||
|
||||
Reference in New Issue
Block a user