zmelumian commited on
Commit
f5895e7
1 Parent(s): f2f8259

Added NIE on xla device

Browse files
xora/pipelines/pipeline_xora_video.py CHANGED
@@ -1010,7 +1010,10 @@ class XoraVideoPipeline(DiffusionPipeline):
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
- context_manager = torch.autocast("cuda" if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16)
 
 
 
1014
  else:
1015
  context_manager = nullcontext() # Dummy context manager
1016
 
 
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
+ if 'xla' in device.type:
1014
+ raise NotImplementedError("Mixed precision is not supported yet on XLA devices.")
1015
+
1016
+ context_manager = torch.autocast(device, dtype=torch.bfloat16)
1017
  else:
1018
  context_manager = nullcontext() # Dummy context manager
1019