feat(tracing): 集成 OpenTelemetry 链路追踪功能

- 在 base.py 中添加文件下载、上传和 FFmpeg 执行的链路追踪
- 在 api_client.py 中实现 API 请求的链路追踪和错误标记
- 在 lease_service.py 中添加租约续期的链路追踪支持
- 在 task_executor.py 中集成任务执行的完整链路追踪
- 新增 util/tracing.py 工具模块提供统一的追踪上下文管理
- 在 .env.example 中添加 OTEL 配置选项
- 在 index.py 中初始化和关闭链路追踪功能
This commit is contained in:
2026-02-07 00:11:01 +08:00
parent c9a6133be9
commit 9b373dea34
8 changed files with 549 additions and 149 deletions

View File

@@ -63,3 +63,9 @@ HW_ACCEL=none
#UPLOAD_METHOD=rclone #UPLOAD_METHOD=rclone
#RCLONE_CONFIG_FILE= # rclone 配置文件路径 #RCLONE_CONFIG_FILE= # rclone 配置文件路径
#RCLONE_REPLACE_MAP="https://oss.example.com|alioss://bucket" #RCLONE_REPLACE_MAP="https://oss.example.com|alioss://bucket"
# ===================
# OTel 链路追踪
# ===================
# 是否启用 OTel 追踪(默认 true)
#OTEL_ENABLED=true

View File

