You've already forked FrameTour-RenderWorker
266 lines
11 KiB
Python
266 lines
11 KiB
Python
import json
|
|
import os
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Optional
|
|
|
|
from opentelemetry.trace import Status, StatusCode
|
|
|
|
from util.exceptions import TemplateError, TemplateNotFoundError, TemplateValidationError
|
|
from util import api, oss
|
|
from config.settings import get_storage_config
|
|
from telemetry import get_tracer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TemplateService(ABC):
|
|
"""模板服务抽象接口"""
|
|
|
|
@abstractmethod
|
|
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取模板信息
|
|
|
|
Args:
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
Dict[str, Any]: 模板信息,如果不存在则返回None
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_local_templates(self):
|
|
"""加载本地模板"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def download_template(self, template_id: str) -> bool:
|
|
"""
|
|
下载模板
|
|
|
|
Args:
|
|
template_id: 模板ID
|
|
|
|
Returns:
|
|
bool: 下载是否成功
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
|
"""
|
|
验证模板
|
|
|
|
Args:
|
|
template_info: 模板信息
|
|
|
|
Returns:
|
|
bool: 验证是否通过
|
|
"""
|
|
pass
|
|
|
|
class DefaultTemplateService(TemplateService):
|
|
"""默认模板服务实现"""
|
|
|
|
def __init__(self):
|
|
self.templates: Dict[str, Dict[str, Any]] = {}
|
|
self.storage_config = get_storage_config()
|
|
|
|
def get_template(self, template_id: str) -> Optional[Dict[str, Any]]:
|
|
"""获取模板信息"""
|
|
if template_id not in self.templates:
|
|
# 尝试下载模板
|
|
if not self.download_template(template_id):
|
|
return None
|
|
return self.templates.get(template_id)
|
|
|
|
def load_local_templates(self):
|
|
"""加载本地模板"""
|
|
template_dir = self.storage_config.template_dir
|
|
if not os.path.exists(template_dir):
|
|
logger.warning("Template directory does not exist: %s", template_dir)
|
|
return
|
|
|
|
for template_name in os.listdir(template_dir):
|
|
if template_name.startswith("_") or template_name.startswith("."):
|
|
continue
|
|
|
|
target_path = os.path.join(template_dir, template_name)
|
|
if os.path.isdir(target_path):
|
|
try:
|
|
self._load_template(template_name, target_path)
|
|
except Exception as e:
|
|
logger.error("Failed to load template %s: %s", template_name, e)
|
|
|
|
def download_template(self, template_id: str) -> bool:
|
|
"""下载模板"""
|
|
tracer = get_tracer(__name__)
|
|
with tracer.start_as_current_span("download_template") as span:
|
|
try:
|
|
span.set_attribute("template.id", template_id)
|
|
|
|
# 获取远程模板信息
|
|
template_info = api.get_template_info(template_id)
|
|
if template_info is None:
|
|
logger.warning("Failed to get template info: %s", template_id)
|
|
return False
|
|
|
|
local_path = template_info.get('local_path')
|
|
if not local_path:
|
|
local_path = os.path.join(self.storage_config.template_dir, str(template_id))
|
|
template_info['local_path'] = local_path
|
|
|
|
# 创建本地目录
|
|
if not os.path.isdir(local_path):
|
|
os.makedirs(local_path)
|
|
|
|
# 下载模板资源
|
|
overall_template = template_info.get('overall_template', {})
|
|
video_parts = template_info.get('video_parts', [])
|
|
|
|
self._download_template_assets(overall_template, template_info)
|
|
for video_part in video_parts:
|
|
self._download_template_assets(video_part, template_info)
|
|
|
|
# 保存模板定义文件
|
|
template_file = os.path.join(local_path, 'template.json')
|
|
with open(template_file, 'w', encoding='utf-8') as f:
|
|
json.dump(template_info, f, ensure_ascii=False, indent=2)
|
|
|
|
# 加载到内存
|
|
self._load_template(template_id, local_path)
|
|
|
|
span.set_status(Status(StatusCode.OK))
|
|
logger.info("Template downloaded successfully: %s", template_id)
|
|
return True
|
|
|
|
except Exception as e:
|
|
span.set_status(Status(StatusCode.ERROR))
|
|
logger.error("Failed to download template %s: %s", template_id, e)
|
|
return False
|
|
|
|
def validate_template(self, template_info: Dict[str, Any]) -> bool:
|
|
"""验证模板"""
|
|
try:
|
|
local_path = template_info.get("local_path")
|
|
if not local_path:
|
|
raise TemplateValidationError("Template missing local_path")
|
|
|
|
# 验证视频部分
|
|
for video_part in template_info.get("video_parts", []):
|
|
self._validate_template_part(video_part, local_path)
|
|
|
|
# 验证整体模板
|
|
overall_template = template_info.get("overall_template", {})
|
|
if overall_template:
|
|
self._validate_template_part(overall_template, local_path)
|
|
|
|
return True
|
|
|
|
except TemplateValidationError:
|
|
raise
|
|
except Exception as e:
|
|
raise TemplateValidationError(f"Template validation failed: {e}")
|
|
|
|
def _load_template(self, template_name: str, local_path: str):
|
|
"""加载单个模板"""
|
|
logger.info("Loading template: %s (%s)", template_name, local_path)
|
|
|
|
template_def_file = os.path.join(local_path, "template.json")
|
|
if not os.path.exists(template_def_file):
|
|
raise TemplateNotFoundError(f"Template definition file not found: {template_def_file}")
|
|
|
|
try:
|
|
with open(template_def_file, 'r', encoding='utf-8') as f:
|
|
template_info = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
raise TemplateError(f"Invalid template JSON: {e}")
|
|
|
|
template_info["local_path"] = local_path
|
|
|
|
try:
|
|
self.validate_template(template_info)
|
|
self.templates[template_name] = template_info
|
|
logger.info("Template loaded successfully: %s", template_name)
|
|
except TemplateValidationError as e:
|
|
logger.error("Template validation failed for %s: %s. Attempting to re-download.", template_name, e)
|
|
# 模板验证失败,尝试重新下载
|
|
if self.download_template(template_name):
|
|
logger.info("Template re-downloaded successfully: %s", template_name)
|
|
else:
|
|
logger.error("Failed to re-download template: %s", template_name)
|
|
raise
|
|
|
|
def _download_template_assets(self, template_part: Dict[str, Any], template_info: Dict[str, Any]):
|
|
"""下载模板资源"""
|
|
local_path = template_info['local_path']
|
|
|
|
# 下载源文件
|
|
if 'source' in template_part:
|
|
source = template_part['source']
|
|
if isinstance(source, str) and source.startswith("http"):
|
|
_, filename = os.path.split(source)
|
|
new_file_path = os.path.join(local_path, filename)
|
|
oss.download_from_oss(source, new_file_path)
|
|
|
|
if filename.endswith(".mp4"):
|
|
from util.ffmpeg import re_encode_and_annexb
|
|
new_file_path = re_encode_and_annexb(new_file_path)
|
|
|
|
template_part['source'] = os.path.relpath(new_file_path, local_path)
|
|
|
|
# 下载覆盖层
|
|
if 'overlays' in template_part:
|
|
for i, overlay in enumerate(template_part['overlays']):
|
|
if isinstance(overlay, str) and overlay.startswith("http"):
|
|
_, filename = os.path.split(overlay)
|
|
oss.download_from_oss(overlay, os.path.join(local_path, filename))
|
|
template_part['overlays'][i] = filename
|
|
|
|
# 下载LUT
|
|
if 'luts' in template_part:
|
|
for i, lut in enumerate(template_part['luts']):
|
|
if isinstance(lut, str) and lut.startswith("http"):
|
|
_, filename = os.path.split(lut)
|
|
oss.download_from_oss(lut, os.path.join(local_path, filename))
|
|
template_part['luts'][i] = filename
|
|
|
|
# 下载音频
|
|
if 'audios' in template_part:
|
|
for i, audio in enumerate(template_part['audios']):
|
|
if isinstance(audio, str) and audio.startswith("http"):
|
|
_, filename = os.path.split(audio)
|
|
oss.download_from_oss(audio, os.path.join(local_path, filename))
|
|
template_part['audios'][i] = filename
|
|
|
|
def _validate_template_part(self, template_part: Dict[str, Any], base_dir: str):
|
|
"""验证模板部分"""
|
|
# 验证源文件
|
|
source_file = template_part.get("source", "")
|
|
if source_file and not source_file.startswith("http") and not source_file.startswith("PLACEHOLDER_"):
|
|
if not os.path.isabs(source_file):
|
|
source_file = os.path.join(base_dir, source_file)
|
|
if not os.path.exists(source_file):
|
|
raise TemplateValidationError(f"Source file not found: {source_file}")
|
|
|
|
# 验证音频文件
|
|
for audio in template_part.get("audios", []):
|
|
if not os.path.isabs(audio):
|
|
audio = os.path.join(base_dir, audio)
|
|
if not os.path.exists(audio):
|
|
raise TemplateValidationError(f"Audio file not found: {audio}")
|
|
|
|
# 验证LUT文件
|
|
for lut in template_part.get("luts", []):
|
|
if not os.path.isabs(lut):
|
|
lut = os.path.join(base_dir, lut)
|
|
if not os.path.exists(lut):
|
|
raise TemplateValidationError(f"LUT file not found: {lut}")
|
|
|
|
# 验证覆盖层文件
|
|
for overlay in template_part.get("overlays", []):
|
|
if not os.path.isabs(overlay):
|
|
overlay = os.path.join(base_dir, overlay)
|
|
if not os.path.exists(overlay):
|
|
raise TemplateValidationError(f"Overlay file not found: {overlay}") |