diff --git a/.env.example b/.env.example index 351a3eb..3692fac 100644 --- a/.env.example +++ b/.env.example @@ -63,3 +63,9 @@ HW_ACCEL=none #UPLOAD_METHOD=rclone #RCLONE_CONFIG_FILE= # rclone 配置文件路径 #RCLONE_REPLACE_MAP="https://oss.example.com|alioss://bucket" + +# =================== +# OTel 链路追踪 +# =================== +# 是否启用 OTel 追踪(默认 true) +#OTEL_ENABLED=true diff --git a/handlers/base.py b/handlers/base.py index e573c3b..ba03a18 100644 --- a/handlers/base.py +++ b/handlers/base.py @@ -15,12 +15,15 @@ import threading from abc import ABC from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING +from opentelemetry.trace import SpanKind + from core.handler import TaskHandler from domain.task import Task from domain.result import TaskResult, ErrorCode from domain.config import WorkerConfig from services import storage from services.cache import MaterialCache +from util.tracing import mark_span_error, start_span from constant import ( HW_ACCEL_NONE, HW_ACCEL_QSV, HW_ACCEL_CUDA, VIDEO_ENCODE_PARAMS, VIDEO_ENCODE_PARAMS_QSV, VIDEO_ENCODE_PARAMS_CUDA @@ -410,21 +413,30 @@ class BaseHandler(TaskHandler, ABC): if timeout is None: timeout = self.config.download_timeout - try: - if use_cache: - # 使用缓存下载 - result = self.material_cache.get_or_download(url, dest, timeout=timeout) - else: - # 直接下载(不走缓存) - result = storage.download_file(url, dest, timeout=timeout) + with start_span( + "render.task.file.download", + kind=SpanKind.CLIENT, + attributes={ + "render.file.destination": dest, + "render.file.use_cache": use_cache, + }, + ) as span: + try: + if use_cache: + result = self.material_cache.get_or_download(url, dest, timeout=timeout) + else: + result = storage.download_file(url, dest, timeout=timeout) - if result: - file_size = os.path.getsize(dest) if os.path.exists(dest) else 0 - logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)") - return result - except Exception as e: - logger.error(f"Download failed: {url} -> {e}") - return False + if result: + file_size = os.path.getsize(dest) if os.path.exists(dest) else 0 + logger.debug(f"Downloaded: {url} -> {dest} ({file_size} bytes)") + if span is not None: + span.set_attribute("render.file.size_bytes", file_size) + return result + except Exception as e: + mark_span_error(span, str(e), ErrorCode.E_INPUT_UNAVAILABLE.value) + logger.error(f"Download failed: {url} -> {e}") + return False def upload_file( self, @@ -445,37 +457,45 @@ class BaseHandler(TaskHandler, ABC): Returns: 访问 URL,失败返回 None """ - # 获取上传 URL - upload_info = self.api_client.get_upload_url(task_id, file_type, file_name) - if not upload_info: - logger.error(f"[task:{task_id}] Failed to get upload URL") - return None + with start_span( + "render.task.file.upload", + kind=SpanKind.CLIENT, + attributes={ + "render.file.type": file_type, + "render.file.path": file_path, + }, + ) as span: + upload_info = self.api_client.get_upload_url(task_id, file_type, file_name) + if not upload_info: + logger.error(f"[task:{task_id}] Failed to get upload URL") + return None - upload_url = upload_info.get('uploadUrl') - access_url = upload_info.get('accessUrl') + upload_url = upload_info.get('uploadUrl') + access_url = upload_info.get('accessUrl') - if not upload_url: - logger.error(f"[task:{task_id}] Invalid upload URL response") - return None + if not upload_url: + logger.error(f"[task:{task_id}] Invalid upload URL response") + return None - # 上传文件 - try: - result = storage.upload_file(upload_url, file_path, timeout=self.config.upload_timeout) - if result: - file_size = os.path.getsize(file_path) - logger.info(f"[task:{task_id}] Uploaded: {file_path} ({file_size} bytes)") + try: + result = storage.upload_file(upload_url, file_path, timeout=self.config.upload_timeout) + if result: + file_size = os.path.getsize(file_path) + logger.info(f"[task:{task_id}] Uploaded: {file_path} ({file_size} bytes)") + if span is not None: + span.set_attribute("render.file.size_bytes", file_size) - # 将上传成功的文件加入缓存 - if access_url: - self.material_cache.add_to_cache(access_url, file_path) + if access_url: + self.material_cache.add_to_cache(access_url, file_path) + + return access_url - return access_url - else: logger.error(f"[task:{task_id}] Upload failed: {file_path}") return None - except Exception as e: - logger.error(f"[task:{task_id}] Upload error: {e}") - return None + except Exception as e: + mark_span_error(span, str(e), ErrorCode.E_UPLOAD_FAILED.value) + logger.error(f"[task:{task_id}] Upload error: {e}") + return None def run_ffmpeg( self, @@ -507,29 +527,42 @@ class BaseHandler(TaskHandler, ABC): cmd_str = cmd_str[:500] + '...' logger.info(f"[task:{task_id}] FFmpeg: {cmd_str}") - try: - run_args = subprocess_args(False) - run_args['stdout'] = subprocess.DEVNULL - run_args['stderr'] = subprocess.PIPE - result = subprocess.run( - cmd_to_run, - timeout=timeout, - **run_args - ) + with start_span( + "render.task.ffmpeg.run", + attributes={ + "render.ffmpeg.timeout_seconds": timeout, + "render.ffmpeg.command": cmd_str, + }, + ) as span: + try: + run_args = subprocess_args(False) + run_args['stdout'] = subprocess.DEVNULL + run_args['stderr'] = subprocess.PIPE + result = subprocess.run( + cmd_to_run, + timeout=timeout, + **run_args + ) - if result.returncode != 0: - stderr = (result.stderr or b'').decode('utf-8', errors='replace')[:1000] - logger.error(f"[task:{task_id}] FFmpeg failed (code={result.returncode}): {stderr}") + if span is not None: + span.set_attribute("render.ffmpeg.return_code", result.returncode) + + if result.returncode != 0: + stderr = (result.stderr or b'').decode('utf-8', errors='replace')[:1000] + logger.error(f"[task:{task_id}] FFmpeg failed (code={result.returncode}): {stderr}") + mark_span_error(span, stderr or "ffmpeg failed", ErrorCode.E_FFMPEG_FAILED.value) + return False + + return True + + except subprocess.TimeoutExpired: + logger.error(f"[task:{task_id}] FFmpeg timeout after {timeout}s") + mark_span_error(span, f"timeout after {timeout}s", ErrorCode.E_TIMEOUT.value) + return False + except Exception as e: + logger.error(f"[task:{task_id}] FFmpeg error: {e}") + mark_span_error(span, str(e), ErrorCode.E_FFMPEG_FAILED.value) return False - - return True - - except subprocess.TimeoutExpired: - logger.error(f"[task:{task_id}] FFmpeg timeout after {timeout}s") - return False - except Exception as e: - logger.error(f"[task:{task_id}] FFmpeg error: {e}") - return False def probe_duration(self, file_path: str) -> Optional[float]: """ diff --git a/index.py b/index.py index d010cb5..1faf86c 100644 --- a/index.py +++ b/index.py @@ -34,6 +34,7 @@ from domain.config import WorkerConfig from services.api_client import APIClientV2 from services.task_executor import TaskExecutor from constant import SOFTWARE_VERSION +from util.tracing import initialize_tracing, shutdown_tracing # 日志配置 def setup_logging(): @@ -113,6 +114,9 @@ class WorkerV2: logger.error(f"Configuration error: {e}") sys.exit(1) + tracing_enabled = initialize_tracing(self.config.worker_id, SOFTWARE_VERSION) + logger.info("OTel tracing %s", "enabled" if tracing_enabled else "disabled") + # 初始化 API 客户端 self.api_client = APIClientV2(self.config) @@ -212,6 +216,7 @@ class WorkerV2: # 关闭 API 客户端 self.api_client.close() + shutdown_tracing() logger.info("Worker stopped") diff --git a/services/api_client.py b/services/api_client.py index 7b99fb6..14e2bd3 100644 --- a/services/api_client.py +++ b/services/api_client.py @@ -10,10 +10,14 @@ import subprocess import time import requests from typing import Dict, List, Optional, Any +from urllib.parse import urlparse + +from opentelemetry.trace import SpanKind, Status, StatusCode from domain.task import Task from domain.config import WorkerConfig from util.system import get_hw_accel_info_str +from util.tracing import inject_trace_headers, mark_span_error, start_span logger = logging.getLogger(__name__) @@ -55,6 +59,45 @@ class APIClientV2: 'Accept': 'application/json' }) + def _request_with_trace( + self, + method: str, + url: str, + *, + task_id: Optional[str] = None, + span_name: str = "", + **kwargs: Any, + ) -> requests.Response: + request_kwargs = dict(kwargs) + headers = request_kwargs.pop("headers", None) + if task_id: + request_kwargs["headers"] = inject_trace_headers(headers) + elif headers: + request_kwargs["headers"] = headers + + parsed_url = urlparse(url) + attributes = { + "http.request.method": method.upper(), + "url.path": parsed_url.path, + "server.address": parsed_url.hostname or "", + } + if parsed_url.port: + attributes["server.port"] = parsed_url.port + + name = span_name or f"render.api.{method.lower()}" + with start_span(name, task_id=task_id, kind=SpanKind.CLIENT, attributes=attributes) as span: + try: + response = self.session.request(method=method, url=url, **request_kwargs) + except Exception as exc: + mark_span_error(span, str(exc), "HTTP_REQUEST_ERROR") + raise + + if span is not None: + span.set_attribute("http.response.status_code", response.status_code) + if response.status_code >= 400: + span.set_status(Status(StatusCode.ERROR, f"HTTP {response.status_code}")) + return response + def sync(self, current_task_ids: List[str]) -> List[Task]: """ 心跳同步并拉取任务 @@ -128,10 +171,13 @@ class APIClientV2: url = f"{self.base_url}/render/v2/task/{task_id}/start" try: - resp = self.session.post( - url, + resp = self._request_with_trace( + method="POST", + url=url, + task_id=task_id, + span_name="render.task.api.report_start", json={'workerId': self.worker_id}, - timeout=10 + timeout=10, ) if resp.status_code == 200: logger.debug(f"[task:{task_id}] Start reported") @@ -157,13 +203,16 @@ class APIClientV2: url = f"{self.base_url}/render/v2/task/{task_id}/success" try: - resp = self.session.post( - url, + resp = self._request_with_trace( + method="POST", + url=url, + task_id=task_id, + span_name="render.task.api.report_success", json={ 'workerId': self.worker_id, 'result': result }, - timeout=10 + timeout=10, ) if resp.status_code == 200: logger.debug(f"[task:{task_id}] Success reported") @@ -190,14 +239,17 @@ class APIClientV2: url = f"{self.base_url}/render/v2/task/{task_id}/fail" try: - resp = self.session.post( - url, + resp = self._request_with_trace( + method="POST", + url=url, + task_id=task_id, + span_name="render.task.api.report_fail", json={ 'workerId': self.worker_id, 'errorCode': error_code, 'errorMessage': error_message[:1000] # 限制长度 }, - timeout=10 + timeout=10, ) if resp.status_code == 200: logger.debug(f"[task:{task_id}] Failure reported") @@ -228,7 +280,14 @@ class APIClientV2: payload['fileName'] = file_name try: - resp = self.session.post(url, json=payload, timeout=10) + resp = self._request_with_trace( + method="POST", + url=url, + task_id=task_id, + span_name="render.task.api.get_upload_url", + json=payload, + timeout=10, + ) if resp.status_code == 200: data = resp.json() if data.get('code') == 200: @@ -256,13 +315,16 @@ class APIClientV2: url = f"{self.base_url}/render/v2/task/{task_id}/extend-lease" try: - resp = self.session.post( - url, + resp = self._request_with_trace( + method="POST", + url=url, + task_id=task_id, + span_name="render.task.api.extend_lease", params={ 'workerId': self.worker_id, 'extension': extension }, - timeout=10 + timeout=10, ) if resp.status_code == 200: logger.debug(f"[task:{task_id}] Lease extended by {extension}s") @@ -287,7 +349,13 @@ class APIClientV2: url = f"{self.base_url}/render/v2/task/{task_id}" try: - resp = self.session.get(url, timeout=10) + resp = self._request_with_trace( + method="GET", + url=url, + task_id=task_id, + span_name="render.task.api.get_task_info", + timeout=10, + ) if resp.status_code == 200: data = resp.json() if data.get('code') == 200: diff --git a/services/lease_service.py b/services/lease_service.py index 6e40dd9..5de7b6e 100644 --- a/services/lease_service.py +++ b/services/lease_service.py @@ -8,10 +8,13 @@ import logging import threading import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from services.api_client import APIClientV2 + from util.tracing import TaskTraceContext + +from util.tracing import bind_trace_context, start_span logger = logging.getLogger(__name__) @@ -29,7 +32,9 @@ class LeaseService: api_client: 'APIClientV2', task_id: str, interval: int = 60, - extension: int = 300 + extension: int = 300, + parent_otel_context: Any = None, + task_trace_context: Optional['TaskTraceContext'] = None, ): """ 初始化租约服务 @@ -44,6 +49,8 @@ class LeaseService: self.task_id = task_id self.interval = interval self.extension = extension + self.parent_otel_context = parent_otel_context + self.task_trace_context = task_trace_context self.running = False self.thread: threading.Thread = None self._stop_event = threading.Event() @@ -79,25 +86,29 @@ class LeaseService: def _run(self): """续期线程主循环""" - while self.running: - # 等待指定间隔或收到停止信号 - if self._stop_event.wait(timeout=self.interval): - # 收到停止信号 - break + with bind_trace_context(self.parent_otel_context, self.task_trace_context): + while self.running: + if self._stop_event.wait(timeout=self.interval): + break - if self.running: - self._extend_lease() + if self.running: + self._extend_lease() def _extend_lease(self): """执行租约续期""" - try: - success = self.api_client.extend_lease(self.task_id, self.extension) - if success: - logger.debug(f"[task:{self.task_id}] Lease extended by {self.extension}s") - else: - logger.warning(f"[task:{self.task_id}] Failed to extend lease") - except Exception as e: - logger.warning(f"[task:{self.task_id}] Lease extension error: {e}") + with start_span( + "render.task.lease.extend", + task_id=self.task_id, + attributes={"render.lease.extension_seconds": self.extension}, + ): + try: + success = self.api_client.extend_lease(self.task_id, self.extension) + if success: + logger.debug(f"[task:{self.task_id}] Lease extended by {self.extension}s") + else: + logger.warning(f"[task:{self.task_id}] Failed to extend lease") + except Exception as e: + logger.warning(f"[task:{self.task_id}] Lease extension error: {e}") def __enter__(self): """上下文管理器入口""" diff --git a/services/storage.py b/services/storage.py index 4353f53..5baab41 100644 --- a/services/storage.py +++ b/services/storage.py @@ -151,6 +151,10 @@ def _upload_with_rclone(url: str, file_path: str) -> bool: if new_url == url: return False + if new_url.startswith(("http://", "https://")): + logger.warning(f"rclone upload skipped: URL still starts with http after replace: {new_url}") + return False + cmd = [ "rclone", "copyto", diff --git a/services/task_executor.py b/services/task_executor.py index 5d4038e..d4255c9 100644 --- a/services/task_executor.py +++ b/services/task_executor.py @@ -11,7 +11,6 @@ from concurrent.futures import ThreadPoolExecutor, Future from typing import Dict, Optional, TYPE_CHECKING from domain.task import Task, TaskType -from domain.result import TaskResult, ErrorCode # 需要 GPU 加速的任务类型 GPU_REQUIRED_TASK_TYPES = { @@ -22,6 +21,13 @@ from domain.config import WorkerConfig from core.handler import TaskHandler from services.lease_service import LeaseService from services.gpu_scheduler import GPUScheduler +from util.tracing import ( + capture_otel_context, + get_current_task_context, + mark_span_error, + start_span, + task_trace_scope, +) if TYPE_CHECKING: from services.api_client import APIClientV2 @@ -174,77 +180,84 @@ class TaskExecutor: task: 任务实体 """ task_id = task.task_id - logger.info(f"[task:{task_id}] Starting {task.task_type.value}") - - # 启动租约续期服务 - lease_service = LeaseService( - self.api_client, - task_id, - interval=self.config.lease_extension_threshold, - extension=self.config.lease_extension_duration - ) - lease_service.start() - - # 获取 GPU 设备(仅对需要 GPU 的任务类型) - device_index = None - needs_gpu = task.task_type in GPU_REQUIRED_TASK_TYPES - if needs_gpu and self.gpu_scheduler.enabled: - device_index = self.gpu_scheduler.acquire() - if device_index is not None: - logger.info(f"[task:{task_id}] Assigned to GPU device {device_index}") - - # 获取处理器(需要在设置 GPU 设备前获取) handler = self.handlers.get(task.task_type) + device_index = None + lease_service = None - try: - # 报告任务开始 - self.api_client.report_start(task_id) + with task_trace_scope(task, span_name="render.task.execute") as task_span: + logger.info(f"[task:{task_id}] Starting {task.task_type.value}") - if not handler: - raise ValueError(f"No handler for task type: {task.task_type}") + lease_service = LeaseService( + self.api_client, + task_id, + interval=self.config.lease_extension_threshold, + extension=self.config.lease_extension_duration, + parent_otel_context=capture_otel_context(), + task_trace_context=get_current_task_context(), + ) + with start_span("render.task.lease.start"): + lease_service.start() - # 设置 GPU 设备(线程本地存储) - if device_index is not None: - handler.set_gpu_device(device_index) + needs_gpu = task.task_type in GPU_REQUIRED_TASK_TYPES + if needs_gpu and self.gpu_scheduler.enabled: + with start_span("render.task.gpu.acquire"): + device_index = self.gpu_scheduler.acquire() + if device_index is not None: + logger.info(f"[task:{task_id}] Assigned to GPU device {device_index}") - # 执行前钩子 - handler.before_handle(task) + try: + with start_span("render.task.report.start"): + self.api_client.report_start(task_id) - # 执行任务 - result = handler.handle(task) + if not handler: + raise ValueError(f"No handler for task type: {task.task_type}") - # 执行后钩子 - handler.after_handle(task, result) + if device_index is not None: + handler.set_gpu_device(device_index) - # 上报结果 - if result.success: - self.api_client.report_success(task_id, result.data) - logger.info(f"[task:{task_id}] Completed successfully") - else: - error_code = result.error_code.value if result.error_code else 'E_UNKNOWN' - self.api_client.report_fail(task_id, error_code, result.error_message or '') - logger.error(f"[task:{task_id}] Failed: {result.error_message}") + with start_span("render.task.handler.before"): + handler.before_handle(task) - except Exception as e: - logger.error(f"[task:{task_id}] Exception: {e}", exc_info=True) - self.api_client.report_fail(task_id, 'E_UNKNOWN', str(e)) + with start_span("render.task.handler.execute"): + result = handler.handle(task) - finally: - # 清除 GPU 设备设置 - if handler: - handler.clear_gpu_device() + with start_span("render.task.handler.after"): + handler.after_handle(task, result) - # 释放 GPU 设备(仅当实际分配了设备时) - if device_index is not None: - self.gpu_scheduler.release(device_index) + if result.success: + with start_span("render.task.report.success"): + self.api_client.report_success(task_id, result.data) + if task_span is not None: + task_span.set_attribute("render.task.result", "success") + logger.info(f"[task:{task_id}] Completed successfully") + else: + error_code = result.error_code.value if result.error_code else 'E_UNKNOWN' + with start_span("render.task.report.fail"): + self.api_client.report_fail(task_id, error_code, result.error_message or '') + mark_span_error(task_span, result.error_message or "task failed", error_code) + logger.error(f"[task:{task_id}] Failed: {result.error_message}") - # 停止租约续期 - lease_service.stop() + except Exception as e: + mark_span_error(task_span, str(e), "E_UNKNOWN") + logger.error(f"[task:{task_id}] Exception: {e}", exc_info=True) + with start_span("render.task.report.exception"): + self.api_client.report_fail(task_id, 'E_UNKNOWN', str(e)) - # 从当前任务中移除 - with self.lock: - self.current_tasks.pop(task_id, None) - self.current_futures.pop(task_id, None) + finally: + if handler: + handler.clear_gpu_device() + + if device_index is not None: + with start_span("render.task.gpu.release"): + self.gpu_scheduler.release(device_index) + + if lease_service is not None: + with start_span("render.task.lease.stop"): + lease_service.stop() + + with self.lock: + self.current_tasks.pop(task_id, None) + self.current_futures.pop(task_id, None) def shutdown(self, wait: bool = True): """ diff --git a/util/tracing.py b/util/tracing.py new file mode 100644 index 0000000..23ae181 --- /dev/null +++ b/util/tracing.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +""" +OTel 链路追踪工具。 + +提供统一的 tracing 初始化、任务上下文管理与 Span 创建能力。 +""" + +import logging +import os +from contextlib import contextmanager, nullcontext +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, Dict, Iterator, Mapping, Optional + +from opentelemetry import context as otel_context +from opentelemetry import propagate, trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.trace import Span, SpanKind, Status, StatusCode + +logger = logging.getLogger(__name__) + +_DEFAULT_SERVICE_NAME = "RenderWorkerNext" +_DEFAULT_TRACER_NAME = "render.worker" +_OTEL_EXPORTER_OTLP_ENDPOINT = "https://oltp.jerryyan.top/v1/traces" +_TASK_ID_ATTR = "render.task.id" +_TASK_TYPE_ATTR = "render.task.type" +_JOB_ID_ATTR = "render.job.id" +_SEGMENT_ID_ATTR = "render.segment.id" +_ERROR_CODE_ATTR = "render.error.code" +_ERROR_MESSAGE_ATTR = "render.error.message" +_TRUE_VALUES = {"1", "true", "yes", "on"} + +_TRACING_INITIALIZED = False +_TRACING_ENABLED = False +_TRACER_PROVIDER: Optional[TracerProvider] = None +_CURRENT_TASK_CONTEXT: ContextVar[Optional["TaskTraceContext"]] = ContextVar( + "render_worker_task_trace_context", + default=None, +) + + +@dataclass(frozen=True) +class TaskTraceContext: + """任务维度的 tracing 上下文。""" + + task_id: str + task_type: str + job_id: str = "" + segment_id: str = "" + + def to_attributes(self) -> Dict[str, str]: + attributes = { + _TASK_ID_ATTR: self.task_id, + _TASK_TYPE_ATTR: self.task_type, + } + if self.job_id: + attributes[_JOB_ID_ATTR] = self.job_id + if self.segment_id: + attributes[_SEGMENT_ID_ATTR] = self.segment_id + return attributes + + +def _parse_bool(value: str, default: bool) -> bool: + if value is None: + return default + return value.strip().lower() in _TRUE_VALUES + + +def is_tracing_enabled() -> bool: + return _TRACING_ENABLED + + +def initialize_tracing(worker_id: str, service_version: str) -> bool: + """ + 初始化 OTel tracing。 + """ + global _TRACING_INITIALIZED + global _TRACING_ENABLED + global _TRACER_PROVIDER + + if _TRACING_INITIALIZED: + return _TRACING_ENABLED + + _TRACING_INITIALIZED = True + if not _parse_bool(os.getenv("OTEL_ENABLED"), default=True): + logger.info("OTel tracing disabled by OTEL_ENABLED") + _TRACING_ENABLED = False + return False + + service_name = _DEFAULT_SERVICE_NAME + attributes: Dict[str, str] = { + SERVICE_NAME: service_name, + SERVICE_VERSION: service_version, + "render.worker.id": str(worker_id), + } + + resource = Resource.create(attributes) + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor( + BatchSpanProcessor( + OTLPSpanExporter(endpoint=_OTEL_EXPORTER_OTLP_ENDPOINT) + ) + ) + trace.set_tracer_provider(tracer_provider) + + _TRACING_ENABLED = True + if trace.get_tracer_provider() is tracer_provider: + _TRACER_PROVIDER = tracer_provider + + logger.info("OTel tracing initialized (service=%s, worker=%s)", service_name, worker_id) + return True + + +def shutdown_tracing() -> None: + """优雅关闭 tracing provider,刷新剩余 span。""" + global _TRACING_ENABLED + + if not _TRACING_ENABLED: + return + + provider = _TRACER_PROVIDER + if provider is not None: + try: + provider.shutdown() + except Exception as exc: + logger.warning("Failed to shutdown tracing provider: %s", exc) + + _TRACING_ENABLED = False + + +def build_task_trace_context(task: Any) -> TaskTraceContext: + task_id = str(getattr(task, "task_id", "")) + task_type_obj = getattr(task, "task_type", "") + task_type = str(getattr(task_type_obj, "value", task_type_obj)) + + job_id = "" + if hasattr(task, "get_job_id"): + job_id = str(task.get_job_id() or "") + + segment_id = "" + if hasattr(task, "get_segment_id"): + segment_value = task.get_segment_id() + segment_id = str(segment_value) if segment_value is not None else "" + + return TaskTraceContext( + task_id=task_id, + task_type=task_type, + job_id=job_id, + segment_id=segment_id, + ) + + +def get_current_task_context() -> Optional[TaskTraceContext]: + return _CURRENT_TASK_CONTEXT.get() + + +def capture_otel_context() -> Any: + return otel_context.get_current() + + +@contextmanager +def bind_trace_context(parent_otel_context: Any, task_context: Optional[TaskTraceContext]) -> Iterator[None]: + """ + 在当前线程绑定父 OTel 上下文与任务上下文。 + + 用于跨线程延续任务链路(例如租约续期线程)。 + """ + otel_token = None + task_token = None + + if parent_otel_context is not None: + otel_token = otel_context.attach(parent_otel_context) + if task_context is not None: + task_token = _CURRENT_TASK_CONTEXT.set(task_context) + + try: + yield + finally: + if task_token is not None: + _CURRENT_TASK_CONTEXT.reset(task_token) + if otel_token is not None: + otel_context.detach(otel_token) + + +@contextmanager +def task_trace_scope(task: Any, span_name: str = "render.task.process") -> Iterator[Optional[Span]]: + """创建任务根 Span 并绑定任务上下文。""" + task_context = build_task_trace_context(task) + task_token = _CURRENT_TASK_CONTEXT.set(task_context) + + span_cm = nullcontext(None) + if _TRACING_ENABLED: + tracer = trace.get_tracer(_DEFAULT_TRACER_NAME) + span_cm = tracer.start_as_current_span(span_name, kind=SpanKind.CONSUMER) + + try: + with span_cm as span: + if span is not None: + for key, value in task_context.to_attributes().items(): + span.set_attribute(key, value) + yield span + finally: + _CURRENT_TASK_CONTEXT.reset(task_token) + + +@contextmanager +def start_span( + name: str, + *, + attributes: Optional[Mapping[str, Any]] = None, + kind: SpanKind = SpanKind.INTERNAL, + task_id: Optional[str] = None, +) -> Iterator[Optional[Span]]: + """ + 创建任务内子 Span。 + + 当 tracing 未启用,或当前不在任务上下文中且未显式传入 task_id 时,返回空上下文。 + """ + task_context = get_current_task_context() + should_trace = _TRACING_ENABLED and (task_context is not None or bool(task_id)) + if not should_trace: + with nullcontext(None) as span: + yield span + return + + tracer = trace.get_tracer(_DEFAULT_TRACER_NAME) + with tracer.start_as_current_span(name, kind=kind) as span: + if task_context is not None: + for key, value in task_context.to_attributes().items(): + span.set_attribute(key, value) + if task_id and (task_context is None or task_context.task_id != task_id): + span.set_attribute(_TASK_ID_ATTR, task_id) + if attributes: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + yield span + + +def mark_span_error(span: Optional[Span], message: str, error_code: str = "") -> None: + """标记 Span 为错误状态。""" + if span is None: + return + + if error_code: + span.set_attribute(_ERROR_CODE_ATTR, error_code) + if message: + span.set_attribute(_ERROR_MESSAGE_ATTR, message[:500]) + span.set_status(Status(StatusCode.ERROR, message[:200])) + + +def inject_trace_headers(headers: Optional[Mapping[str, str]] = None) -> Dict[str, str]: + """向 HTTP 头注入当前 trace 上下文。""" + carrier = dict(headers) if headers else {} + if _TRACING_ENABLED: + propagate.inject(carrier) + return carrier