from typing import List, Tuple from einops import rearrange import numpy as np, os, torch from PIL import Image from torchvision.utils import make_grid import time def get_output_folder(output_path, batch_folder): out_path = os.path.join(output_path,time.strftime('%Y-%m')) if batch_folder != "": out_path = os.path.join(out_path, batch_folder) os.makedirs(out_path, exist_ok=True) return out_path def save_samples( args, x_samples: torch.Tensor, seed: int, n_rows: int ) -> Tuple[Image.Image, List[Image.Image]]: """Function to save samples to disk. Args: args: Stable deforum diffusion arguments. x_samples: Samples to save. seed: Seed for the experiment. n_rows: Number of rows in the grid. Returns: A tuple of the grid image and a list of the generated images. ( grid_image, generated_images ) """ # save samples images = [] grid_image = None if args.display_samples or args.save_samples: for index, x_sample in enumerate(x_samples): x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") images.append(Image.fromarray(x_sample.astype(np.uint8))) if args.save_samples: images[-1].save( os.path.join( args.outdir, f"{args.timestring}_{index:02}_{seed}.png" ) ) # save grid if args.display_grid or args.save_grid: grid = torch.stack([x_samples], 0) grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows, padding=0) # to image grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid_image = Image.fromarray(grid.astype(np.uint8)) if args.save_grid: grid_image.save( os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png") ) # return grid_image and individual sample images return grid_image, images def save_image(image, image_type, filename, args, video_args, root): if video_args.store_frames_in_ram: root.frames_cache.append({'path':os.path.join(args.outdir, filename), 'image':image, 'image_type':image_type}) else: image.save(os.path.join(args.outdir, filename)) import cv2, gc def reset_frames_cache(root): root.frames_cache = [] gc.collect() def dump_frames_cache(root): for image_cache in root.frames_cache: if image_cache['image_type'] == 'cv2': cv2.imwrite(image_cache['path'], image_cache['image']) elif image_cache['image_type'] == 'PIL': image_cache['image'].save(image_cache['path']) # do not reset the cache since we're going to add frame erasing later function #TODO