ibrim commited on
Commit
ab422cd
1 Parent(s): 5b5b19a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -279,7 +279,7 @@ class GPT(nn.Module):
279
  print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
  # Create AdamW optimizer and use the fused version if it is available
281
  fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
- use_fused = fused_available and device_type == 'cuda'
283
  extra_args = dict(fused=True) if use_fused else dict()
284
  optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
  print(f"using fused AdamW: {use_fused}")
 
279
  print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
  # Create AdamW optimizer and use the fused version if it is available
281
  fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cpu'
283
  extra_args = dict(fused=True) if use_fused else dict()
284
  optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
  print(f"using fused AdamW: {use_fused}")