You've already forked FrameTour-RenderWorker
289 lines
12 KiB
Python
289 lines
12 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from opentelemetry.trace import Status, StatusCode
|
|
|
|
from entity.render_task import RenderTask
|
|
from services.render_service import RenderService
|
|
from services.template_service import TemplateService
|
|
from util.exceptions import TaskError, TaskValidationError
|
|
from util import api, oss
|
|
from telemetry import get_tracer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskService(ABC):
|
|
"""任务服务抽象接口"""
|
|
|
|
@abstractmethod
|
|
def process_task(self, task_info: Dict[str, Any]) -> bool:
|
|
"""
|
|
处理任务
|
|
|
|
Args:
|
|
task_info: 任务信息
|
|
|
|
Returns:
|
|
bool: 处理是否成功
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def create_render_task(self, task_info: Dict[str, Any], template_info: Dict[str, Any]) -> RenderTask:
|
|
"""
|
|
创建渲染任务
|
|
|
|
Args:
|
|
task_info: 任务信息
|
|
template_info: 模板信息
|
|
|
|
Returns:
|
|
RenderTask: 渲染任务对象
|
|
"""
|
|
pass
|
|
|
|
class DefaultTaskService(TaskService):
|
|
"""默认任务服务实现"""
|
|
|
|
def __init__(self, render_service: RenderService, template_service: TemplateService):
|
|
self.render_service = render_service
|
|
self.template_service = template_service
|
|
|
|
def process_task(self, task_info: Dict[str, Any]) -> bool:
|
|
"""处理任务"""
|
|
tracer = get_tracer(__name__)
|
|
with tracer.start_as_current_span("process_task") as span:
|
|
try:
|
|
# 标准化任务信息
|
|
task_info = api.normalize_task(task_info)
|
|
span.set_attribute("task.id", task_info.get("id", "unknown"))
|
|
span.set_attribute("task.template_id", task_info.get("templateId", "unknown"))
|
|
|
|
# 获取模板信息
|
|
template_id = task_info.get("templateId")
|
|
template_info = self.template_service.get_template(template_id)
|
|
if not template_info:
|
|
raise TaskError(f"Template not found: {template_id}")
|
|
|
|
# 报告任务开始
|
|
api.report_task_start(task_info)
|
|
|
|
# 创建渲染任务
|
|
render_task = self.create_render_task(task_info, template_info)
|
|
|
|
# 执行渲染
|
|
success = self.render_service.render(render_task)
|
|
if not success:
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
api.report_task_failed(task_info, "Render failed")
|
|
return False
|
|
|
|
# 获取视频信息
|
|
width, height, duration = self.render_service.get_video_info(render_task.output_file)
|
|
span.set_attribute("video.width", width)
|
|
span.set_attribute("video.height", height)
|
|
span.set_attribute("video.duration", duration)
|
|
|
|
# 音频淡出
|
|
new_file = self.render_service.fade_out_audio(render_task.output_file, duration)
|
|
render_task.output_file = new_file
|
|
|
|
# 上传文件 - 创建一个兼容对象
|
|
class TaskCompat:
|
|
def __init__(self, output_file):
|
|
self.output_file = output_file
|
|
def get_output_file(self):
|
|
return self.output_file
|
|
|
|
task_compat = TaskCompat(render_task.output_file)
|
|
upload_success = api.upload_task_file(task_info, task_compat)
|
|
if not upload_success:
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
api.report_task_failed(task_info, "Upload failed")
|
|
return False
|
|
|
|
# 清理临时文件
|
|
self._cleanup_temp_files(render_task)
|
|
|
|
# 报告任务成功
|
|
api.report_task_success(task_info, videoInfo={
|
|
"width": width,
|
|
"height": height,
|
|
"duration": duration
|
|
})
|
|
|
|
span.set_status(Status(StatusCode.OK))
|
|
return True
|
|
|
|
except Exception as e:
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
logger.error(f"Task processing failed: {e}", exc_info=True)
|
|
api.report_task_failed(task_info, str(e))
|
|
return False
|
|
|
|
def create_render_task(self, task_info: Dict[str, Any], template_info: Dict[str, Any]) -> RenderTask:
|
|
"""创建渲染任务"""
|
|
tracer = get_tracer(__name__)
|
|
with tracer.start_as_current_span("create_render_task") as span:
|
|
# 解析任务参数
|
|
task_params_str = task_info.get("taskParams", "{}")
|
|
span.set_attribute("task_params", task_params_str)
|
|
|
|
try:
|
|
task_params = json.loads(task_params_str)
|
|
task_params_orig = json.loads(task_params_str)
|
|
except json.JSONDecodeError as e:
|
|
raise TaskValidationError(f"Invalid task params JSON: {e}")
|
|
|
|
# 并行下载资源
|
|
self._download_resources(task_params)
|
|
|
|
# 创建子任务列表
|
|
sub_tasks = []
|
|
only_if_usage_count = {}
|
|
|
|
for part in template_info.get("video_parts", []):
|
|
source, ext_data = self._parse_video_source(
|
|
part.get('source'), task_params, template_info
|
|
)
|
|
if not source:
|
|
logger.warning("No video found for part: %s", part)
|
|
continue
|
|
|
|
# 检查only_if条件
|
|
only_if = part.get('only_if', '')
|
|
if only_if:
|
|
only_if_usage_count[only_if] = only_if_usage_count.get(only_if, 0) + 1
|
|
required_count = only_if_usage_count[only_if]
|
|
if not self._check_placeholder_exist_with_count(only_if, task_params_orig, required_count):
|
|
logger.info("Skipping part due to only_if condition: %s (need %d)", only_if, required_count)
|
|
continue
|
|
|
|
# 创建子任务
|
|
sub_task = self._create_sub_task(part, source, ext_data, template_info)
|
|
sub_tasks.append(sub_task)
|
|
|
|
# 创建主任务
|
|
output_file = f"out_{task_info.get('id', 'unknown')}.mp4"
|
|
main_task = RenderTask(
|
|
input_files=[task.output_file for task in sub_tasks],
|
|
output_file=output_file,
|
|
resolution=template_info.get("video_size", ""),
|
|
frame_rate=template_info.get("frame_rate", 25),
|
|
center_cut=template_info.get("crop_mode"),
|
|
zoom_cut=template_info.get("zoom_cut")
|
|
)
|
|
|
|
# 应用整体模板设置
|
|
overall_template = template_info.get("overall_template", {})
|
|
self._apply_template_settings(main_task, overall_template, template_info)
|
|
|
|
# 设置扩展数据
|
|
main_task.ext_data = task_info
|
|
|
|
span.set_attribute("render_task.sub_tasks", len(sub_tasks))
|
|
span.set_attribute("render_task.effects", len(main_task.effects))
|
|
|
|
return main_task
|
|
|
|
def _download_resources(self, task_params: Dict[str, Any]):
|
|
"""并行下载资源"""
|
|
with ThreadPoolExecutor(max_workers=8) as executor:
|
|
for param_list in task_params.values():
|
|
if isinstance(param_list, list):
|
|
for param in param_list:
|
|
url = param.get("url", "")
|
|
if url.startswith("http"):
|
|
_, filename = os.path.split(url)
|
|
executor.submit(oss.download_from_oss, url, filename, True)
|
|
|
|
def _parse_video_source(self, source: str, task_params: Dict[str, Any],
|
|
template_info: Dict[str, Any]) -> tuple[Optional[str], Dict[str, Any]]:
|
|
"""解析视频源"""
|
|
if source.startswith('PLACEHOLDER_'):
|
|
placeholder_id = source.replace('PLACEHOLDER_', '')
|
|
new_sources = task_params.get(placeholder_id, [])
|
|
pick_source = {}
|
|
|
|
if isinstance(new_sources, list):
|
|
if len(new_sources) == 0:
|
|
logger.debug("No video found for placeholder: %s", placeholder_id)
|
|
return None, pick_source
|
|
else:
|
|
pick_source = new_sources.pop(0)
|
|
new_sources = pick_source.get("url", "")
|
|
|
|
if new_sources.startswith("http"):
|
|
_, source_name = os.path.split(new_sources)
|
|
oss.download_from_oss(new_sources, source_name, True)
|
|
return source_name, pick_source
|
|
return new_sources, pick_source
|
|
|
|
return os.path.join(template_info.get("local_path", ""), source), {}
|
|
|
|
def _check_placeholder_exist_with_count(self, placeholder_id: str, task_params: Dict[str, Any],
|
|
required_count: int = 1) -> bool:
|
|
"""检查占位符是否存在足够数量的片段"""
|
|
if placeholder_id in task_params:
|
|
new_sources = task_params.get(placeholder_id, [])
|
|
if isinstance(new_sources, list):
|
|
return len(new_sources) >= required_count
|
|
return required_count <= 1
|
|
return False
|
|
|
|
def _create_sub_task(self, part: Dict[str, Any], source: str, ext_data: Dict[str, Any],
|
|
template_info: Dict[str, Any]) -> RenderTask:
|
|
"""创建子任务"""
|
|
sub_task = RenderTask(
|
|
input_files=[source],
|
|
resolution=template_info.get("video_size", ""),
|
|
frame_rate=template_info.get("frame_rate", 25),
|
|
annexb=True,
|
|
center_cut=part.get("crop_mode"),
|
|
zoom_cut=part.get("zoom_cut"),
|
|
ext_data=ext_data
|
|
)
|
|
|
|
# 应用部分模板设置
|
|
self._apply_template_settings(sub_task, part, template_info)
|
|
|
|
return sub_task
|
|
|
|
def _apply_template_settings(self, task: RenderTask, template_part: Dict[str, Any],
|
|
template_info: Dict[str, Any]):
|
|
"""应用模板设置到任务"""
|
|
# 添加效果
|
|
for effect in template_part.get('effects', []):
|
|
task.add_effect(effect)
|
|
|
|
# 添加LUT
|
|
for lut in template_part.get('luts', []):
|
|
full_path = os.path.join(template_info.get("local_path", ""), lut)
|
|
task.add_lut(full_path.replace("\\", "/"))
|
|
|
|
# 添加音频
|
|
for audio in template_part.get('audios', []):
|
|
full_path = os.path.join(template_info.get("local_path", ""), audio)
|
|
task.add_audios(full_path)
|
|
|
|
# 添加覆盖层
|
|
for overlay in template_part.get('overlays', []):
|
|
full_path = os.path.join(template_info.get("local_path", ""), overlay)
|
|
task.add_overlay(full_path)
|
|
|
|
def _cleanup_temp_files(self, task: RenderTask):
|
|
"""清理临时文件"""
|
|
try:
|
|
template_dir = os.getenv("TEMPLATE_DIR", "")
|
|
if template_dir and template_dir not in task.output_file:
|
|
if os.path.exists(task.output_file):
|
|
os.remove(task.output_file)
|
|
logger.info("Cleaned up temp file: %s", task.output_file)
|
|
else:
|
|
logger.info("Skipped cleanup of template file: %s", task.output_file)
|
|
except OSError as e:
|
|
logger.warning("Failed to cleanup temp file %s: %s", task.output_file, e) |