andstor commited on
Commit
6ca6353
1 Parent(s): 55d1b4f

Update src/parallelism_utils.py

Browse files
Files changed (1) hide show
  1. src/parallelism_utils.py +6 -12
src/parallelism_utils.py CHANGED
@@ -9,10 +9,10 @@ def get_precision_fac(precision: str):
9
  raise ValueError("Precision must be either 'mixed' or 'single'")
10
 
11
 
12
- def get_params_fac(model_dtype: torch.dtype):
13
- if model_dtype == torch.float16:
14
  return 2
15
- elif model_dtype == torch.float32:
16
  return 4
17
  else:
18
  raise ValueError("Model dtype must be either torch.float16 or torch.float32")
@@ -29,19 +29,13 @@ FP32_PARAM_FACTOR = 4
29
  MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
30
 
31
 
32
- # TODO: check if params_fac is needed during full fp32 training.
33
- # Normally, mixed precision training results in 1.5x memory compared to FP32.
34
- # Currently, we are assuming 2x memory for FP32, as deepspeed's ZeRO-2 is optimized for FP16 training.
35
-
36
-
37
-
38
  def estimate_zero1_model_states_mem_needs(total_params,
39
  num_gpus_per_node=1,
40
  num_nodes=1,
41
  cpu_offload=True,
42
  additional_buffer_factor=1.5,
43
  precision="mixed",
44
- model_dtype = torch.float16,
45
  ):
46
 
47
  total_gpus = num_nodes * num_gpus_per_node
@@ -68,7 +62,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
68
  cpu_offload=True,
69
  additional_buffer_factor=1.5,
70
  precision="mixed",
71
- model_dtype = torch.float16,
72
  ):
73
 
74
  total_gpus = num_nodes * num_gpus_per_node
@@ -98,7 +92,7 @@ def estimate_zero3_model_states_mem_needs(total_params,
98
  zero_init=True,
99
  additional_buffer_factor=1.5,
100
  precision="mixed",
101
- model_dtype = torch.float16,
102
  ):
103
 
104
  total_gpus = num_nodes * num_gpus_per_node
 
9
  raise ValueError("Precision must be either 'mixed' or 'single'")
10
 
11
 
12
+ def get_params_fac(model_dtype: str):
13
+ if model_dtype == "float16":
14
  return 2
15
+ elif model_dtype == "float32":
16
  return 4
17
  else:
18
  raise ValueError("Model dtype must be either torch.float16 or torch.float32")
 
29
  MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
30
 
31
 
 
 
 
 
 
 
32
  def estimate_zero1_model_states_mem_needs(total_params,
33
  num_gpus_per_node=1,
34
  num_nodes=1,
35
  cpu_offload=True,
36
  additional_buffer_factor=1.5,
37
  precision="mixed",
38
+ model_dtype = "float16",
39
  ):
40
 
41
  total_gpus = num_nodes * num_gpus_per_node
 
62
  cpu_offload=True,
63
  additional_buffer_factor=1.5,
64
  precision="mixed",
65
+ model_dtype = "float16",
66
  ):
67
 
68
  total_gpus = num_nodes * num_gpus_per_node
 
92
  zero_init=True,
93
  additional_buffer_factor=1.5,
94
  precision="mixed",
95
+ model_dtype = "float16",
96
  ):
97
 
98
  total_gpus = num_nodes * num_gpus_per_node