@@ -15,12 +15,15 @@ import threading
from abc import ABC from abc import ABC
from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING
from opentelemetry.trace import SpanKind
from core.handler import TaskHandler from core.handler import TaskHandler
from domain.task import Task from domain.task import Task
from domain.result import TaskResult, ErrorCode from domain.result import TaskResult, ErrorCode
from domain.config import WorkerConfig from domain.config import WorkerConfig
from services import storage from services import storage
from services.cache import MaterialCache from services.cache import MaterialCache
from util.tracing import mark_span_error, start_span
from constant import ( from constant import (
HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA, HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA,
VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA
@@ -410,21 +413,30 @@ class BaseHandler(TaskHandler, ABC):
if timeout is None: if timeout is None:
timeout = self.config.download_timeout timeout = self.config.download_timeout
try: with start_span(
if use_cache: "render.task.file.download",
# 使用缓存下载 kind=SpanKind.CLIENT,
result = self.material_cache.get_or_download(url, dest, timeout=timeout) attributes={
else: "render.file.destination": dest,
# 直接下载(不走缓存) "render.file.use_cache": use_cache,
result = storage.download_file(url, dest, timeout=timeout) },
) as span:
try:
if use_cache:
result = self.material_cache.get_or_download(url, dest, timeout=timeout)
else:
result = storage.download_file(url, dest, timeout=timeout)
if result: if result:
file_size = os.path.getsize(dest) if os.path.exists(dest) else 0 file_size = os.path.getsize(dest) if os.path.exists(dest) else 0
logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)") logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)")
return result if span is not None:
except Exception as e: span.set_attribute("render.file.size_bytes", file_size)
logger.error(f"Download failed: {url} -> {e}") return result
return False except Exception as e:
mark_span_error(span, str(e), ErrorCode.E_INPUT_UNAVAILABLE.value)
logger.error(f"Download failed: {url} -> {e}")
return False
def upload_file( def upload_file(
self, self,
@@ -445,37 +457,45 @@ class BaseHandler(TaskHandler, ABC):
Returns: Returns:
访问 URL,失败返回 None 访问 URL,失败返回 None
""" """
# 获取上传 URL with start_span(
upload_info = self.api_client.get_upload_url(task_id, file_type, file_name) "render.task.file.upload",
if not upload_info: kind=SpanKind.CLIENT,
logger.error(f"[task:{task_id}] Failed to get upload URL") attributes={
return None "render.file.type": file_type,
"render.file.path": file_path,
},
) as span:
upload_info = self.api_client.get_upload_url(task_id, file_type, file_name)
if not upload_info:
logger.error(f"[task:{task_id}] Failed to get upload URL")
return None
upload_url = upload_info.get('uploadUrl') upload_url = upload_info.get('uploadUrl')
access_url = upload_info.get('accessUrl') access_url = upload_info.get('accessUrl')
if not upload_url: if not upload_url:
logger.error(f"[task:{task_id}] Invalid upload URL response") logger.error(f"[task:{task_id}] Invalid upload URL response")
return None return None
# 上传文件 try:
try: result = storage.upload_file(upload_url, file_path, timeout=self.config.upload_timeout)
result = storage.upload_file(upload_url, file_path, timeout=self.config.upload_timeout) if result:
if result: file_size = os.path.getsize(file_path)
file_size = os.path.getsize(file_path) logger.info(f"[task:{task_id}] Uploaded: {file_path} ({file_size} bytes)")
logger.info(f"[task:{task_id}] Uploaded: {file_path} ({file_size} bytes)") if span is not None:
span.set_attribute("render.file.size_bytes", file_size)
# 将上传成功的文件加入缓存 if access_url:
if access_url: self.material_cache.add_to_cache(access_url, file_path)
self.material_cache.add_to_cache(access_url, file_path)
return access_url
return access_url
else:
logger.error(f"[task:{task_id}] Upload failed: {file_path}") logger.error(f"[task:{task_id}] Upload failed: {file_path}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"[task:{task_id}] Upload error: {e}") mark_span_error(span, str(e), ErrorCode.E_UPLOAD_FAILED.value)
return None logger.error(f"[task:{task_id}] Upload error: {e}")
return None
def run_ffmpeg( def run_ffmpeg(
self, self,
@@ -507,29 +527,42 @@ class BaseHandler(TaskHandler, ABC):
cmd_str = cmd_str[:500] + '...' cmd_str = cmd_str[:500] + '...'
logger.info(f"[task:{task_id}] FFmpeg: {cmd_str}") logger.info(f"[task:{task_id}] FFmpeg: {cmd_str}")
try: with start_span(
run_args = subprocess_args(False) "render.task.ffmpeg.run",
run_args['stdout'] = subprocess.DEVNULL attributes={
run_args['stderr'] = subprocess.PIPE "render.ffmpeg.timeout_seconds": timeout,
result = subprocess.run( "render.ffmpeg.command": cmd_str,
cmd_to_run, },
timeout=timeout, ) as span:
**run_args try:
) run_args = subprocess_args(False)
run_args['stdout'] = subprocess.DEVNULL
run_args['stderr'] = subprocess.PIPE
result = subprocess.run(
cmd_to_run,
timeout=timeout,
**run_args
)
if result.returncode != 0: if span is not None:
stderr = (result.stderr or b'').decode('utf-8', errors='replace')[:1000] span.set_attribute("render.ffmpeg.return_code", result.returncode)
logger.error(f"[task:{task_id}] FFmpeg failed (code={result.returncode}): {stderr}")
if result.returncode != 0:
stderr = (result.stderr or b'').decode('utf-8', errors='replace')[:1000]
logger.error(f"[task:{task_id}] FFmpeg failed (code={result.returncode}): {stderr}")
mark_span_error(span, stderr or "ffmpeg failed", ErrorCode.E_FFMPEG_FAILED.value)
return False
return True
except subprocess.TimeoutExpired:
logger.error(f"[task:{task_id}] FFmpeg timeout after {timeout}s")
mark_span_error(span, f"timeout after {timeout}s", ErrorCode.E_TIMEOUT.value)
return False
except Exception as e:
logger.error(f"[task:{task_id}] FFmpeg error: {e}")
mark_span_error(span, str(e), ErrorCode.E_FFMPEG_FAILED.value)
return False return False
return True
except subprocess.TimeoutExpired:
logger.error(f"[task:{task_id}] FFmpeg timeout after {timeout}s")
return False
except Exception as e:
logger.error(f"[task:{task_id}] FFmpeg error: {e}")
return False
def probe_duration(self, file_path: str) -> Optional[float]: def probe_duration(self, file_path: str) -> Optional[float]:
""" """

View File

@@ -34,6 +34,7 @@ from domain.config import WorkerConfig
from services.api_client import APIClientV2 from services.api_client import APIClientV2
from services.task_executor import TaskExecutor from services.task_executor import TaskExecutor
from constant import SOFTWARE_VERSION from constant import SOFTWARE_VERSION
from util.tracing import initialize_tracing, shutdown_tracing
# 日志配置 # 日志配置
def setup_logging(): def setup_logging():
@@ -113,6 +114,9 @@ class WorkerV2:
logger.error(f"Configuration error: {e}") logger.error(f"Configuration error: {e}")
sys.exit(1) sys.exit(1)
tracing_enabled = initialize_tracing(self.config.worker_id, SOFTWARE_VERSION)
logger.info("OTel tracing %s", "enabled" if tracing_enabled else "disabled")
# 初始化 API 客户端 # 初始化 API 客户端
self.api_client = APIClientV2(self.config) self.api_client = APIClientV2(self.config)
@@ -212,6 +216,7 @@ class WorkerV2:
# 关闭 API 客户端 # 关闭 API 客户端
self.api_client.close() self.api_client.close()
shutdown_tracing()
logger.info("Worker stopped") logger.info("Worker stopped")

View File

@@ -10,10 +10,14 @@ import subprocess
import time import time
import requests import requests
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from urllib.parse import urlparse
from opentelemetry.trace import SpanKind, Status, StatusCode
from domain.task import Task from domain.task import Task
from domain.config import WorkerConfig from domain.config import WorkerConfig
from util.system import get_hw_accel_info_str from util.system import get_hw_accel_info_str
from util.tracing import inject_trace_headers, mark_span_error, start_span
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,6 +59,45 @@ class APIClientV2:
'Accept': 'application/json' 'Accept': 'application/json'
}) })
def _request_with_trace(
self,
method: str,
url: str,
*,
task_id: Optional[str] = None,
span_name: str = "",
**kwargs: Any,
) -> requests.Response:
request_kwargs = dict(kwargs)
headers = request_kwargs.pop("headers", None)
if task_id:
request_kwargs["headers"] = inject_trace_headers(headers)
elif headers:
request_kwargs["headers"] = headers
parsed_url = urlparse(url)
attributes = {
"http.request.method": method.upper(),
"url.path": parsed_url.path,
"server.address": parsed_url.hostname or "",
}
if parsed_url.port:
attributes["server.port"] = parsed_url.port
name = span_name or f"render.api.{method.lower()}"
with start_span(name, task_id=task_id, kind=SpanKind.CLIENT, attributes=attributes) as span:
try:
response = self.session.request(method=method, url=url, **request_kwargs)
except Exception as exc:
mark_span_error(span, str(exc), "HTTP_REQUEST_ERROR")
raise
if span is not None:
span.set_attribute("http.response.status_code", response.status_code)
if response.status_code >= 400:
span.set_status(Status(StatusCode.ERROR, f"HTTP {response.status_code}"))
return response
def sync(self, current_task_ids: List[str]) -> List[Task]: def sync(self, current_task_ids: List[str]) -> List[Task]:
""" """
心跳同步并拉取任务 心跳同步并拉取任务
@@ -128,10 +171,13 @@ class APIClientV2:
url = f"{self.base_url}/render/v2/task/{task_id}/start" url = f"{self.base_url}/render/v2/task/{task_id}/start"
try: try:
resp = self.session.post( resp = self._request_with_trace(
url, method="POST",
url=url,
task_id=task_id,
span_name="render.task.api.report_start",
json={'workerId': self.worker_id}, json={'workerId': self.worker_id},
timeout=10 timeout=10,
) )
if resp.status_code == 200: if resp.status_code == 200:
logger.debug(f"[task:{task_id}] Start reported") logger.debug(f"[task:{task_id}] Start reported")
@@ -157,13 +203,16 @@ class APIClientV2:
url = f"{self.base_url}/render/v2/task/{task_id}/success" url = f"{self.base_url}/render/v2/task/{task_id}/success"
try: try:
resp = self.session.post( resp = self._request_with_trace(
url, method="POST",
url=url,
task_id=task_id,
span_name="render.task.api.report_success",
json={ json={
'workerId': self.worker_id, 'workerId': self.worker_id,
'result': result 'result': result
}, },
timeout=10 timeout=10,
) )
if resp.status_code == 200: if resp.status_code == 200:
logger.debug(f"[task:{task_id}] Success reported") logger.debug(f"[task:{task_id}] Success reported")
@@ -190,14 +239,17 @@ class APIClientV2:
url = f"{self.base_url}/render/v2/task/{task_id}/fail" url = f"{self.base_url}/render/v2/task/{task_id}/fail"
try: try:
resp = self.session.post( resp = self._request_with_trace(
url, method="POST",
url=url,
task_id=task_id,
span_name="render.task.api.report_fail",
json={ json={
'workerId': self.worker_id, 'workerId': self.worker_id,
'errorCode': error_code, 'errorCode': error_code,
'errorMessage': error_message[:1000] # 限制长度 'errorMessage': error_message[:1000] # 限制长度
}, },
timeout=10 timeout=10,
) )
if resp.status_code == 200: if resp.status_code == 200:
logger.debug(f"[task:{task_id}] Failure reported") logger.debug(f"[task:{task_id}] Failure reported")
@@ -228,7 +280,14 @@ class APIClientV2:
payload['fileName'] = file_name payload['fileName'] = file_name
try: try:
resp = self.session.post(url, json=payload, timeout=10) resp = self._request_with_trace(
method="POST",
url=url,
task_id=task_id,
span_name="render.task.api.get_upload_url",
json=payload,
timeout=10,
)
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
if data.get('code') == 200: if data.get('code') == 200:
@@ -256,13 +315,16 @@ class APIClientV2:
url = f"{self.base_url}/render/v2/task/{task_id}/extend-lease" url = f"{self.base_url}/render/v2/task/{task_id}/extend-lease"
try: try:
resp = self.session.post( resp = self._request_with_trace(
url, method="POST",
url=url,
task_id=task_id,
span_name="render.task.api.extend_lease",
params={ params={
'workerId': self.worker_id, 'workerId': self.worker_id,
'extension': extension 'extension': extension
}, },
timeout=10 timeout=10,
) )
if resp.status_code == 200: if resp.status_code == 200:
logger.debug(f"[task:{task_id}] Lease extended by {extension}s") logger.debug(f"[task:{task_id}] Lease extended by {extension}s")
@@ -287,7 +349,13 @@ class APIClientV2:
url = f"{self.base_url}/render/v2/task/{task_id}" url = f"{self.base_url}/render/v2/task/{task_id}"
try: try:
resp = self.session.get(url, timeout=10) resp = self._request_with_trace(
method="GET",
url=url,
task_id=task_id,
span_name="render.task.api.get_task_info",
timeout=10,
)
if resp.status_code == 200: if resp.status_code == 200:
data = resp.json() data = resp.json()
if data.get('code') == 200: if data.get('code') == 200:

