cocktailpeanut commited on
Commit
9aa6318
1 Parent(s): 57e2fd5
Files changed (1) hide show
  1. utils/pipeline_magictime.py +9 -1
utils/pipeline_magictime.py CHANGED
@@ -27,6 +27,13 @@ from .unet import UNet3DConditionModel
27
 
28
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
 
 
 
 
 
 
 
 
30
  @dataclass
31
  class MagicTimePipelineOutput(BaseOutput):
32
  videos: Union[torch.Tensor, np.ndarray]
@@ -120,7 +127,8 @@ class MagicTimePipeline(DiffusionPipeline):
120
  else:
121
  raise ImportError("Please install accelerate via `pip install accelerate`")
122
 
123
- device = torch.device(f"cuda:{gpu_id}")
 
124
 
125
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
126
  if cpu_offloaded_model is not None:
 
27
 
28
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
 
30
+ if torch.cuda.is_available():
31
+ device = "cuda"
32
+ elif torch.backends.mps.is_available():
33
+ device = "mps"
34
+ else:
35
+ device = "cpu"
36
+
37
  @dataclass
38
  class MagicTimePipelineOutput(BaseOutput):
39
  videos: Union[torch.Tensor, np.ndarray]
 
127
  else:
128
  raise ImportError("Please install accelerate via `pip install accelerate`")
129
 
130
+ #device = torch.device(f"cuda:{gpu_id}")
131
+ device = torch.device(device)
132
 
133
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
134
  if cpu_offloaded_model is not None: