import os from pathlib import Path from tops.config import LazyCall as L import torch import functools from dp2.data.datasets.fdh import get_dataloader_fdh_wds from dp2.data.utils import get_coco_flipmap from dp2.data.transforms.transforms import ( Normalize, ToFloat, CreateCondition, RandomHorizontalFlip, CreateEmbedding, ) from dp2.metrics.torch_metrics import compute_metrics_iteratively from dp2.metrics.fid_clip import compute_fid_clip from dp2.metrics.ppl import calculate_ppl from .utils import train_eval_fn def final_eval_fn(*args, **kwargs): result = compute_metrics_iteratively(*args, **kwargs) result2 = calculate_ppl(*args, **kwargs, upsample_size=(288, 160)) result3 = compute_fid_clip(*args, **kwargs) assert all(key not in result for key in result2) result.update(result2) result.update(result3) return result def get_cache_directory(imsize, subset): return Path(metrics_cache, f"{subset}{imsize[0]}") dataset_base_dir = ( os.environ["BASE_DATASET_DIR"] if "BASE_DATASET_DIR" in os.environ else "data" ) metrics_cache = ( os.environ["FBA_METRICS_CACHE"] if "FBA_METRICS_CACHE" in os.environ else ".cache" ) data_dir = Path(dataset_base_dir, "fdh") data = dict( imsize=(288, 160), im_channels=3, cse_nc=16, n_keypoints=17, train=dict( loader=L(get_dataloader_fdh_wds)( path=data_dir.joinpath("train", "out-{000000..001423}.tar"), batch_size="${train.batch_size}", num_workers=6, transform=L(torch.nn.Sequential)( L(RandomHorizontalFlip)(p=0.5, flip_map=get_coco_flipmap()), ), gpu_transform=L(torch.nn.Sequential)( L(ToFloat)(norm=False, keys=["img", "mask", "E_mask", "maskrcnn_mask"]), L(CreateEmbedding)(embed_path=data_dir.joinpath("embed_map.torch")), L(Normalize)(mean=[0.5*255, 0.5*255, 0.5*255], std=[0.5*255, 0.5*255, 0.5*255], inplace=True), L(CreateCondition)(), ), infinite=True, shuffle=True, partial_batches=False, load_embedding=True, keypoints_split="train", load_new_keypoints=False ) ), val=dict( loader=L(get_dataloader_fdh_wds)( path=data_dir.joinpath("val", "out-{000000..000023}.tar"), batch_size="${train.batch_size}", num_workers=6, transform=None, gpu_transform="${data.train.loader.gpu_transform}", infinite=False, shuffle=False, partial_batches=True, load_embedding=True, keypoints_split="val", load_new_keypoints="${data.train.loader.load_new_keypoints}" ) ), # Training evaluation might do optimizations to reduce compute overhead. E.g. compute with AMP. train_evaluation_fn=L(functools.partial)( train_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh"), data_len=30_000), evaluation_fn=L(functools.partial)( final_eval_fn, cache_directory=L(get_cache_directory)(imsize="${data.imsize}", subset="fdh_eval"), data_len=30_000) )