View File

@@ -8,10 +8,13 @@
import logging import logging
import threading import threading
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from services.api_client import APIClientV2 from services.api_client import APIClientV2
from util.tracing import TaskTraceContext
from util.tracing import bind_trace_context, start_span
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,7 +32,9 @@ class LeaseService:
api_client: 'APIClientV2', api_client: 'APIClientV2',
task_id: str, task_id: str,
interval: int = 60, interval: int = 60,
extension: int = 300 extension: int = 300,
parent_otel_context: Any = None,
task_trace_context: Optional['TaskTraceContext'] = None,
): ):
""" """
初始化租约服务 初始化租约服务
@@ -44,6 +49,8 @@ class LeaseService:
self.task_id = task_id self.task_id = task_id
self.interval = interval self.interval = interval
self.extension = extension self.extension = extension
self.parent_otel_context = parent_otel_context
self.task_trace_context = task_trace_context
self.running = False self.running = False
self.thread: threading.Thread = None self.thread: threading.Thread = None
self._stop_event = threading.Event() self._stop_event = threading.Event()
@@ -79,25 +86,29 @@ class LeaseService:
def _run(self): def _run(self):
"""续期线程主循环""" """续期线程主循环"""
while self.running: with bind_trace_context(self.parent_otel_context, self.task_trace_context):
# 等待指定间隔或收到停止信号 while self.running:
if self._stop_event.wait(timeout=self.interval): if self._stop_event.wait(timeout=self.interval):
# 收到停止信号 break
break
if self.running: if self.running:
self._extend_lease() self._extend_lease()
def _extend_lease(self): def _extend_lease(self):
"""执行租约续期""" """执行租约续期"""
try: with start_span(
success = self.api_client.extend_lease(self.task_id, self.extension) "render.task.lease.extend",
if success: task_id=self.task_id,
logger.debug(f"[task:{self.task_id}] Lease extended by {self.extension}s") attributes={"render.lease.extension_seconds": self.extension},
else: ):
logger.warning(f"[task:{self.task_id}] Failed to extend lease") try:
except Exception as e: success = self.api_client.extend_lease(self.task_id, self.extension)
logger.warning(f"[task:{self.task_id}] Lease extension error: {e}") if success:
logger.debug(f"[task:{self.task_id}] Lease extended by {self.extension}s")
else:
logger.warning(f"[task:{self.task_id}] Failed to extend lease")
except Exception as e:
logger.warning(f"[task:{self.task_id}] Lease extension error: {e}")
def __enter__(self): def __enter__(self):
"""上下文管理器入口""" """上下文管理器入口"""

