ran the ruff formatter
Browse files- inference.py +12 -6
- xora/pipelines/pipeline_xora_video.py +5 -3
inference.py
CHANGED
@@ -240,10 +240,14 @@ def main():
|
|
240 |
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
241 |
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
242 |
assert (
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
247 |
|
248 |
# Paths for the separate mode directories
|
249 |
ckpt_dir = Path(args.ckpt_dir)
|
@@ -296,8 +300,10 @@ def main():
|
|
296 |
torch.manual_seed(args.seed)
|
297 |
if torch.cuda.is_available():
|
298 |
torch.cuda.manual_seed(args.seed)
|
299 |
-
|
300 |
-
generator = torch.Generator(
|
|
|
|
|
301 |
|
302 |
images = pipeline(
|
303 |
num_inference_steps=args.num_inference_steps,
|
|
|
240 |
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
241 |
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
242 |
assert (
|
243 |
+
(
|
244 |
+
height,
|
245 |
+
width,
|
246 |
+
args.num_frames,
|
247 |
+
)
|
248 |
+
in RECOMMENDED_RESOLUTIONS
|
249 |
+
or args.custom_resolution
|
250 |
+
), f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
|
251 |
|
252 |
# Paths for the separate mode directories
|
253 |
ckpt_dir = Path(args.ckpt_dir)
|
|
|
300 |
torch.manual_seed(args.seed)
|
301 |
if torch.cuda.is_available():
|
302 |
torch.cuda.manual_seed(args.seed)
|
303 |
+
|
304 |
+
generator = torch.Generator(
|
305 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
306 |
+
).manual_seed(args.seed)
|
307 |
|
308 |
images = pipeline(
|
309 |
num_inference_steps=args.num_inference_steps,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
@@ -1010,9 +1010,11 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
-
if
|
1014 |
-
raise NotImplementedError(
|
1015 |
-
|
|
|
|
|
1016 |
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
1017 |
else:
|
1018 |
context_manager = nullcontext() # Dummy context manager
|
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
+
if "xla" in device.type:
|
1014 |
+
raise NotImplementedError(
|
1015 |
+
"Mixed precision is not supported yet on XLA devices."
|
1016 |
+
)
|
1017 |
+
|
1018 |
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
1019 |
else:
|
1020 |
context_manager = nullcontext() # Dummy context manager
|