You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
import osimport randomimport datetimeimport numpy as npimport torchimport torch.distributed as dist
def setup_ddp(): """
Initializes the distributed data parallel environment.
This function relies on environment variables set by `torchrun` or a similar launcher. It initializes the process group and sets the CUDA device for the current process.
Returns: tuple: A tuple containing (rank, world_size, local_rank). """
if not dist.is_available(): raise RuntimeError("torch.distributed is not available.")
dist.init_process_group(backend="nccl") rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) print( f"[DDP Setup] Global Rank: {rank}/{world_size}, " f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}" ) return rank, world_size, local_rank
def cleanup_ddp(): """Cleans up the distributed process group.""" if dist.is_initialized(): dist.destroy_process_group()
def set_seed(seed: int, rank: int = 0): """
Sets the random seed for reproducibility across all relevant libraries.
Args: seed (int): The base seed value. rank (int): The process rank, used to ensure different processes have different seeds, which can be important for data loading. """
actual_seed = seed + rank random.seed(actual_seed) np.random.seed(actual_seed) torch.manual_seed(actual_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(actual_seed) # The two lines below can impact performance, so they are often # reserved for final experiments where reproducibility is critical. torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
def get_model_size(model: torch.nn.Module) -> str: """
Calculates the number of trainable parameters in a PyTorch model and returns it as a human-readable string.
Args: model (torch.nn.Module): The PyTorch model.
Returns: str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K"). """
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
if total_params >= 1e9: return f"{total_params / 1e9:.1f}B" # Billions elif total_params >= 1e6: return f"{total_params / 1e6:.1f}M" # Millions else: return f"{total_params / 1e3:.1f}K" # Thousands
def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor: """
Reduces a tensor's value across all processes in a distributed setup.
Args: tensor (torch.Tensor): The tensor to be reduced. world_size (int): The total number of processes. op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.). Defaults to dist.ReduceOp.SUM.
Returns: torch.Tensor: The reduced tensor, which will be identical on all processes. """
rt = tensor.clone() dist.all_reduce(rt, op=op) # Note: `dist.ReduceOp.AVG` is available in newer torch versions. # For compatibility, manual division is sometimes used after a SUM. if op == dist.ReduceOp.AVG: rt /= world_size return rt
def format_time(seconds: float) -> str: """
Formats a duration in seconds into a human-readable H:M:S string.
Args: seconds (float): The total seconds.
Returns: str: The formatted time string (e.g., "0:15:32"). """
return str(datetime.timedelta(seconds=int(seconds)))
|