zmelumian commited on
Commit
85a3cf8
1 Parent(s): ba73063

added cuda as optional

Browse files
inference.py CHANGED
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
55
  vae = CausalVideoAutoencoder.from_config(vae_config)
56
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
  vae.load_state_dict(vae_state_dict)
58
- return vae.cuda().to(torch.bfloat16)
 
 
59
 
60
 
61
  def load_unet(unet_dir):
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
65
  transformer = Transformer3DModel.from_config(transformer_config)
66
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
67
  transformer.load_state_dict(unet_state_dict, strict=True)
68
- return transformer.cuda()
 
 
69
 
70
 
71
  def load_scheduler(scheduler_dir):
@@ -254,7 +258,9 @@ def main():
254
  patchifier = SymmetricPatchifier(patch_size=1)
255
  text_encoder = T5EncoderModel.from_pretrained(
256
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
257
- ).to("cuda")
 
 
258
  tokenizer = T5Tokenizer.from_pretrained(
259
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
260
  )
@@ -272,7 +278,9 @@ def main():
272
  "vae": vae,
273
  }
274
 
275
- pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
 
 
276
 
277
  # Prepare input for the pipeline
278
  sample = {
@@ -286,8 +294,10 @@ def main():
286
  random.seed(args.seed)
287
  np.random.seed(args.seed)
288
  torch.manual_seed(args.seed)
289
- torch.cuda.manual_seed(args.seed)
290
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
 
 
291
 
292
  images = pipeline(
293
  num_inference_steps=args.num_inference_steps,
 
55
  vae = CausalVideoAutoencoder.from_config(vae_config)
56
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
  vae.load_state_dict(vae_state_dict)
58
+ if torch.cuda.is_available():
59
+ vae = vae.cuda()
60
+ return vae.to(torch.bfloat16)
61
 
62
 
63
  def load_unet(unet_dir):
 
67
  transformer = Transformer3DModel.from_config(transformer_config)
68
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
69
  transformer.load_state_dict(unet_state_dict, strict=True)
70
+ if torch.cuda.is_available():
71
+ transformer = transformer.cuda()
72
+ return transformer
73
 
74
 
75
  def load_scheduler(scheduler_dir):
 
258
  patchifier = SymmetricPatchifier(patch_size=1)
259
  text_encoder = T5EncoderModel.from_pretrained(
260
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
261
+ )
262
+ if torch.cuda.is_available():
263
+ text_encoder = text_encoder.to("cuda")
264
  tokenizer = T5Tokenizer.from_pretrained(
265
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
266
  )
 
278
  "vae": vae,
279
  }
280
 
281
+ pipeline = XoraVideoPipeline(**submodel_dict)
282
+ if torch.cuda.is_available():
283
+ pipeline = pipeline.to("cuda")
284
 
285
  # Prepare input for the pipeline
286
  sample = {
 
294
  random.seed(args.seed)
295
  np.random.seed(args.seed)
296
  torch.manual_seed(args.seed)
297
+ if torch.cuda.is_available():
298
+ torch.cuda.manual_seed(args.seed)
299
+
300
+ generator = torch.Generator(device="cuda" if torch.cuda.is_available() else 'cpu').manual_seed(args.seed)
301
 
302
  images = pipeline(
303
  num_inference_steps=args.num_inference_steps,
xora/pipelines/pipeline_xora_video.py CHANGED
@@ -1010,7 +1010,7 @@ class XoraVideoPipeline(DiffusionPipeline):
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
- context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
1014
  else:
1015
  context_manager = nullcontext() # Dummy context manager
1016
 
 
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
+ context_manager = torch.autocast("cuda" if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16)
1014
  else:
1015
  context_manager = nullcontext() # Dummy context manager
1016