You've already forked FrameTour-RenderWorker
feat(重构): 实现新的渲染服务架构
- 新增 RenderTask
This commit is contained in:
12
services/__init__.py
Normal file
12
services/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .render_service import RenderService, DefaultRenderService
|
||||
from .task_service import TaskService, DefaultTaskService
|
||||
from .template_service import TemplateService, DefaultTemplateService
|
||||
|
||||
__all__ = [
|
||||
'RenderService',
|
||||
'DefaultRenderService',
|
||||
'TaskService',
|
||||
'DefaultTaskService',
|
||||
'TemplateService',
|
||||
'DefaultTemplateService'
|
||||
]
|
237
services/render_service.py
Normal file
237
services/render_service.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import subprocess
|
||||
import os
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from entity.render_task import RenderTask
|
||||
from entity.ffmpeg_command_builder import FFmpegCommandBuilder
|
||||
from util.exceptions import RenderError, FFmpegError
|
||||
from util.ffmpeg import probe_video_info, fade_out_audio, handle_ffmpeg_output, subprocess_args
|
||||
from telemetry import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _convert_ffmpeg_task_to_render_task(ffmpeg_task):
|
||||
"""将旧的FfmpegTask转换为新的RenderTask"""
|
||||
from entity.render_task import RenderTask, TaskType
|
||||
|
||||
# 获取输入文件
|
||||
input_files = []
|
||||
for inp in ffmpeg_task.input_file:
|
||||
if hasattr(inp, 'get_output_file'):
|
||||
input_files.append(inp.get_output_file())
|
||||
else:
|
||||
input_files.append(str(inp))
|
||||
|
||||
# 确定任务类型
|
||||
task_type = TaskType.COPY
|
||||
if ffmpeg_task.task_type == 'concat':
|
||||
task_type = TaskType.CONCAT
|
||||
elif ffmpeg_task.task_type == 'encode':
|
||||
task_type = TaskType.ENCODE
|
||||
|
||||
# 创建新任务
|
||||
render_task = RenderTask(
|
||||
input_files=input_files,
|
||||
output_file=ffmpeg_task.output_file,
|
||||
task_type=task_type,
|
||||
resolution=ffmpeg_task.resolution,
|
||||
frame_rate=ffmpeg_task.frame_rate,
|
||||
annexb=ffmpeg_task.annexb,
|
||||
center_cut=ffmpeg_task.center_cut,
|
||||
zoom_cut=ffmpeg_task.zoom_cut,
|
||||
ext_data=getattr(ffmpeg_task, 'ext_data', {})
|
||||
)
|
||||
|
||||
# 复制各种资源
|
||||
render_task.effects = getattr(ffmpeg_task, 'effects', [])
|
||||
render_task.luts = getattr(ffmpeg_task, 'luts', [])
|
||||
render_task.audios = getattr(ffmpeg_task, 'audios', [])
|
||||
render_task.overlays = getattr(ffmpeg_task, 'overlays', [])
|
||||
render_task.subtitles = getattr(ffmpeg_task, 'subtitles', [])
|
||||
|
||||
return render_task
|
||||
|
||||
class RenderService(ABC):
|
||||
"""渲染服务抽象接口"""
|
||||
|
||||
@abstractmethod
|
||||
def render(self, task: Union[RenderTask, 'FfmpegTask']) -> bool:
|
||||
"""
|
||||
执行渲染任务
|
||||
|
||||
Args:
|
||||
task: 渲染任务
|
||||
|
||||
Returns:
|
||||
bool: 渲染是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_video_info(self, file_path: str) -> tuple[int, int, float]:
|
||||
"""
|
||||
获取视频信息
|
||||
|
||||
Args:
|
||||
file_path: 视频文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (width, height, duration)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fade_out_audio(self, file_path: str, duration: float, fade_seconds: float = 2.0) -> str:
|
||||
"""
|
||||
音频淡出处理
|
||||
|
||||
Args:
|
||||
file_path: 音频文件路径
|
||||
duration: 音频总时长
|
||||
fade_seconds: 淡出时长
|
||||
|
||||
Returns:
|
||||
str: 处理后的文件路径
|
||||
"""
|
||||
pass
|
||||
|
||||
class DefaultRenderService(RenderService):
|
||||
"""默认渲染服务实现"""
|
||||
|
||||
def render(self, task: Union[RenderTask, 'FfmpegTask']) -> bool:
|
||||
"""执行渲染任务"""
|
||||
# 兼容旧的FfmpegTask
|
||||
if hasattr(task, 'get_ffmpeg_args'): # 这是FfmpegTask
|
||||
# 使用旧的方式执行
|
||||
return self._render_legacy_ffmpeg_task(task)
|
||||
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("render_task") as span:
|
||||
try:
|
||||
# 验证任务
|
||||
task.validate()
|
||||
span.set_attribute("task.type", task.task_type.value)
|
||||
span.set_attribute("task.input_files", len(task.input_files))
|
||||
span.set_attribute("task.output_file", task.output_file)
|
||||
|
||||
# 检查是否需要处理
|
||||
if not task.need_processing():
|
||||
if len(task.input_files) == 1:
|
||||
task.output_file = task.input_files[0]
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return True
|
||||
|
||||
# 构建FFmpeg命令
|
||||
builder = FFmpegCommandBuilder(task)
|
||||
ffmpeg_args = builder.build_command()
|
||||
|
||||
if not ffmpeg_args:
|
||||
# 不需要处理,直接返回
|
||||
if len(task.input_files) == 1:
|
||||
task.output_file = task.input_files[0]
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return True
|
||||
|
||||
# 执行FFmpeg命令
|
||||
return self._execute_ffmpeg(ffmpeg_args, span)
|
||||
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error(f"Render failed: {e}", exc_info=True)
|
||||
raise RenderError(f"Render failed: {e}") from e
|
||||
|
||||
def _execute_ffmpeg(self, args: list[str], span) -> bool:
|
||||
"""执行FFmpeg命令"""
|
||||
span.set_attribute("ffmpeg.args", " ".join(args))
|
||||
logger.info("Executing FFmpeg: %s", " ".join(args))
|
||||
|
||||
try:
|
||||
# 执行FFmpeg进程
|
||||
process = subprocess.run(
|
||||
["ffmpeg", "-progress", "-", "-loglevel", "error"] + args[1:],
|
||||
stderr=subprocess.PIPE,
|
||||
**subprocess_args(True)
|
||||
)
|
||||
|
||||
span.set_attribute("ffmpeg.return_code", process.returncode)
|
||||
|
||||
# 处理输出
|
||||
if process.stdout:
|
||||
output = handle_ffmpeg_output(process.stdout)
|
||||
span.set_attribute("ffmpeg.output", output)
|
||||
logger.info("FFmpeg output: %s", output)
|
||||
|
||||
# 检查返回码
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr.decode() if process.stderr else "Unknown error"
|
||||
span.set_attribute("ffmpeg.error", error_msg)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error("FFmpeg failed with return code %d: %s", process.returncode, error_msg)
|
||||
raise FFmpegError(
|
||||
f"FFmpeg execution failed",
|
||||
command=args,
|
||||
return_code=process.returncode,
|
||||
stderr=error_msg
|
||||
)
|
||||
|
||||
# 检查输出文件
|
||||
output_file = args[-1] # 输出文件总是最后一个参数
|
||||
if not os.path.exists(output_file):
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
raise RenderError(f"Output file not created: {output_file}")
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize(output_file)
|
||||
span.set_attribute("output.file_size", file_size)
|
||||
|
||||
if file_size < 4096: # 文件过小
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
raise RenderError(f"Output file too small: {file_size} bytes")
|
||||
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
logger.info("FFmpeg execution completed successfully")
|
||||
return True
|
||||
|
||||
except subprocess.SubprocessError as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error("Subprocess error: %s", e)
|
||||
raise FFmpegError(f"Subprocess error: {e}") from e
|
||||
|
||||
def get_video_info(self, file_path: str) -> tuple[int, int, float]:
|
||||
"""获取视频信息"""
|
||||
return probe_video_info(file_path)
|
||||
|
||||
def fade_out_audio(self, file_path: str, duration: float, fade_seconds: float = 2.0) -> str:
|
||||
"""音频淡出处理"""
|
||||
return fade_out_audio(file_path, duration, fade_seconds)
|
||||
|
||||
def _render_legacy_ffmpeg_task(self, ffmpeg_task) -> bool:
|
||||
"""兼容处理旧的FfmpegTask"""
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("render_legacy_ffmpeg_task") as span:
|
||||
try:
|
||||
# 处理依赖任务
|
||||
for sub_task in ffmpeg_task.analyze_input_render_tasks():
|
||||
if not self.render(sub_task):
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
return False
|
||||
|
||||
# 获取FFmpeg参数
|
||||
ffmpeg_args = ffmpeg_task.get_ffmpeg_args()
|
||||
|
||||
if not ffmpeg_args:
|
||||
# 不需要处理,直接返回
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return True
|
||||
|
||||
# 执行FFmpeg命令
|
||||
return self._execute_ffmpeg(ffmpeg_args, span)
|
||||
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error(f"Legacy FFmpeg task render failed: {e}", exc_info=True)
|
||||
raise RenderError(f"Legacy render failed: {e}") from e
|
289
services/task_service.py
Normal file
289
services/task_service.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from entity.render_task import RenderTask
|
||||
from services.render_service import RenderService
|
||||
from services.template_service import TemplateService
|
||||
from util.exceptions import TaskError, TaskValidationError
|
||||
from util import api, oss
|
||||
from telemetry import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskService(ABC):
|
||||
"""任务服务抽象接口"""
|
||||
|
||||
@abstractmethod
|
||||
def process_task(self, task_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
处理任务
|
||||
|
||||
Args:
|
||||
task_info: 任务信息
|
||||
|
||||
Returns:
|
||||
bool: 处理是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_render_task(self, task_info: Dict[str, Any], template_info: Dict[str, Any]) -> RenderTask:
|
||||
"""
|
||||
创建渲染任务
|
||||
|
||||
Args:
|
||||
task_info: 任务信息
|
||||
template_info: 模板信息
|
||||
|
||||
Returns:
|
||||
RenderTask: 渲染任务对象
|
||||
"""
|
||||
pass
|
||||
|
||||
class DefaultTaskService(TaskService):
|
||||
"""默认任务服务实现"""
|
||||
|
||||
def __init__(self, render_service: RenderService, template_service: TemplateService):
|
||||
self.render_service = render_service
|
||||
self.template_service = template_service
|
||||
|
||||
def process_task(self, task_info: Dict[str, Any]) -> bool:
|
||||
"""处理任务"""
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("process_task") as span:
|
||||
try:
|
||||
# 标准化任务信息
|
||||
task_info = api.normalize_task(task_info)
|
||||
span.set_attribute("task.id", task_info.get("id", "unknown"))
|
||||
span.set_attribute("task.template_id", task_info.get("templateId", "unknown"))
|
||||
|
||||
# 获取模板信息
|
||||
template_id = task_info.get("templateId")
|
||||
template_info = self.template_service.get_template(template_id)
|
||||
if not template_info:
|
||||
raise TaskError(f"Template not found: {template_id}")
|
||||
|
||||
# 报告任务开始
|
||||
api.report_task_start(task_info)
|
||||
|
||||
# 创建渲染任务
|
||||
render_task = self.create_render_task(task_info, template_info)
|
||||
|
||||
# 执行渲染
|
||||
success = self.render_service.render(render_task)
|
||||
if not success:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
api.report_task_failed(task_info, "Render failed")
|
||||
return False
|
||||
|
||||
# 获取视频信息
|
||||
width, height, duration = self.render_service.get_video_info(render_task.output_file)
|
||||
span.set_attribute("video.width", width)
|
||||
span.set_attribute("video.height", height)
|
||||
span.set_attribute("video.duration", duration)
|
||||
|
||||
# 音频淡出
|
||||
new_file = self.render_service.fade_out_audio(render_task.output_file, duration)
|
||||
render_task.output_file = new_file
|
||||
|
||||
# 上传文件 - 创建一个兼容对象
|
||||
class TaskCompat:
|
||||
def __init__(self, output_file):
|
||||
self.output_file = output_file
|
||||
def get_output_file(self):
|
||||
return self.output_file
|
||||
|
||||
task_compat = TaskCompat(render_task.output_file)
|
||||
upload_success = api.upload_task_file(task_info, task_compat)
|
||||
if not upload_success:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
api.report_task_failed(task_info, "Upload failed")
|
||||
return False
|
||||
|
||||
# 清理临时文件
|
||||
self._cleanup_temp_files(render_task)
|
||||
|
||||
# 报告任务成功
|
||||
api.report_task_success(task_info, videoInfo={
|
||||
"width": width,
|
||||
"height": height,
|
||||
"duration": duration
|
||||
})
|
||||
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error(f"Task processing failed: {e}", exc_info=True)
|
||||
api.report_task_failed(task_info, str(e))
|
||||
return False
|
||||
|
||||
def create_render_task(self, task_info: Dict[str, Any], template_info: Dict[str, Any]) -> RenderTask:
|
||||
"""创建渲染任务"""
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("create_render_task") as span:
|
||||
# 解析任务参数
|
||||
task_params_str = task_info.get("taskParams", "{}")
|
||||
span.set_attribute("task_params", task_params_str)
|
||||
|
||||
try:
|
||||
task_params = json.loads(task_params_str)
|
||||
task_params_orig = json.loads(task_params_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise TaskValidationError(f"Invalid task params JSON: {e}")
|
||||
|
||||
# 并行下载资源
|
||||
self._download_resources(task_params)
|
||||
|
||||
# 创建子任务列表
|
||||
sub_tasks = []
|
||||
only_if_usage_count = {}
|
||||
|
||||
for part in template_info.get("video_parts", []):
|
||||
source, ext_data = self._parse_video_source(
|
||||
part.get('source'), task_params, template_info
|
||||
)
|
||||
if not source:
|
||||
logger.warning("No video found for part: %s", part)
|
||||
continue
|
||||
|
||||
# 检查only_if条件
|
||||
only_if = part.get('only_if', '')
|
||||
if only_if:
|
||||
only_if_usage_count[only_if] = only_if_usage_count.get(only_if, 0) + 1
|
||||
required_count = only_if_usage_count[only_if]
|
||||
if not self._check_placeholder_exist_with_count(only_if, task_params_orig, required_count):
|
||||
logger.info("Skipping part due to only_if condition: %s (need %d)", only_if, required_count)
|
||||
continue
|
||||
|
||||
# 创建子任务
|
||||
sub_task = self._create_sub_task(part, source, ext_data, template_info)
|
||||
sub_tasks.append(sub_task)
|
||||
|
||||
# 创建主任务
|
||||
output_file = f"out_{task_info.get('id', 'unknown')}.mp4"
|
||||
main_task = RenderTask(
|
||||
input_files=[task.output_file for task in sub_tasks],
|
||||
output_file=output_file,
|
||||
resolution=template_info.get("video_size", ""),
|
||||
frame_rate=template_info.get("frame_rate", 25),
|
||||
center_cut=template_info.get("crop_mode"),
|
||||
zoom_cut=template_info.get("zoom_cut")
|
||||
)
|
||||
|
||||
# 应用整体模板设置
|
||||
overall_template = template_info.get("overall_template", {})
|
||||
self._apply_template_settings(main_task, overall_template, template_info)
|
||||
|
||||
# 设置扩展数据
|
||||
main_task.ext_data = task_info
|
||||
|
||||
span.set_attribute("render_task.sub_tasks", len(sub_tasks))
|
||||
span.set_attribute("render_task.effects", len(main_task.effects))
|
||||
|
||||
return main_task
|
||||
|
||||
def _download_resources(self, task_params: Dict[str, Any]):
|
||||
"""并行下载资源"""
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
for param_list in task_params.values():
|
||||
if isinstance(param_list, list):
|
||||
for param in param_list:
|
||||
url = param.get("url", "")
|
||||
if url.startswith("http"):
|
||||
_, filename = os.path.split(url)
|
||||
executor.submit(oss.download_from_oss, url, filename, True)
|
||||
|
||||
def _parse_video_source(self, source: str, task_params: Dict[str, Any],
|
||||
template_info: Dict[str, Any]) -> tuple[Optional[str], Dict[str, Any]]:
|
||||
"""解析视频源"""
|
||||
if source.startswith('PLACEHOLDER_'):
|
||||
placeholder_id = source.replace('PLACEHOLDER_', '')
|
||||
new_sources = task_params.get(placeholder_id, [])
|
||||
pick_source = {}
|
||||
|
||||
if isinstance(new_sources, list):
|
||||
if len(new_sources) == 0:
|
||||
logger.debug("No video found for placeholder: %s", placeholder_id)
|
||||
return None, pick_source
|
||||
else:
|
||||
pick_source = new_sources.pop(0)
|
||||
new_sources = pick_source.get("url", "")
|
||||
|
||||
if new_sources.startswith("http"):
|
||||
_, source_name = os.path.split(new_sources)
|
||||
oss.download_from_oss(new_sources, source_name, True)
|
||||
return source_name, pick_source
|
||||
return new_sources, pick_source
|
||||
|
||||
return os.path.join(template_info.get("local_path", ""), source), {}
|
||||
|
||||
def _check_placeholder_exist_with_count(self, placeholder_id: str, task_params: Dict[str, Any],
|
||||
required_count: int = 1) -> bool:
|
||||
"""检查占位符是否存在足够数量的片段"""
|
||||
if placeholder_id in task_params:
|
||||
new_sources = task_params.get(placeholder_id, [])
|
||||
if isinstance(new_sources, list):
|
||||
return len(new_sources) >= required_count
|
||||
return required_count <= 1
|
||||
return False
|
||||
|
||||
def _create_sub_task(self, part: Dict[str, Any], source: str, ext_data: Dict[str, Any],
|
||||
template_info: Dict[str, Any]) -> RenderTask:
|
||||
"""创建子任务"""
|
||||
sub_task = RenderTask(
|
||||
input_files=[source],
|
||||
resolution=template_info.get("video_size", ""),
|
||||
frame_rate=template_info.get("frame_rate", 25),
|
||||
annexb=True,
|
||||
center_cut=part.get("crop_mode"),
|
||||
zoom_cut=part.get("zoom_cut"),
|
||||
ext_data=ext_data
|
||||
)
|
||||
|
||||
# 应用部分模板设置
|
||||
self._apply_template_settings(sub_task, part, template_info)
|
||||
|
||||
return sub_task
|
||||
|
||||
def _apply_template_settings(self, task: RenderTask, template_part: Dict[str, Any],
|
||||
template_info: Dict[str, Any]):
|
||||
"""应用模板设置到任务"""
|
||||
# 添加效果
|
||||
for effect in template_part.get('effects', []):
|
||||
task.add_effect(effect)
|
||||
|
||||
# 添加LUT
|
||||
for lut in template_part.get('luts', []):
|
||||
full_path = os.path.join(template_info.get("local_path", ""), lut)
|
||||
task.add_lut(full_path.replace("\\", "/"))
|
||||
|
||||
# 添加音频
|
||||
for audio in template_part.get('audios', []):
|
||||
full_path = os.path.join(template_info.get("local_path", ""), audio)
|
||||
task.add_audios(full_path)
|
||||
|
||||
# 添加覆盖层
|
||||
for overlay in template_part.get('overlays', []):
|
||||
full_path = os.path.join(template_info.get("local_path", ""), overlay)
|
||||
task.add_overlay(full_path)
|
||||
|
||||
def _cleanup_temp_files(self, task: RenderTask):
|
||||
"""清理临时文件"""
|
||||
try:
|
||||
template_dir = os.getenv("TEMPLATE_DIR", "")
|
||||
if template_dir and template_dir not in task.output_file:
|
||||
if os.path.exists(task.output_file):
|
||||
os.remove(task.output_file)
|
||||
logger.info("Cleaned up temp file: %s", task.output_file)
|
||||
else:
|
||||
logger.info("Skipped cleanup of template file: %s", task.output_file)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to cleanup temp file %s: %s", task.output_file, e)
|
266
services/template_service.py
Normal file
266
services/template_service.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from util.exceptions import TemplateError, TemplateNotFoundError, TemplateValidationError
|
||||
from util import api, oss
|
||||
from config.settings import get_storage_config
|
||||
from telemetry import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TemplateService(ABC):
|
||||
"""模板服务抽象接口"""
|
||||
|
||||
@abstractmethod
|
||||
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取模板信息
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模板信息,如果不存在则返回None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_local_templates(self):
|
||||
"""加载本地模板"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download_template(self, template_id: str) -> bool:
|
||||
"""
|
||||
下载模板
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
bool: 下载是否成功
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
验证模板
|
||||
|
||||
Args:
|
||||
template_info: 模板信息
|
||||
|
||||
Returns:
|
||||
bool: 验证是否通过
|
||||
"""
|
||||
pass
|
||||
|
||||
class DefaultTemplateService(TemplateService):
|
||||
"""默认模板服务实现"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates: Dict[str, Dict[str, Any]] = {}
|
||||
self.storage_config = get_storage_config()
|
||||
|
||||
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取模板信息"""
|
||||
if template_id not in self.templates:
|
||||
# 尝试下载模板
|
||||
if not self.download_template(template_id):
|
||||
return None
|
||||
return self.templates.get(template_id)
|
||||
|
||||
def load_local_templates(self):
|
||||
"""加载本地模板"""
|
||||
template_dir = self.storage_config.template_dir
|
||||
if not os.path.exists(template_dir):
|
||||
logger.warning("Template directory does not exist: %s", template_dir)
|
||||
return
|
||||
|
||||
for template_name in os.listdir(template_dir):
|
||||
if template_name.startswith("_") or template_name.startswith("."):
|
||||
continue
|
||||
|
||||
target_path = os.path.join(template_dir, template_name)
|
||||
if os.path.isdir(target_path):
|
||||
try:
|
||||
self._load_template(template_name, target_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load template %s: %s", template_name, e)
|
||||
|
||||
def download_template(self, template_id: str) -> bool:
|
||||
"""下载模板"""
|
||||
tracer = get_tracer(__name__)
|
||||
with tracer.start_as_current_span("download_template") as span:
|
||||
try:
|
||||
span.set_attribute("template.id", template_id)
|
||||
|
||||
# 获取远程模板信息
|
||||
template_info = api.get_template_info(template_id)
|
||||
if template_info is None:
|
||||
logger.warning("Failed to get template info: %s", template_id)
|
||||
return False
|
||||
|
||||
local_path = template_info.get('local_path')
|
||||
if not local_path:
|
||||
local_path = os.path.join(self.storage_config.template_dir, str(template_id))
|
||||
template_info['local_path'] = local_path
|
||||
|
||||
# 创建本地目录
|
||||
if not os.path.isdir(local_path):
|
||||
os.makedirs(local_path)
|
||||
|
||||
# 下载模板资源
|
||||
overall_template = template_info.get('overall_template', {})
|
||||
video_parts = template_info.get('video_parts', [])
|
||||
|
||||
self._download_template_assets(overall_template, template_info)
|
||||
for video_part in video_parts:
|
||||
self._download_template_assets(video_part, template_info)
|
||||
|
||||
# 保存模板定义文件
|
||||
template_file = os.path.join(local_path, 'template.json')
|
||||
with open(template_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(template_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 加载到内存
|
||||
self._load_template(template_id, local_path)
|
||||
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
logger.info("Template downloaded successfully: %s", template_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
logger.error("Failed to download template %s: %s", template_id, e)
|
||||
return False
|
||||
|
||||
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
||||
"""验证模板"""
|
||||
try:
|
||||
local_path = template_info.get("local_path")
|
||||
if not local_path:
|
||||
raise TemplateValidationError("Template missing local_path")
|
||||
|
||||
# 验证视频部分
|
||||
for video_part in template_info.get("video_parts", []):
|
||||
self._validate_template_part(video_part, local_path)
|
||||
|
||||
# 验证整体模板
|
||||
overall_template = template_info.get("overall_template", {})
|
||||
if overall_template:
|
||||
self._validate_template_part(overall_template, local_path)
|
||||
|
||||
return True
|
||||
|
||||
except TemplateValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise TemplateValidationError(f"Template validation failed: {e}")
|
||||
|
||||
def _load_template(self, template_name: str, local_path: str):
|
||||
"""加载单个模板"""
|
||||
logger.info("Loading template: %s (%s)", template_name, local_path)
|
||||
|
||||
template_def_file = os.path.join(local_path, "template.json")
|
||||
if not os.path.exists(template_def_file):
|
||||
raise TemplateNotFoundError(f"Template definition file not found: {template_def_file}")
|
||||
|
||||
try:
|
||||
with open(template_def_file, 'r', encoding='utf-8') as f:
|
||||
template_info = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise TemplateError(f"Invalid template JSON: {e}")
|
||||
|
||||
template_info["local_path"] = local_path
|
||||
|
||||
try:
|
||||
self.validate_template(template_info)
|
||||
self.templates[template_name] = template_info
|
||||
logger.info("Template loaded successfully: %s", template_name)
|
||||
except TemplateValidationError as e:
|
||||
logger.error("Template validation failed for %s: %s. Attempting to re-download.", template_name, e)
|
||||
# 模板验证失败,尝试重新下载
|
||||
if self.download_template(template_name):
|
||||
logger.info("Template re-downloaded successfully: %s", template_name)
|
||||
else:
|
||||
logger.error("Failed to re-download template: %s", template_name)
|
||||
raise
|
||||
|
||||
def _download_template_assets(self, template_part: Dict[str, Any], template_info: Dict[str, Any]):
|
||||
"""下载模板资源"""
|
||||
local_path = template_info['local_path']
|
||||
|
||||
# 下载源文件
|
||||
if 'source' in template_part:
|
||||
source = template_part['source']
|
||||
if isinstance(source, str) and source.startswith("http"):
|
||||
_, filename = os.path.split(source)
|
||||
new_file_path = os.path.join(local_path, filename)
|
||||
oss.download_from_oss(source, new_file_path)
|
||||
|
||||
if filename.endswith(".mp4"):
|
||||
from util.ffmpeg import re_encode_and_annexb
|
||||
new_file_path = re_encode_and_annexb(new_file_path)
|
||||
|
||||
template_part['source'] = os.path.relpath(new_file_path, local_path)
|
||||
|
||||
# 下载覆盖层
|
||||
if 'overlays' in template_part:
|
||||
for i, overlay in enumerate(template_part['overlays']):
|
||||
if isinstance(overlay, str) and overlay.startswith("http"):
|
||||
_, filename = os.path.split(overlay)
|
||||
oss.download_from_oss(overlay, os.path.join(local_path, filename))
|
||||
template_part['overlays'][i] = filename
|
||||
|
||||
# 下载LUT
|
||||
if 'luts' in template_part:
|
||||
for i, lut in enumerate(template_part['luts']):
|
||||
if isinstance(lut, str) and lut.startswith("http"):
|
||||
_, filename = os.path.split(lut)
|
||||
oss.download_from_oss(lut, os.path.join(local_path, filename))
|
||||
template_part['luts'][i] = filename
|
||||
|
||||
# 下载音频
|
||||
if 'audios' in template_part:
|
||||
for i, audio in enumerate(template_part['audios']):
|
||||
if isinstance(audio, str) and audio.startswith("http"):
|
||||
_, filename = os.path.split(audio)
|
||||
oss.download_from_oss(audio, os.path.join(local_path, filename))
|
||||
template_part['audios'][i] = filename
|
||||
|
||||
def _validate_template_part(self, template_part: Dict[str, Any], base_dir: str):
|
||||
"""验证模板部分"""
|
||||
# 验证源文件
|
||||
source_file = template_part.get("source", "")
|
||||
if source_file and not source_file.startswith("http") and not source_file.startswith("PLACEHOLDER_"):
|
||||
if not os.path.isabs(source_file):
|
||||
source_file = os.path.join(base_dir, source_file)
|
||||
if not os.path.exists(source_file):
|
||||
raise TemplateValidationError(f"Source file not found: {source_file}")
|
||||
|
||||
# 验证音频文件
|
||||
for audio in template_part.get("audios", []):
|
||||
if not os.path.isabs(audio):
|
||||
audio = os.path.join(base_dir, audio)
|
||||
if not os.path.exists(audio):
|
||||
raise TemplateValidationError(f"Audio file not found: {audio}")
|
||||
|
||||
# 验证LUT文件
|
||||
for lut in template_part.get("luts", []):
|
||||
if not os.path.isabs(lut):
|
||||
lut = os.path.join(base_dir, lut)
|
||||
if not os.path.exists(lut):
|
||||
raise TemplateValidationError(f"LUT file not found: {lut}")
|
||||
|
||||
# 验证覆盖层文件
|
||||
for overlay in template_part.get("overlays", []):
|
||||
if not os.path.isabs(overlay):
|
||||
overlay = os.path.join(base_dir, overlay)
|
||||
if not os.path.exists(overlay):
|
||||
raise TemplateValidationError(f"Overlay file not found: {overlay}")
|
Reference in New Issue
Block a user