View File

@@ -151,6 +151,10 @@ def _upload_with_rclone(url: str, file_path: str) -> bool:
if new_url == url: if new_url == url:
return False return False
if new_url.startswith(("http://", "https://")):
logger.warning(f"rclone upload skipped: URL still starts with http after replace: {new_url}")
return False
cmd = [ cmd = [
"rclone", "rclone",
"copyto", "copyto",

View File

@@ -11,7 +11,6 @@ from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, Optional, TYPE_CHECKING from typing import Dict, Optional, TYPE_CHECKING
from domain.task import Task, TaskType from domain.task import Task, TaskType
from domain.result import TaskResult, ErrorCode
# 需要 GPU 加速的任务类型 # 需要 GPU 加速的任务类型
GPU_REQUIRED_TASK_TYPES = { GPU_REQUIRED_TASK_TYPES = {
@@ -22,6 +21,13 @@ from domain.config import WorkerConfig
from core.handler import TaskHandler from core.handler import TaskHandler
from services.lease_service import LeaseService from services.lease_service import LeaseService
from services.gpu_scheduler import GPUScheduler from services.gpu_scheduler import GPUScheduler
from util.tracing import (
capture_otel_context,
get_current_task_context,
mark_span_error,
start_span,
task_trace_scope,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from services.api_client import APIClientV2 from services.api_client import APIClientV2
@@ -174,77 +180,84 @@ class TaskExecutor:
task: 任务实体 task: 任务实体
""" """
task_id = task.task_id task_id = task.task_id
logger.info(f"[task:{task_id}] Starting {task.task_type.value}")
# 启动租约续期服务
lease_service = LeaseService(
self.api_client,
task_id,
interval=self.config.lease_extension_threshold,
extension=self.config.lease_extension_duration
)
lease_service.start()
# 获取 GPU 设备(仅对需要 GPU 的任务类型)
device_index = None
needs_gpu = task.task_type in GPU_REQUIRED_TASK_TYPES
if needs_gpu and self.gpu_scheduler.enabled:
device_index = self.gpu_scheduler.acquire()
if device_index is not None:
logger.info(f"[task:{task_id}] Assigned to GPU device {device_index}")
# 获取处理器(需要在设置 GPU 设备前获取)
handler = self.handlers.get(task.task_type) handler = self.handlers.get(task.task_type)
device_index = None
lease_service = None
try: with task_trace_scope(task, span_name="render.task.execute") as task_span:
# 报告任务开始 logger.info(f"[task:{task_id}] Starting {task.task_type.value}")
self.api_client.report_start(task_id)
if not handler: lease_service = LeaseService(
raise ValueError(f"No handler for task type: {task.task_type}") self.api_client,
task_id,
interval=self.config.lease_extension_threshold,
extension=self.config.lease_extension_duration,
parent_otel_context=capture_otel_context(),
task_trace_context=get_current_task_context(),
)
with start_span("render.task.lease.start"):
lease_service.start()
# 设置 GPU 设备(线程本地存储) needs_gpu = task.task_type in GPU_REQUIRED_TASK_TYPES
if device_index is not None: if needs_gpu and self.gpu_scheduler.enabled:
handler.set_gpu_device(device_index) with start_span("render.task.gpu.acquire"):
device_index = self.gpu_scheduler.acquire()
if device_index is not None:
logger.info(f"[task:{task_id}] Assigned to GPU device {device_index}")
# 执行前钩子 try:
handler.before_handle(task) with start_span("render.task.report.start"):
self.api_client.report_start(task_id)
# 执行任务 if not handler:
result = handler.handle(task) raise ValueError(f"No handler for task type: {task.task_type}")
# 执行后钩子 if device_index is not None:
handler.after_handle(task, result) handler.set_gpu_device(device_index)
# 上报结果 with start_span("render.task.handler.before"):
if result.success: handler.before_handle(task)
self.api_client.report_success(task_id, result.data)
logger.info(f"[task:{task_id}] Completed successfully")
else:
error_code = result.error_code.value if result.error_code else 'E_UNKNOWN'
self.api_client.report_fail(task_id, error_code, result.error_message or '')
logger.error(f"[task:{task_id}] Failed: {result.error_message}")
except Exception as e: with start_span("render.task.handler.execute"):
logger.error(f"[task:{task_id}] Exception: {e}", exc_info=True) result = handler.handle(task)
self.api_client.report_fail(task_id, 'E_UNKNOWN', str(e))
finally: with start_span("render.task.handler.after"):
# 清除 GPU 设备设置 handler.after_handle(task, result)
if handler:
handler.clear_gpu_device()
# 释放 GPU 设备(仅当实际分配了设备时) if result.success:
if device_index is not None: with start_span("render.task.report.success"):
self.gpu_scheduler.release(device_index) self.api_client.report_success(task_id, result.data)
if task_span is not None:
task_span.set_attribute("render.task.result", "success")
logger.info(f"[task:{task_id}] Completed successfully")
else:
error_code = result.error_code.value if result.error_code else 'E_UNKNOWN'
with start_span("render.task.report.fail"):
self.api_client.report_fail(task_id, error_code, result.error_message or '')
mark_span_error(task_span, result.error_message or "task failed", error_code)
logger.error(f"[task:{task_id}] Failed: {result.error_message}")
# 停止租约续期 except Exception as e:
lease_service.stop() mark_span_error(task_span, str(e), "E_UNKNOWN")
logger.error(f"[task:{task_id}] Exception: {e}", exc_info=True)
with start_span("render.task.report.exception"):
self.api_client.report_fail(task_id, 'E_UNKNOWN', str(e))
# 从当前任务中移除 finally:
with self.lock: if handler:
self.current_tasks.pop(task_id, None) handler.clear_gpu_device()
self.current_futures.pop(task_id, None)
if device_index is not None:
with start_span("render.task.gpu.release"):
self.gpu_scheduler.release(device_index)
if lease_service is not None:
with start_span("render.task.lease.stop"):
lease_service.stop()
with self.lock:
self.current_tasks.pop(task_id, None)
self.current_futures.pop(task_id, None)
def shutdown(self, wait: bool = True): def shutdown(self, wait: bool = True):
""" """

260
util/tracing.py Normal file
View File

@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*-
"""
OTel 链路追踪工具。
提供统一的 tracing 初始化、任务上下文管理与 Span 创建能力。
"""
import logging
import os
from contextlib import contextmanager, nullcontext
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Mapping, Optional
from opentelemetry import context as otel_context
from opentelemetry import propagate, trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
logger = logging.getLogger(__name__)
_DEFAULT_SERVICE_NAME = "RenderWorkerNext"
_DEFAULT_TRACER_NAME = "render.worker"
_OTEL_EXPORTER_OTLP_ENDPOINT = "https://oltp.jerryyan.top/v1/traces"
_TASK_ID_ATTR = "render.task.id"
_TASK_TYPE_ATTR = "render.task.type"
_JOB_ID_ATTR = "render.job.id"
_SEGMENT_ID_ATTR = "render.segment.id"
_ERROR_CODE_ATTR = "render.error.code"
_ERROR_MESSAGE_ATTR = "render.error.message"
_TRUE_VALUES = {"1", "true", "yes", "on"}
_TRACING_INITIALIZED = False
_TRACING_ENABLED = False
_TRACER_PROVIDER: Optional[TracerProvider] = None
_CURRENT_TASK_CONTEXT: ContextVar[Optional["TaskTraceContext"]] = ContextVar(
"render_worker_task_trace_context",
default=None,
)
@dataclass(frozen=True)
class TaskTraceContext:
"""任务维度的 tracing 上下文。"""
task_id: str
task_type: str
job_id: str = ""
segment_id: str = ""
def to_attributes(self) -> Dict[str, str]:
attributes = {
_TASK_ID_ATTR: self.task_id,
_TASK_TYPE_ATTR: self.task_type,
}
if self.job_id:
attributes[_JOB_ID_ATTR] = self.job_id
if self.segment_id:
attributes[_SEGMENT_ID_ATTR] = self.segment_id
return attributes
def _parse_bool(value: str, default: bool) -> bool:
if value is None:
return default
return value.strip().lower() in _TRUE_VALUES
def is_tracing_enabled() -> bool:
return _TRACING_ENABLED
def initialize_tracing(worker_id: str, service_version: str) -> bool:
"""
初始化 OTel tracing。
"""
global _TRACING_INITIALIZED
global _TRACING_ENABLED
global _TRACER_PROVIDER
if _TRACING_INITIALIZED:
return _TRACING_ENABLED
_TRACING_INITIALIZED = True
if not _parse_bool(os.getenv("OTEL_ENABLED"), default=True):
logger.info("OTel tracing disabled by OTEL_ENABLED")
_TRACING_ENABLED = False
return False
service_name = _DEFAULT_SERVICE_NAME
attributes: Dict[str, str] = {
SERVICE_NAME: service_name,
SERVICE_VERSION: service_version,
"render.worker.id": str(worker_id),
}
resource = Resource.create(attributes)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(
BatchSpanProcessor(
OTLPSpanExporter(endpoint=_OTEL_EXPORTER_OTLP_ENDPOINT)
)
)
trace.set_tracer_provider(tracer_provider)
_TRACING_ENABLED = True
if trace.get_tracer_provider() is tracer_provider:
_TRACER_PROVIDER = tracer_provider
logger.info("OTel tracing initialized (service=%s, worker=%s)", service_name, worker_id)
return True
def shutdown_tracing() -> None:
"""优雅关闭 tracing provider,刷新剩余 span。"""
global _TRACING_ENABLED
if not _TRACING_ENABLED:
return
provider = _TRACER_PROVIDER
if provider is not None:
try:
provider.shutdown()
except Exception as exc:
logger.warning("Failed to shutdown tracing provider: %s", exc)
_TRACING_ENABLED = False
def build_task_trace_context(task: Any) -> TaskTraceContext:
task_id = str(getattr(task, "task_id", ""))
task_type_obj = getattr(task, "task_type", "")
task_type = str(getattr(task_type_obj, "value", task_type_obj))
job_id = ""
if hasattr(task, "get_job_id"):
job_id = str(task.get_job_id() or "")
segment_id = ""
if hasattr(task, "get_segment_id"):
segment_value = task.get_segment_id()
segment_id = str(segment_value) if segment_value is not None else ""
return TaskTraceContext(
task_id=task_id,
task_type=task_type,
job_id=job_id,
segment_id=segment_id,
)
def get_current_task_context() -> Optional[TaskTraceContext]:
return _CURRENT_TASK_CONTEXT.get()
def capture_otel_context() -> Any:
return otel_context.get_current()
@contextmanager
def bind_trace_context(parent_otel_context: Any, task_context: Optional[TaskTraceContext]) -> Iterator[None]:
"""
在当前线程绑定父 OTel 上下文与任务上下文。
用于跨线程延续任务链路(例如租约续期线程)。
"""
otel_token = None
task_token = None
if parent_otel_context is not None:
otel_token = otel_context.attach(parent_otel_context)
if task_context is not None:
task_token = _CURRENT_TASK_CONTEXT.set(task_context)
try:
yield
finally:
if task_token is not None:
_CURRENT_TASK_CONTEXT.reset(task_token)
if otel_token is not None:
otel_context.detach(otel_token)
@contextmanager
def task_trace_scope(task: Any, span_name: str = "render.task.process") -> Iterator[Optional[Span]]:
"""创建任务根 Span 并绑定任务上下文。"""
task_context = build_task_trace_context(task)
task_token = _CURRENT_TASK_CONTEXT.set(task_context)
span_cm = nullcontext(None)
if _TRACING_ENABLED:
tracer = trace.get_tracer(_DEFAULT_TRACER_NAME)
span_cm = tracer.start_as_current_span(span_name, kind=SpanKind.CONSUMER)
try:
with span_cm as span:
if span is not None:
for key, value in task_context.to_attributes().items():
span.set_attribute(key, value)
yield span
finally:
_CURRENT_TASK_CONTEXT.reset(task_token)
@contextmanager
def start_span(
name: str,
*,
attributes: Optional[Mapping[str, Any]] = None,
kind: SpanKind = SpanKind.INTERNAL,
task_id: Optional[str] = None,
) -> Iterator[Optional[Span]]:
"""
创建任务内子 Span。
当 tracing 未启用,或当前不在任务上下文中且未显式传入 task_id 时,返回空上下文。
"""
task_context = get_current_task_context()
should_trace = _TRACING_ENABLED and (task_context is not None or bool(task_id))
if not should_trace:
with nullcontext(None) as span:
yield span
return
tracer = trace.get_tracer(_DEFAULT_TRACER_NAME)
with tracer.start_as_current_span(name, kind=kind) as span:
if task_context is not None:
for key, value in task_context.to_attributes().items():
span.set_attribute(key, value)
if task_id and (task_context is None or task_context.task_id != task_id):
span.set_attribute(_TASK_ID_ATTR, task_id)
if attributes:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
yield span
def mark_span_error(span: Optional[Span], message: str, error_code: str = "") -> None:
"""标记 Span 为错误状态。"""
if span is None:
return
if error_code:
span.set_attribute(_ERROR_CODE_ATTR, error_code)
if message:
span.set_attribute(_ERROR_MESSAGE_ATTR, message[:500])
span.set_status(Status(StatusCode.ERROR, message[:200]))
def inject_trace_headers(headers: Optional[Mapping[str, str]] = None) -> Dict[str, str]:
"""向 HTTP 头注入当前 trace 上下文。"""
carrier = dict(headers) if headers else {}
if _TRACING_ENABLED:
propagate.inject(carrier)
return carrier