You've already forked FrameTour-RenderWorker
288 lines
10 KiB
Python
288 lines
10 KiB
Python
import json
|
|
import os
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Optional
|
|
|
|
from opentelemetry.trace import Status, StatusCode
|
|
|
|
from util.exceptions import (
|
|
TemplateError,
|
|
TemplateNotFoundError,
|
|
TemplateValidationError,
|
|
)
|
|
from util import api, oss
|
|
from config.settings import get_storage_config
|
|
from telemetry import get_tracer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TemplateService(ABC):
|
|
"""模板服务抽象接口"""
|
|
|
|
@abstractmethod
|
|
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取模板信息
|
|
|
|
Args:
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
Dict[str, Any]: 模板信息,如果不存在则返回None
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_local_templates(self):
|
|
"""加载本地模板"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def download_template(self, template_id: str) -> bool:
|
|
"""
|
|
下载模板
|
|
|
|
Args:
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
bool: 下载是否成功
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
|
"""
|
|
验证模板
|
|
|
|
Args:
|
|
template_info: 模板信息
|
|
|
|
Returns:
|
|
bool: 验证是否通过
|
|
"""
|
|
pass
|
|
|
|
|
|
class DefaultTemplateService(TemplateService):
|
|
"""默认模板服务实现"""
|
|
|
|
def __init__(self):
|
|
self.templates: Dict[str, Dict[str, Any]] = {}
|
|
self.storage_config = get_storage_config()
|
|
|
|
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|
"""获取模板信息"""
|
|
if template_id not in self.templates:
|
|
# 尝试下载模板
|
|
if not self.download_template(template_id):
|
|
return None
|
|
return self.templates.get(template_id)
|
|
|
|
def load_local_templates(self):
|
|
"""加载本地模板"""
|
|
template_dir = self.storage_config.template_dir
|
|
if not os.path.exists(template_dir):
|
|
logger.warning("Template directory does not exist: %s", template_dir)
|
|
return
|
|
|
|
for template_name in os.listdir(template_dir):
|
|
if template_name.startswith("_") or template_name.startswith("."):
|
|
continue
|
|
|
|
target_path = os.path.join(template_dir, template_name)
|
|
if os.path.isdir(target_path):
|
|
try:
|
|
self._load_template(template_name, target_path)
|
|
except Exception as e:
|
|
logger.error("Failed to load template %s: %s", template_name, e)
|
|
|
|
def download_template(self, template_id: str) -> bool:
|
|
"""下载模板"""
|
|
tracer = get_tracer(__name__)
|
|
with tracer.start_as_current_span("download_template") as span:
|
|
try:
|
|
span.set_attribute("template.id", template_id)
|
|
|
|
# 获取远程模板信息
|
|
template_info = api.get_template_info(template_id)
|
|
if template_info is None:
|
|
logger.warning("Failed to get template info: %s", template_id)
|
|
return False
|
|
|
|
local_path = template_info.get("local_path")
|
|
if not local_path:
|
|
local_path = os.path.join(
|
|
self.storage_config.template_dir, str(template_id)
|
|
)
|
|
template_info["local_path"] = local_path
|
|
|
|
# 创建本地目录
|
|
if not os.path.isdir(local_path):
|
|
os.makedirs(local_path)
|
|
|
|
# 下载模板资源
|
|
overall_template = template_info.get("overall_template", {})
|
|
video_parts = template_info.get("video_parts", [])
|
|
|
|
self._download_template_assets(overall_template, template_info)
|
|
for video_part in video_parts:
|
|
self._download_template_assets(video_part, template_info)
|
|
|
|
# 保存模板定义文件
|
|
template_file = os.path.join(local_path, "template.json")
|
|
with open(template_file, "w", encoding="utf-8") as f:
|
|
json.dump(template_info, f, ensure_ascii=False, indent=2)
|
|
|
|
# 加载到内存
|
|
self._load_template(template_id, local_path)
|
|
|
|
span.set_status(Status(StatusCode.OK))
|
|
logger.info("Template downloaded successfully: %s", template_id)
|
|
return True
|
|
|
|
except Exception as e:
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
logger.error("Failed to download template %s: %s", template_id, e)
|
|
return False
|
|
|
|
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
|
"""验证模板"""
|
|
try:
|
|
local_path = template_info.get("local_path")
|
|
if not local_path:
|
|
raise TemplateValidationError("Template missing local_path")
|
|
|
|
# 验证视频部分
|
|
for video_part in template_info.get("video_parts", []):
|
|
self._validate_template_part(video_part, local_path)
|
|
|
|
# 验证整体模板
|
|
overall_template = template_info.get("overall_template", {})
|
|
if overall_template:
|
|
self._validate_template_part(overall_template, local_path)
|
|
|
|
return True
|
|
|
|
except TemplateValidationError:
|
|
raise
|
|
except Exception as e:
|
|
raise TemplateValidationError(f"Template validation failed: {e}")
|
|
|
|
def _load_template(self, template_name: str, local_path: str):
|
|
"""加载单个模板"""
|
|
logger.info("Loading template: %s (%s)", template_name, local_path)
|
|
|
|
template_def_file = os.path.join(local_path, "template.json")
|
|
if not os.path.exists(template_def_file):
|
|
raise TemplateNotFoundError(
|
|
f"Template definition file not found: {template_def_file}"
|
|
)
|
|
|
|
try:
|
|
with open(template_def_file, "r", encoding="utf-8") as f:
|
|
template_info = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
raise TemplateError(f"Invalid template JSON: {e}")
|
|
|
|
template_info["local_path"] = local_path
|
|
|
|
try:
|
|
self.validate_template(template_info)
|
|
self.templates[template_name] = template_info
|
|
logger.info("Template loaded successfully: %s", template_name)
|
|
except TemplateValidationError as e:
|
|
logger.error(
|
|
"Template validation failed for %s: %s. Attempting to re-download.",
|
|
template_name,
|
|
e,
|
|
)
|
|
# 模板验证失败,尝试重新下载
|
|
if self.download_template(template_name):
|
|
logger.info("Template re-downloaded successfully: %s", template_name)
|
|
else:
|
|
logger.error("Failed to re-download template: %s", template_name)
|
|
raise
|
|
|
|
def _download_template_assets(
|
|
self, template_part: Dict[str, Any], template_info: Dict[str, Any]
|
|
):
|
|
"""下载模板资源"""
|
|
local_path = template_info["local_path"]
|
|
|
|
# 下载源文件
|
|
if "source" in template_part:
|
|
source = template_part["source"]
|
|
if isinstance(source, str) and source.startswith("http"):
|
|
_, filename = os.path.split(source)
|
|
new_file_path = os.path.join(local_path, filename)
|
|
oss.download_from_oss(source, new_file_path)
|
|
|
|
if filename.endswith(".mp4"):
|
|
from util.ffmpeg import re_encode_and_annexb
|
|
|
|
new_file_path = re_encode_and_annexb(new_file_path)
|
|
|
|
template_part["source"] = os.path.relpath(new_file_path, local_path)
|
|
|
|
# 下载覆盖层
|
|
if "overlays" in template_part:
|
|
for i, overlay in enumerate(template_part["overlays"]):
|
|
if isinstance(overlay, str) and overlay.startswith("http"):
|
|
_, filename = os.path.split(overlay)
|
|
oss.download_from_oss(overlay, os.path.join(local_path, filename))
|
|
template_part["overlays"][i] = filename
|
|
|
|
# 下载LUT
|
|
if "luts" in template_part:
|
|
for i, lut in enumerate(template_part["luts"]):
|
|
if isinstance(lut, str) and lut.startswith("http"):
|
|
_, filename = os.path.split(lut)
|
|
oss.download_from_oss(lut, os.path.join(local_path, filename))
|
|
template_part["luts"][i] = filename
|
|
|
|
# 下载音频
|
|
if "audios" in template_part:
|
|
for i, audio in enumerate(template_part["audios"]):
|
|
if isinstance(audio, str) and audio.startswith("http"):
|
|
_, filename = os.path.split(audio)
|
|
oss.download_from_oss(audio, os.path.join(local_path, filename))
|
|
template_part["audios"][i] = filename
|
|
|
|
def _validate_template_part(self, template_part: Dict[str, Any], base_dir: str):
|
|
"""验证模板部分"""
|
|
# 验证源文件
|
|
source_file = template_part.get("source", "")
|
|
if (
|
|
source_file
|
|
and not source_file.startswith("http")
|
|
and not source_file.startswith("PLACEHOLDER_")
|
|
):
|
|
if not os.path.isabs(source_file):
|
|
source_file = os.path.join(base_dir, source_file)
|
|
if not os.path.exists(source_file):
|
|
raise TemplateValidationError(f"Source file not found: {source_file}")
|
|
|
|
# 验证音频文件
|
|
for audio in template_part.get("audios", []):
|
|
if not os.path.isabs(audio):
|
|
audio = os.path.join(base_dir, audio)
|
|
if not os.path.exists(audio):
|
|
raise TemplateValidationError(f"Audio file not found: {audio}")
|
|
|
|
# 验证LUT文件
|
|
for lut in template_part.get("luts", []):
|
|
if not os.path.isabs(lut):
|
|
lut = os.path.join(base_dir, lut)
|
|
if not os.path.exists(lut):
|
|
raise TemplateValidationError(f"LUT file not found: {lut}")
|
|
|
|
# 验证覆盖层文件
|
|
for overlay in template_part.get("overlays", []):
|
|
if not os.path.isabs(overlay):
|
|
overlay = os.path.join(base_dir, overlay)
|
|
if not os.path.exists(overlay):
|
|
raise TemplateValidationError(f"Overlay file not found: {overlay}")
|