Update model.py
Browse files
model.py
CHANGED
@@ -279,7 +279,7 @@ class GPT(nn.Module):
|
|
279 |
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
280 |
# Create AdamW optimizer and use the fused version if it is available
|
281 |
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
282 |
-
use_fused = fused_available and device_type == '
|
283 |
extra_args = dict(fused=True) if use_fused else dict()
|
284 |
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
285 |
print(f"using fused AdamW: {use_fused}")
|
|
|
279 |
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
280 |
# Create AdamW optimizer and use the fused version if it is available
|
281 |
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
282 |
+
use_fused = fused_available and device_type == 'cpu'
|
283 |
extra_args = dict(fused=True) if use_fused else dict()
|
284 |
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
285 |
print(f"using fused AdamW: {use_fused}")
|