fffiloni commited on
Commit
7eee4f5
1 Parent(s): eddc0c5

pixart special

Browse files
Files changed (1) hide show
  1. main.py +6 -2
main.py CHANGED
@@ -202,8 +202,12 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pip
202
  if args.task == "single":
203
  # Attempt to move the model to GPU if model is not Flux
204
  if args.model != "flux":
205
- if pipe.device != torch.device('cuda'):
206
- pipe.to(device, dtype)
 
 
 
 
207
  else:
208
  print(f"PIPE:{pipe}")
209
 
 
202
  if args.task == "single":
203
  # Attempt to move the model to GPU if model is not Flux
204
  if args.model != "flux":
205
+ if args.model != "pixart":
206
+ if pipe.device != torch.device('cuda'):
207
+ pipe.to(device, dtype)
208
+ else:
209
+ if pipe.device != torch.device('cuda'):
210
+ pipe.to(device)
211
  else:
212
  print(f"PIPE:{pipe}")
213