Source code for cebmf_torch.utils.device

import torch


[docs] def get_device(prefer_gpu: bool = True) -> torch.device: """Get the best available device. Priority order: 1. CUDA (NVIDIA GPUs) 2. MPS (Apple Silicon GPUs) 3. CPU (fallback) Args: prefer_gpu: Whether to prefer GPU over CPU Returns: torch.device: The selected device """ if not prefer_gpu: return torch.device("cpu") if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") else: return torch.device("cpu")
[docs] def to_device(x: torch.Tensor, device: torch.device | None = None): """ Move a tensor or module to the specified device. Parameters ---------- x : object Tensor, module, or object supporting the .to() method. device : torch.device or None, optional Target device. If None, uses the default device from get_device(). Returns ------- object The input object moved to the specified device, or unchanged if not supported. """ if device is None: device = get_device() return x.to(device)