Spaces:
Running
Running
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import argparse | |
import binascii | |
import os | |
import os.path as osp | |
import imageio | |
import torch | |
import torchvision | |
__all__ = ['cache_video', 'cache_image', 'str2bool'] | |
def rand_name(length=8, suffix=''): | |
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') | |
if suffix: | |
if not suffix.startswith('.'): | |
suffix = '.' + suffix | |
name += suffix | |
return name | |
def cache_video(tensor, | |
save_file=None, | |
fps=30, | |
suffix='.mp4', | |
nrow=8, | |
normalize=True, | |
value_range=(-1, 1), | |
retry=5): | |
# cache file | |
cache_file = osp.join('/tmp', rand_name( | |
suffix=suffix)) if save_file is None else save_file | |
# save to cache | |
error = None | |
for _ in range(retry): | |
try: | |
# preprocess | |
tensor = tensor.clamp(min(value_range), max(value_range)) | |
tensor = torch.stack([ | |
torchvision.utils.make_grid( | |
u, nrow=nrow, normalize=normalize, value_range=value_range) | |
for u in tensor.unbind(2) | |
], | |
dim=1).permute(1, 2, 3, 0) | |
tensor = (tensor * 255).type(torch.uint8).cpu() | |
# write video | |
writer = imageio.get_writer( | |
cache_file, fps=fps, codec='libx264', quality=8) | |
for frame in tensor.numpy(): | |
writer.append_data(frame) | |
writer.close() | |
return cache_file | |
except Exception as e: | |
error = e | |
continue | |
else: | |
print(f'cache_video failed, error: {error}', flush=True) | |
return None | |
def cache_image(tensor, | |
save_file, | |
nrow=8, | |
normalize=True, | |
value_range=(-1, 1), | |
retry=5): | |
# cache file | |
suffix = osp.splitext(save_file)[1] | |
if suffix.lower() not in [ | |
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' | |
]: | |
suffix = '.png' | |
# save to cache | |
error = None | |
for _ in range(retry): | |
try: | |
tensor = tensor.clamp(min(value_range), max(value_range)) | |
torchvision.utils.save_image( | |
tensor, | |
save_file, | |
nrow=nrow, | |
normalize=normalize, | |
value_range=value_range) | |
return save_file | |
except Exception as e: | |
error = e | |
continue | |
def str2bool(v): | |
""" | |
Convert a string to a boolean. | |
Supported true values: 'yes', 'true', 't', 'y', '1' | |
Supported false values: 'no', 'false', 'f', 'n', '0' | |
Args: | |
v (str): String to convert. | |
Returns: | |
bool: Converted boolean value. | |
Raises: | |
argparse.ArgumentTypeError: If the value cannot be converted to boolean. | |
""" | |
if isinstance(v, bool): | |
return v | |
v_lower = v.lower() | |
if v_lower in ('yes', 'true', 't', 'y', '1'): | |
return True | |
elif v_lower in ('no', 'false', 'f', 'n', '0'): | |
return False | |
else: | |
raise argparse.ArgumentTypeError('Boolean value expected (True/False)') | |