Update src/parallelism_utils.py
Browse files- src/parallelism_utils.py +6 -12
src/parallelism_utils.py
CHANGED
@@ -9,10 +9,10 @@ def get_precision_fac(precision: str):
|
|
9 |
raise ValueError("Precision must be either 'mixed' or 'single'")
|
10 |
|
11 |
|
12 |
-
def get_params_fac(model_dtype:
|
13 |
-
if model_dtype ==
|
14 |
return 2
|
15 |
-
elif model_dtype ==
|
16 |
return 4
|
17 |
else:
|
18 |
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
|
@@ -29,19 +29,13 @@ FP32_PARAM_FACTOR = 4
|
|
29 |
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
|
30 |
|
31 |
|
32 |
-
# TODO: check if params_fac is needed during full fp32 training.
|
33 |
-
# Normally, mixed precision training results in 1.5x memory compared to FP32.
|
34 |
-
# Currently, we are assuming 2x memory for FP32, as deepspeed's ZeRO-2 is optimized for FP16 training.
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
def estimate_zero1_model_states_mem_needs(total_params,
|
39 |
num_gpus_per_node=1,
|
40 |
num_nodes=1,
|
41 |
cpu_offload=True,
|
42 |
additional_buffer_factor=1.5,
|
43 |
precision="mixed",
|
44 |
-
model_dtype =
|
45 |
):
|
46 |
|
47 |
total_gpus = num_nodes * num_gpus_per_node
|
@@ -68,7 +62,7 @@ def estimate_zero2_model_states_mem_needs(total_params,
|
|
68 |
cpu_offload=True,
|
69 |
additional_buffer_factor=1.5,
|
70 |
precision="mixed",
|
71 |
-
model_dtype =
|
72 |
):
|
73 |
|
74 |
total_gpus = num_nodes * num_gpus_per_node
|
@@ -98,7 +92,7 @@ def estimate_zero3_model_states_mem_needs(total_params,
|
|
98 |
zero_init=True,
|
99 |
additional_buffer_factor=1.5,
|
100 |
precision="mixed",
|
101 |
-
model_dtype =
|
102 |
):
|
103 |
|
104 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
9 |
raise ValueError("Precision must be either 'mixed' or 'single'")
|
10 |
|
11 |
|
12 |
+
def get_params_fac(model_dtype: str):
|
13 |
+
if model_dtype == "float16":
|
14 |
return 2
|
15 |
+
elif model_dtype == "float32":
|
16 |
return 4
|
17 |
else:
|
18 |
raise ValueError("Model dtype must be either torch.float16 or torch.float32")
|
|
|
29 |
MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def estimate_zero1_model_states_mem_needs(total_params,
|
33 |
num_gpus_per_node=1,
|
34 |
num_nodes=1,
|
35 |
cpu_offload=True,
|
36 |
additional_buffer_factor=1.5,
|
37 |
precision="mixed",
|
38 |
+
model_dtype = "float16",
|
39 |
):
|
40 |
|
41 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
62 |
cpu_offload=True,
|
63 |
additional_buffer_factor=1.5,
|
64 |
precision="mixed",
|
65 |
+
model_dtype = "float16",
|
66 |
):
|
67 |
|
68 |
total_gpus = num_nodes * num_gpus_per_node
|
|
|
92 |
zero_init=True,
|
93 |
additional_buffer_factor=1.5,
|
94 |
precision="mixed",
|
95 |
+
model_dtype = "float16",
|
96 |
):
|
97 |
|
98 |
total_gpus = num_nodes * num_gpus_per_node
|