added cuda as optional
Browse files- inference.py +16 -6
- xora/pipelines/pipeline_xora_video.py +1 -1
inference.py
CHANGED
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
|
|
55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
57 |
vae.load_state_dict(vae_state_dict)
|
58 |
-
|
|
|
|
|
59 |
|
60 |
|
61 |
def load_unet(unet_dir):
|
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
|
|
65 |
transformer = Transformer3DModel.from_config(transformer_config)
|
66 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
67 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
68 |
-
|
|
|
|
|
69 |
|
70 |
|
71 |
def load_scheduler(scheduler_dir):
|
@@ -254,7 +258,9 @@ def main():
|
|
254 |
patchifier = SymmetricPatchifier(patch_size=1)
|
255 |
text_encoder = T5EncoderModel.from_pretrained(
|
256 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
257 |
-
)
|
|
|
|
|
258 |
tokenizer = T5Tokenizer.from_pretrained(
|
259 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
260 |
)
|
@@ -272,7 +278,9 @@ def main():
|
|
272 |
"vae": vae,
|
273 |
}
|
274 |
|
275 |
-
pipeline = XoraVideoPipeline(**submodel_dict)
|
|
|
|
|
276 |
|
277 |
# Prepare input for the pipeline
|
278 |
sample = {
|
@@ -286,8 +294,10 @@ def main():
|
|
286 |
random.seed(args.seed)
|
287 |
np.random.seed(args.seed)
|
288 |
torch.manual_seed(args.seed)
|
289 |
-
torch.cuda.
|
290 |
-
|
|
|
|
|
291 |
|
292 |
images = pipeline(
|
293 |
num_inference_steps=args.num_inference_steps,
|
|
|
55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
57 |
vae.load_state_dict(vae_state_dict)
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
vae = vae.cuda()
|
60 |
+
return vae.to(torch.bfloat16)
|
61 |
|
62 |
|
63 |
def load_unet(unet_dir):
|
|
|
67 |
transformer = Transformer3DModel.from_config(transformer_config)
|
68 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
69 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
transformer = transformer.cuda()
|
72 |
+
return transformer
|
73 |
|
74 |
|
75 |
def load_scheduler(scheduler_dir):
|
|
|
258 |
patchifier = SymmetricPatchifier(patch_size=1)
|
259 |
text_encoder = T5EncoderModel.from_pretrained(
|
260 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
261 |
+
)
|
262 |
+
if torch.cuda.is_available():
|
263 |
+
text_encoder = text_encoder.to("cuda")
|
264 |
tokenizer = T5Tokenizer.from_pretrained(
|
265 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
266 |
)
|
|
|
278 |
"vae": vae,
|
279 |
}
|
280 |
|
281 |
+
pipeline = XoraVideoPipeline(**submodel_dict)
|
282 |
+
if torch.cuda.is_available():
|
283 |
+
pipeline = pipeline.to("cuda")
|
284 |
|
285 |
# Prepare input for the pipeline
|
286 |
sample = {
|
|
|
294 |
random.seed(args.seed)
|
295 |
np.random.seed(args.seed)
|
296 |
torch.manual_seed(args.seed)
|
297 |
+
if torch.cuda.is_available():
|
298 |
+
torch.cuda.manual_seed(args.seed)
|
299 |
+
|
300 |
+
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else 'cpu').manual_seed(args.seed)
|
301 |
|
302 |
images = pipeline(
|
303 |
num_inference_steps=args.num_inference_steps,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
@@ -1010,7 +1010,7 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
-
context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
|
1014 |
else:
|
1015 |
context_manager = nullcontext() # Dummy context manager
|
1016 |
|
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
+
context_manager = torch.autocast("cuda" if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16)
|
1014 |
else:
|
1015 |
context_manager = nullcontext() # Dummy context manager
|
1016 |
|