You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
3.6 KiB

1 month ago
  1. import os
  2. import random
  3. import datetime
  4. import numpy as np
  5. import torch
  6. import torch.distributed as dist
  7. def setup_ddp():
  8. """
  9. Initializes the distributed data parallel environment.
  10. This function relies on environment variables set by `torchrun` or a similar
  11. launcher. It initializes the process group and sets the CUDA device for the
  12. current process.
  13. Returns:
  14. tuple: A tuple containing (rank, world_size, local_rank).
  15. """
  16. if not dist.is_available():
  17. raise RuntimeError("torch.distributed is not available.")
  18. dist.init_process_group(backend="nccl")
  19. rank = int(os.environ["RANK"])
  20. world_size = int(os.environ["WORLD_SIZE"])
  21. local_rank = int(os.environ["LOCAL_RANK"])
  22. torch.cuda.set_device(local_rank)
  23. print(
  24. f"[DDP Setup] Global Rank: {rank}/{world_size}, "
  25. f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}"
  26. )
  27. return rank, world_size, local_rank
  28. def cleanup_ddp():
  29. """Cleans up the distributed process group."""
  30. if dist.is_initialized():
  31. dist.destroy_process_group()
  32. def set_seed(seed: int, rank: int = 0):
  33. """
  34. Sets the random seed for reproducibility across all relevant libraries.
  35. Args:
  36. seed (int): The base seed value.
  37. rank (int): The process rank, used to ensure different processes have
  38. different seeds, which can be important for data loading.
  39. """
  40. actual_seed = seed + rank
  41. random.seed(actual_seed)
  42. np.random.seed(actual_seed)
  43. torch.manual_seed(actual_seed)
  44. if torch.cuda.is_available():
  45. torch.cuda.manual_seed_all(actual_seed)
  46. # The two lines below can impact performance, so they are often
  47. # reserved for final experiments where reproducibility is critical.
  48. torch.backends.cudnn.deterministic = True
  49. torch.backends.cudnn.benchmark = False
  50. def get_model_size(model: torch.nn.Module) -> str:
  51. """
  52. Calculates the number of trainable parameters in a PyTorch model and returns
  53. it as a human-readable string.
  54. Args:
  55. model (torch.nn.Module): The PyTorch model.
  56. Returns:
  57. str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K").
  58. """
  59. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  60. if total_params >= 1e9:
  61. return f"{total_params / 1e9:.1f}B" # Billions
  62. elif total_params >= 1e6:
  63. return f"{total_params / 1e6:.1f}M" # Millions
  64. else:
  65. return f"{total_params / 1e3:.1f}K" # Thousands
  66. def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor:
  67. """
  68. Reduces a tensor's value across all processes in a distributed setup.
  69. Args:
  70. tensor (torch.Tensor): The tensor to be reduced.
  71. world_size (int): The total number of processes.
  72. op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.).
  73. Defaults to dist.ReduceOp.SUM.
  74. Returns:
  75. torch.Tensor: The reduced tensor, which will be identical on all processes.
  76. """
  77. rt = tensor.clone()
  78. dist.all_reduce(rt, op=op)
  79. # Note: `dist.ReduceOp.AVG` is available in newer torch versions.
  80. # For compatibility, manual division is sometimes used after a SUM.
  81. if op == dist.ReduceOp.AVG:
  82. rt /= world_size
  83. return rt
  84. def format_time(seconds: float) -> str:
  85. """
  86. Formats a duration in seconds into a human-readable H:M:S string.
  87. Args:
  88. seconds (float): The total seconds.
  89. Returns:
  90. str: The formatted time string (e.g., "0:15:32").
  91. """
  92. return str(datetime.timedelta(seconds=int(seconds)))