Spaces:
Runtime error
Runtime error
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-b", | |
"--batch_size", | |
type=int, | |
default=4, | |
help="Batch size to process input images to events. Defaults to 4", | |
) | |
parser.add_argument( | |
"-i", | |
"--images_paths", | |
type=str, | |
required=True, | |
help="Path to a directory with image files", | |
) | |
parser.add_argument( | |
"-o", | |
"--output_path", | |
type=str, | |
default=None, | |
help="Path to a directory were events should be written. " | |
+ "Will NOT write anything to disk if this flag is not used.", | |
) | |
parser.add_argument( | |
"-s", | |
"--save_input", | |
action="store_true", | |
default=False, | |
help="Binary flag to include the input image to the model (after crop and" | |
+ " resize) in the images written or uploaded (depending on saving options.)", | |
) | |
parser.add_argument( | |
"-r", | |
"--resume_path", | |
type=str, | |
default=None, | |
help="Path to a directory containing the trainer to resume." | |
+ " In particular it must contain `opts.yam` and `checkpoints/`." | |
+ " Typically this points to a Masker, which holds the path to a" | |
+ " Painter in its opts", | |
) | |
parser.add_argument( | |
"--no_time", | |
action="store_true", | |
default=False, | |
help="Binary flag to prevent the timing of operations.", | |
) | |
parser.add_argument( | |
"-f", | |
"--flood_mask_binarization", | |
type=float, | |
default=0.5, | |
help="Value to use to binarize masks (mask > value). " | |
+ "Set to -1 to use soft masks (not binarized). Defaults to 0.5.", | |
) | |
parser.add_argument( | |
"-t", | |
"--target_size", | |
type=int, | |
default=640, | |
help="Output image size (when not using `keep_ratio_128`): images are resized" | |
+ " such that their smallest side is `target_size` then cropped in the middle" | |
+ " of the largest side such that the resulting input image (and output images)" | |
+ " has height and width `target_size x target_size`. **Must** be a multiple of" | |
+ " 2^7=128 (up/downscaling inside the models). Defaults to 640.", | |
) | |
parser.add_argument( | |
"--half", | |
action="store_true", | |
default=False, | |
help="Binary flag to use half precision (float16). Defaults to False.", | |
) | |
parser.add_argument( | |
"-n", | |
"--n_images", | |
default=-1, | |
type=int, | |
help="Limit the number of images processed (if you have 100 images in " | |
+ "a directory but n is 10 then only the first 10 images will be loaded" | |
+ " for processing)", | |
) | |
parser.add_argument( | |
"--no_conf", | |
action="store_true", | |
default=False, | |
help="disable writing the apply_events hash and command in the output folder", | |
) | |
parser.add_argument( | |
"--overwrite", | |
action="store_true", | |
default=False, | |
help="Do not check for existing outdir, i.e. force overwrite" | |
+ " potentially existing files in the output path", | |
) | |
parser.add_argument( | |
"--no_cloudy", | |
action="store_true", | |
default=False, | |
help="Prevent the use of the cloudy intermediate" | |
+ " image to create the flood image. Rendering will" | |
+ " be more colorful but may seem less realistic", | |
) | |
parser.add_argument( | |
"--keep_ratio_128", | |
action="store_true", | |
default=False, | |
help="When loading the input images, resize and crop them in order for their " | |
+ "dimensions to match the closest multiples" | |
+ " of 128. Will force a batch size of 1 since images" | |
+ " now have different dimensions. " | |
+ "Use --max_im_width to cap the resulting dimensions.", | |
) | |
parser.add_argument( | |
"--fuse", | |
action="store_true", | |
default=False, | |
help="Use batch norm fusion to speed up inference", | |
) | |
parser.add_argument( | |
"-m", | |
"--max_im_width", | |
type=int, | |
default=-1, | |
help="When using --keep_ratio_128, some images may still be too large. Use " | |
+ "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).", | |
) | |
parser.add_argument( | |
"--upload", | |
action="store_true", | |
help="Upload to comet.ml in a project called `climategan-apply`", | |
) | |
parser.add_argument( | |
"--zip_outdir", | |
"-z", | |
action="store_true", | |
help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'", | |
) | |
return parser.parse_args() | |
args = parse_args() | |
print("\n• Imports\n") | |
import time | |
import_time = time.time() | |
import sys | |
import shutil | |
from collections import OrderedDict | |
from pathlib import Path | |
import comet_ml # noqa: F401 | |
import torch | |
import numpy as np | |
import skimage.io as io | |
from skimage.color import rgba2rgb | |
from skimage.transform import resize | |
from tqdm import tqdm | |
from climategan.trainer import Trainer | |
from climategan.bn_fusion import bn_fuse | |
from climategan.tutils import print_num_parameters | |
from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve | |
import_time = time.time() - import_time | |
def to_m1_p1(img, i): | |
""" | |
rescales a [0, 1] image to [-1, +1] | |
Args: | |
img (np.array): float32 numpy array of an image in [0, 1] | |
i (int): Index of the image being rescaled | |
Raises: | |
ValueError: If the image is not in [0, 1] | |
Returns: | |
np.array(np.float32): array in [-1, +1] | |
""" | |
if img.min() >= 0 and img.max() <= 1: | |
return (img.astype(np.float32) - 0.5) * 2 | |
raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})") | |
def uint8(array): | |
""" | |
convert an array to np.uint8 (does not rescale or anything else than changing dtype) | |
Args: | |
array (np.array): array to modify | |
Returns: | |
np.array(np.uint8): converted array | |
""" | |
return array.astype(np.uint8) | |
def resize_and_crop(img, to=640): | |
""" | |
Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
is `to`, then crops this resized image in its center so that the output is `to x to` | |
without aspect ratio distortion | |
Args: | |
img (np.array): np.uint8 255 image | |
Returns: | |
np.array: [0, 1] np.float32 image | |
""" | |
# resize keeping aspect ratio: smallest dim is 640 | |
h, w = img.shape[:2] | |
if h < w: | |
size = (to, int(to * w / h)) | |
else: | |
size = (int(to * h / w), to) | |
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
r_img = uint8(r_img) | |
# crop in the center | |
H, W = r_img.shape[:2] | |
top = (H - to) // 2 | |
left = (W - to) // 2 | |
rc_img = r_img[top : top + to, left : left + to, :] | |
return rc_img / 255.0 | |
def print_time(text, time_series, purge=-1): | |
""" | |
Print a timeseries's mean and std with a label | |
Args: | |
text (str): label of the time series | |
time_series (list): list of timings | |
purge (int, optional): ignore first n values of time series. Defaults to -1. | |
""" | |
if not time_series: | |
return | |
if purge > 0 and len(time_series) > purge: | |
time_series = time_series[purge:] | |
m = np.mean(time_series) | |
s = np.std(time_series) | |
print( | |
f"{text.capitalize() + ' ':.<26} {m:.5f}" | |
+ (f" +/- {s:.5f}" if len(time_series) > 1 else "") | |
) | |
def print_store(store, purge=-1): | |
""" | |
Pretty-print time series store | |
Args: | |
store (dict): maps string keys to lists of times | |
purge (int, optional): ignore first n values of time series. Defaults to -1. | |
""" | |
singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1}) | |
multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1}) | |
empties = {k: v for k, v in store.items() if len(v) == 0} | |
if empties: | |
print("Ignoring empty stores ", ", ".join(empties.keys())) | |
print() | |
for k in singles: | |
print_time(k, singles[k], purge) | |
print() | |
print("Unit: s/batch") | |
for k in multiples: | |
print_time(k, multiples[k], purge) | |
print() | |
def write_apply_config(out): | |
""" | |
Saves the args to `apply_events.py` in a text file for future reference | |
""" | |
cwd = Path.cwd().expanduser().resolve() | |
command = f"cd {str(cwd)}\n" | |
command += " ".join(sys.argv) | |
git_hash = get_git_revision_hash() | |
with (out / "command.txt").open("w") as f: | |
f.write(command) | |
with (out / "hash.txt").open("w") as f: | |
f.write(git_hash) | |
def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy): | |
""" | |
Create the output directory's name based on uer-provided arguments | |
""" | |
name_items = [] | |
if half: | |
name_items.append("half") | |
if keep_ratio: | |
name_items.append("AR") | |
if max_im_width and keep_ratio: | |
name_items.append(f"{max_im_width}") | |
if target_size and not keep_ratio: | |
name_items.append("S") | |
name_items.append(f"{target_size}") | |
if bin_value != 0.5: | |
name_items.append(f"bin{bin_value}") | |
if not cloudy: | |
name_items.append("no_cloudy") | |
return "-".join(name_items) | |
def make_outdir( | |
outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy | |
): | |
""" | |
Creates the output directory if it does not exist. If it does exist, | |
prompts the user for confirmation (except if `overwrite` is True). | |
If the output directory's name is "_auto_" then it is created as: | |
outdir.parent / get_outdir_name(...) | |
""" | |
if outdir.name == "_auto_": | |
outdir = outdir.parent / get_outdir_name( | |
half, keep_ratio, max_im_width, target_size, bin_value, cloudy | |
) | |
if outdir.exists() and not overwrite: | |
print( | |
f"\nWARNING: outdir ({str(outdir)}) already exists." | |
+ " Files with existing names will be overwritten" | |
) | |
if "n" in input(">>> Continue anyway? [y / n] (default: y) : "): | |
print("Interrupting execution from user input.") | |
sys.exit() | |
print() | |
outdir.mkdir(exist_ok=True, parents=True) | |
return outdir | |
def get_time_stores(import_time): | |
return OrderedDict( | |
{ | |
"imports": [import_time], | |
"setup": [], | |
"data pre-processing": [], | |
"encode": [], | |
"mask": [], | |
"flood": [], | |
"depth": [], | |
"segmentation": [], | |
"smog": [], | |
"wildfire": [], | |
"all events": [], | |
"numpy": [], | |
"inference on all images": [], | |
"write": [], | |
} | |
) | |
if __name__ == "__main__": | |
# ----------------------------------------- | |
# ----- Initialize script variables ----- | |
# ----------------------------------------- | |
print( | |
"• Using args\n\n" | |
+ "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]), | |
) | |
batch_size = args.batch_size | |
bin_value = args.flood_mask_binarization | |
cloudy = not args.no_cloudy | |
fuse = args.fuse | |
half = args.half | |
images_paths = resolve(args.images_paths) | |
keep_ratio = args.keep_ratio_128 | |
max_im_width = args.max_im_width | |
n_images = args.n_images | |
outdir = resolve(args.output_path) if args.output_path is not None else None | |
resume_path = args.resume_path | |
target_size = args.target_size | |
time_inference = not args.no_time | |
upload = args.upload | |
zip_outdir = args.zip_outdir | |
# ------------------------------------- | |
# ----- Validate size arguments ----- | |
# ------------------------------------- | |
if keep_ratio: | |
if target_size != 640: | |
print( | |
"\nWARNING: using --keep_ratio_128 overwrites target_size" | |
+ " which is ignored." | |
) | |
if batch_size != 1: | |
print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128") | |
batch_size = 1 | |
if max_im_width > 0 and max_im_width % 128 != 0: | |
new_im_width = int(max_im_width / 128) * 128 | |
print("\nWARNING: max_im_width should be <0 or a multiple of 128.") | |
print( | |
" Was {} but is now overwritten to {}".format( | |
max_im_width, new_im_width | |
) | |
) | |
max_im_width = new_im_width | |
else: | |
if target_size % 128 != 0: | |
print(f"\nWarning: target size {target_size} is not a multiple of 128.") | |
target_size = target_size - (target_size % 128) | |
print(f"Setting target_size to {target_size}.") | |
# ------------------------------------- | |
# ----- Create output directory ----- | |
# ------------------------------------- | |
if outdir is not None: | |
outdir = make_outdir( | |
outdir, | |
args.overwrite, | |
half, | |
keep_ratio, | |
max_im_width, | |
target_size, | |
bin_value, | |
cloudy, | |
) | |
# ------------------------------- | |
# ----- Create time store ----- | |
# ------------------------------- | |
stores = get_time_stores(import_time) | |
# ----------------------------------- | |
# ----- Load Trainer instance ----- | |
# ----------------------------------- | |
with Timer(store=stores.get("setup", []), ignore=time_inference): | |
print("\n• Initializing trainer\n") | |
torch.set_grad_enabled(False) | |
trainer = Trainer.resume_from_path( | |
resume_path, | |
setup=True, | |
inference=True, | |
new_exp=None, | |
) | |
print() | |
print_num_parameters(trainer, True) | |
if fuse: | |
trainer.G = bn_fuse(trainer.G) | |
if half: | |
trainer.G.half() | |
# -------------------------------------------- | |
# ----- Read data from input directory ----- | |
# -------------------------------------------- | |
print("\n• Reading & Pre-processing Data\n") | |
# find all images | |
data_paths = find_images(images_paths) | |
base_data_paths = data_paths | |
# filter images | |
if 0 < n_images < len(data_paths): | |
data_paths = data_paths[:n_images] | |
# repeat data | |
elif n_images > len(data_paths): | |
repeats = n_images // len(data_paths) + 1 | |
data_paths = base_data_paths * repeats | |
data_paths = data_paths[:n_images] | |
with Timer(store=stores.get("data pre-processing", []), ignore=time_inference): | |
# read images to numpy arrays | |
data = [io.imread(str(d)) for d in data_paths] | |
# rgba to rgb | |
data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data] | |
# resize images to target_size or | |
if keep_ratio: | |
# to closest multiples of 128 <= max_im_width, keeping aspect ratio | |
new_sizes = [to_128(d, max_im_width) for d in data] | |
data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)] | |
else: | |
# to args.target_size | |
data = [resize_and_crop(d, target_size) for d in data] | |
new_sizes = [(target_size, target_size) for _ in data] | |
# resize() produces [0, 1] images, rescale to [-1, 1] | |
data = [to_m1_p1(d, i) for i, d in enumerate(data)] | |
n_batchs = len(data) // batch_size | |
if len(data) % batch_size != 0: | |
n_batchs += 1 | |
print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.") | |
# -------------------------------------------- | |
# ----- Batch-process images to events ----- | |
# -------------------------------------------- | |
print(f"\n• Using device {str(trainer.device)}\n") | |
all_events = [] | |
with Timer(store=stores.get("inference on all images", []), ignore=time_inference): | |
for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"): | |
images = data[b * batch_size : (b + 1) * batch_size] | |
if not images: | |
continue | |
# concatenate images in a batch batch_size x height x width x 3 | |
images = np.stack(images) | |
# Retreive numpy events as a dict {event: array[BxHxWxC]} | |
events = trainer.infer_all( | |
images, | |
numpy=True, | |
stores=stores, | |
bin_value=bin_value, | |
half=half, | |
cloudy=cloudy, | |
) | |
# save resized and cropped image | |
if args.save_input: | |
events["input"] = uint8((images + 1) / 2 * 255) | |
# store events to write after inference loop | |
all_events.append(events) | |
# -------------------------------------------- | |
# ----- Save (write/upload) inferences ----- | |
# -------------------------------------------- | |
if outdir is not None or upload: | |
if upload: | |
print("\n• Creating comet Experiment") | |
exp = comet_ml.Experiment(project_name="climategan-apply") | |
exp.log_parameters(vars(args)) | |
# -------------------------------------------------------------- | |
# ----- Change inferred data structure to a list of dicts ----- | |
# -------------------------------------------------------------- | |
to_write = [] | |
events_names = list(all_events[0].keys()) | |
for events_data in all_events: | |
n_ims = len(events_data[events_names[0]]) | |
for i in range(n_ims): | |
item = {event: events_data[event][i] for event in events_names} | |
to_write.append(item) | |
progress_bar_desc = "" | |
if outdir is not None: | |
print("\n• Output directory:\n") | |
print(str(outdir), "\n") | |
if upload: | |
progress_bar_desc = "Writing & Uploading events" | |
else: | |
progress_bar_desc = "Writing events" | |
else: | |
if upload: | |
progress_bar_desc = "Uploading events" | |
# ------------------------------------ | |
# ----- Save individual images ----- | |
# ------------------------------------ | |
with Timer(store=stores.get("write", []), ignore=time_inference): | |
# for each image | |
for t, event_dict in tqdm( | |
enumerate(to_write), | |
desc=progress_bar_desc, | |
unit="input image", | |
total=len(to_write), | |
): | |
idx = t % len(base_data_paths) | |
stem = Path(data_paths[idx]).stem | |
width = new_sizes[idx][1] | |
if keep_ratio: | |
ar = "_AR" | |
else: | |
ar = "" | |
# for each event type | |
event_bar = tqdm( | |
enumerate(event_dict.items()), | |
leave=False, | |
total=len(events_names), | |
unit="event", | |
) | |
for e, (event, im_data) in event_bar: | |
event_bar.set_description( | |
f" {event.capitalize():<{len(progress_bar_desc) - 2}}" | |
) | |
if args.no_cloudy: | |
suffix = ar + "_no_cloudy" | |
else: | |
suffix = ar | |
im_path = Path(f"{stem}_{event}_{width}{suffix}.png") | |
if outdir is not None: | |
im_path = outdir / im_path | |
io.imsave(im_path, im_data) | |
if upload: | |
exp.log_image(im_data, name=im_path.name) | |
if zip_outdir: | |
print("\n• Zipping output directory... ", end="", flush=True) | |
archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir)) | |
archive_path = archive_path.rename(outdir.parent / archive_path.name) | |
print("Done:\n") | |
print(str(archive_path)) | |
# --------------------------- | |
# ----- Print timings ----- | |
# --------------------------- | |
if time_inference: | |
print("\n• Timings\n") | |
print_store(stores) | |
# --------------------------------------------- | |
# ----- Save apply_events.py run config ----- | |
# --------------------------------------------- | |
if not args.no_conf and outdir is not None: | |
write_apply_config(outdir) | |