from datetime import datetime import matplotlib.pyplot as plt import torch def freeze_module(module): for param in module.parameters(): param.requires_grad = False def get_device(): device = "cuda" if torch.cuda.is_available() else "cpu" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = "mps" if device == "mps": print( "WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch" " errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues" " with generations." ) return device def show_pil(img): fig = plt.imshow(img) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.show() def get_timestamp(): current_time = datetime.now() timestamp = current_time.strftime("%H:%M:%S") return timestamp