import argparse def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "-b", "--batch_size", type=int, default=4, help="Batch size to process input images to events. Defaults to 4", ) parser.add_argument( "-i", "--images_paths", type=str, required=True, help="Path to a directory with image files", ) parser.add_argument( "-o", "--output_path", type=str, default=None, help="Path to a directory were events should be written. " + "Will NOT write anything to disk if this flag is not used.", ) parser.add_argument( "-s", "--save_input", action="store_true", default=False, help="Binary flag to include the input image to the model (after crop and" + " resize) in the images written or uploaded (depending on saving options.)", ) parser.add_argument( "-r", "--resume_path", type=str, default=None, help="Path to a directory containing the trainer to resume." + " In particular it must contain `opts.yam` and `checkpoints/`." + " Typically this points to a Masker, which holds the path to a" + " Painter in its opts", ) parser.add_argument( "--no_time", action="store_true", default=False, help="Binary flag to prevent the timing of operations.", ) parser.add_argument( "-f", "--flood_mask_binarization", type=float, default=0.5, help="Value to use to binarize masks (mask > value). " + "Set to -1 to use soft masks (not binarized). Defaults to 0.5.", ) parser.add_argument( "-t", "--target_size", type=int, default=640, help="Output image size (when not using `keep_ratio_128`): images are resized" + " such that their smallest side is `target_size` then cropped in the middle" + " of the largest side such that the resulting input image (and output images)" + " has height and width `target_size x target_size`. **Must** be a multiple of" + " 2^7=128 (up/downscaling inside the models). Defaults to 640.", ) parser.add_argument( "--half", action="store_true", default=False, help="Binary flag to use half precision (float16). Defaults to False.", ) parser.add_argument( "-n", "--n_images", default=-1, type=int, help="Limit the number of images processed (if you have 100 images in " + "a directory but n is 10 then only the first 10 images will be loaded" + " for processing)", ) parser.add_argument( "--no_conf", action="store_true", default=False, help="disable writing the apply_events hash and command in the output folder", ) parser.add_argument( "--overwrite", action="store_true", default=False, help="Do not check for existing outdir, i.e. force overwrite" + " potentially existing files in the output path", ) parser.add_argument( "--no_cloudy", action="store_true", default=False, help="Prevent the use of the cloudy intermediate" + " image to create the flood image. Rendering will" + " be more colorful but may seem less realistic", ) parser.add_argument( "--keep_ratio_128", action="store_true", default=False, help="When loading the input images, resize and crop them in order for their " + "dimensions to match the closest multiples" + " of 128. Will force a batch size of 1 since images" + " now have different dimensions. " + "Use --max_im_width to cap the resulting dimensions.", ) parser.add_argument( "--fuse", action="store_true", default=False, help="Use batch norm fusion to speed up inference", ) parser.add_argument( "-m", "--max_im_width", type=int, default=-1, help="When using --keep_ratio_128, some images may still be too large. Use " + "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).", ) parser.add_argument( "--upload", action="store_true", help="Upload to comet.ml in a project called `climategan-apply`", ) parser.add_argument( "--zip_outdir", "-z", action="store_true", help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'", ) return parser.parse_args() args = parse_args() print("\n• Imports\n") import time import_time = time.time() import sys import shutil from collections import OrderedDict from pathlib import Path import comet_ml # noqa: F401 import torch import numpy as np import skimage.io as io from skimage.color import rgba2rgb from skimage.transform import resize from tqdm import tqdm from climategan.trainer import Trainer from climategan.bn_fusion import bn_fuse from climategan.tutils import print_num_parameters from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve import_time = time.time() - import_time def to_m1_p1(img, i): """ rescales a [0, 1] image to [-1, +1] Args: img (np.array): float32 numpy array of an image in [0, 1] i (int): Index of the image being rescaled Raises: ValueError: If the image is not in [0, 1] Returns: np.array(np.float32): array in [-1, +1] """ if img.min() >= 0 and img.max() <= 1: return (img.astype(np.float32) - 0.5) * 2 raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})") def uint8(array): """ convert an array to np.uint8 (does not rescale or anything else than changing dtype) Args: array (np.array): array to modify Returns: np.array(np.uint8): converted array """ return array.astype(np.uint8) def resize_and_crop(img, to=640): """ Resizes an image so that it keeps the aspect ratio and the smallest dimensions is `to`, then crops this resized image in its center so that the output is `to x to` without aspect ratio distortion Args: img (np.array): np.uint8 255 image Returns: np.array: [0, 1] np.float32 image """ # resize keeping aspect ratio: smallest dim is 640 h, w = img.shape[:2] if h < w: size = (to, int(to * w / h)) else: size = (int(to * h / w), to) r_img = resize(img, size, preserve_range=True, anti_aliasing=True) r_img = uint8(r_img) # crop in the center H, W = r_img.shape[:2] top = (H - to) // 2 left = (W - to) // 2 rc_img = r_img[top : top + to, left : left + to, :] return rc_img / 255.0 def print_time(text, time_series, purge=-1): """ Print a timeseries's mean and std with a label Args: text (str): label of the time series time_series (list): list of timings purge (int, optional): ignore first n values of time series. Defaults to -1. """ if not time_series: return if purge > 0 and len(time_series) > purge: time_series = time_series[purge:] m = np.mean(time_series) s = np.std(time_series) print( f"{text.capitalize() + ' ':.<26} {m:.5f}" + (f" +/- {s:.5f}" if len(time_series) > 1 else "") ) def print_store(store, purge=-1): """ Pretty-print time series store Args: store (dict): maps string keys to lists of times purge (int, optional): ignore first n values of time series. Defaults to -1. """ singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1}) multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1}) empties = {k: v for k, v in store.items() if len(v) == 0} if empties: print("Ignoring empty stores ", ", ".join(empties.keys())) print() for k in singles: print_time(k, singles[k], purge) print() print("Unit: s/batch") for k in multiples: print_time(k, multiples[k], purge) print() def write_apply_config(out): """ Saves the args to `apply_events.py` in a text file for future reference """ cwd = Path.cwd().expanduser().resolve() command = f"cd {str(cwd)}\n" command += " ".join(sys.argv) git_hash = get_git_revision_hash() with (out / "command.txt").open("w") as f: f.write(command) with (out / "hash.txt").open("w") as f: f.write(git_hash) def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy): """ Create the output directory's name based on uer-provided arguments """ name_items = [] if half: name_items.append("half") if keep_ratio: name_items.append("AR") if max_im_width and keep_ratio: name_items.append(f"{max_im_width}") if target_size and not keep_ratio: name_items.append("S") name_items.append(f"{target_size}") if bin_value != 0.5: name_items.append(f"bin{bin_value}") if not cloudy: name_items.append("no_cloudy") return "-".join(name_items) def make_outdir( outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy ): """ Creates the output directory if it does not exist. If it does exist, prompts the user for confirmation (except if `overwrite` is True). If the output directory's name is "_auto_" then it is created as: outdir.parent / get_outdir_name(...) """ if outdir.name == "_auto_": outdir = outdir.parent / get_outdir_name( half, keep_ratio, max_im_width, target_size, bin_value, cloudy ) if outdir.exists() and not overwrite: print( f"\nWARNING: outdir ({str(outdir)}) already exists." + " Files with existing names will be overwritten" ) if "n" in input(">>> Continue anyway? [y / n] (default: y) : "): print("Interrupting execution from user input.") sys.exit() print() outdir.mkdir(exist_ok=True, parents=True) return outdir def get_time_stores(import_time): return OrderedDict( { "imports": [import_time], "setup": [], "data pre-processing": [], "encode": [], "mask": [], "flood": [], "depth": [], "segmentation": [], "smog": [], "wildfire": [], "all events": [], "numpy": [], "inference on all images": [], "write": [], } ) if __name__ == "__main__": # ----------------------------------------- # ----- Initialize script variables ----- # ----------------------------------------- print( "• Using args\n\n" + "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]), ) batch_size = args.batch_size bin_value = args.flood_mask_binarization cloudy = not args.no_cloudy fuse = args.fuse half = args.half images_paths = resolve(args.images_paths) keep_ratio = args.keep_ratio_128 max_im_width = args.max_im_width n_images = args.n_images outdir = resolve(args.output_path) if args.output_path is not None else None resume_path = args.resume_path target_size = args.target_size time_inference = not args.no_time upload = args.upload zip_outdir = args.zip_outdir # ------------------------------------- # ----- Validate size arguments ----- # ------------------------------------- if keep_ratio: if target_size != 640: print( "\nWARNING: using --keep_ratio_128 overwrites target_size" + " which is ignored." ) if batch_size != 1: print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128") batch_size = 1 if max_im_width > 0 and max_im_width % 128 != 0: new_im_width = int(max_im_width / 128) * 128 print("\nWARNING: max_im_width should be <0 or a multiple of 128.") print( " Was {} but is now overwritten to {}".format( max_im_width, new_im_width ) ) max_im_width = new_im_width else: if target_size % 128 != 0: print(f"\nWarning: target size {target_size} is not a multiple of 128.") target_size = target_size - (target_size % 128) print(f"Setting target_size to {target_size}.") # ------------------------------------- # ----- Create output directory ----- # ------------------------------------- if outdir is not None: outdir = make_outdir( outdir, args.overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy, ) # ------------------------------- # ----- Create time store ----- # ------------------------------- stores = get_time_stores(import_time) # ----------------------------------- # ----- Load Trainer instance ----- # ----------------------------------- with Timer(store=stores.get("setup", []), ignore=time_inference): print("\n• Initializing trainer\n") torch.set_grad_enabled(False) trainer = Trainer.resume_from_path( resume_path, setup=True, inference=True, new_exp=None, ) print() print_num_parameters(trainer, True) if fuse: trainer.G = bn_fuse(trainer.G) if half: trainer.G.half() # -------------------------------------------- # ----- Read data from input directory ----- # -------------------------------------------- print("\n• Reading & Pre-processing Data\n") # find all images data_paths = find_images(images_paths) base_data_paths = data_paths # filter images if 0 < n_images < len(data_paths): data_paths = data_paths[:n_images] # repeat data elif n_images > len(data_paths): repeats = n_images // len(data_paths) + 1 data_paths = base_data_paths * repeats data_paths = data_paths[:n_images] with Timer(store=stores.get("data pre-processing", []), ignore=time_inference): # read images to numpy arrays data = [io.imread(str(d)) for d in data_paths] # rgba to rgb data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data] # resize images to target_size or if keep_ratio: # to closest multiples of 128 <= max_im_width, keeping aspect ratio new_sizes = [to_128(d, max_im_width) for d in data] data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)] else: # to args.target_size data = [resize_and_crop(d, target_size) for d in data] new_sizes = [(target_size, target_size) for _ in data] # resize() produces [0, 1] images, rescale to [-1, 1] data = [to_m1_p1(d, i) for i, d in enumerate(data)] n_batchs = len(data) // batch_size if len(data) % batch_size != 0: n_batchs += 1 print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.") # -------------------------------------------- # ----- Batch-process images to events ----- # -------------------------------------------- print(f"\n• Using device {str(trainer.device)}\n") all_events = [] with Timer(store=stores.get("inference on all images", []), ignore=time_inference): for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"): images = data[b * batch_size : (b + 1) * batch_size] if not images: continue # concatenate images in a batch batch_size x height x width x 3 images = np.stack(images) # Retreive numpy events as a dict {event: array[BxHxWxC]} events = trainer.infer_all( images, numpy=True, stores=stores, bin_value=bin_value, half=half, cloudy=cloudy, ) # save resized and cropped image if args.save_input: events["input"] = uint8((images + 1) / 2 * 255) # store events to write after inference loop all_events.append(events) # -------------------------------------------- # ----- Save (write/upload) inferences ----- # -------------------------------------------- if outdir is not None or upload: if upload: print("\n• Creating comet Experiment") exp = comet_ml.Experiment(project_name="climategan-apply") exp.log_parameters(vars(args)) # -------------------------------------------------------------- # ----- Change inferred data structure to a list of dicts ----- # -------------------------------------------------------------- to_write = [] events_names = list(all_events[0].keys()) for events_data in all_events: n_ims = len(events_data[events_names[0]]) for i in range(n_ims): item = {event: events_data[event][i] for event in events_names} to_write.append(item) progress_bar_desc = "" if outdir is not None: print("\n• Output directory:\n") print(str(outdir), "\n") if upload: progress_bar_desc = "Writing & Uploading events" else: progress_bar_desc = "Writing events" else: if upload: progress_bar_desc = "Uploading events" # ------------------------------------ # ----- Save individual images ----- # ------------------------------------ with Timer(store=stores.get("write", []), ignore=time_inference): # for each image for t, event_dict in tqdm( enumerate(to_write), desc=progress_bar_desc, unit="input image", total=len(to_write), ): idx = t % len(base_data_paths) stem = Path(data_paths[idx]).stem width = new_sizes[idx][1] if keep_ratio: ar = "_AR" else: ar = "" # for each event type event_bar = tqdm( enumerate(event_dict.items()), leave=False, total=len(events_names), unit="event", ) for e, (event, im_data) in event_bar: event_bar.set_description( f" {event.capitalize():<{len(progress_bar_desc) - 2}}" ) if args.no_cloudy: suffix = ar + "_no_cloudy" else: suffix = ar im_path = Path(f"{stem}_{event}_{width}{suffix}.png") if outdir is not None: im_path = outdir / im_path io.imsave(im_path, im_data) if upload: exp.log_image(im_data, name=im_path.name) if zip_outdir: print("\n• Zipping output directory... ", end="", flush=True) archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir)) archive_path = archive_path.rename(outdir.parent / archive_path.name) print("Done:\n") print(str(archive_path)) # --------------------------- # ----- Print timings ----- # --------------------------- if time_inference: print("\n• Timings\n") print_store(stores) # --------------------------------------------- # ----- Save apply_events.py run config ----- # --------------------------------------------- if not args.no_conf and outdir is not None: write_apply_config(outdir)