import numpy as np import torch import joblib import numpy as np import torchvision.transforms as T import sys sys.path.append('pytorch-caney') # from pytorch_caney.models.mim.mim import build_mim_model class Transform: """ torchvision transform which transforms the input imagery into addition to generating a MiM mask """ def __init__(self, config): self.transform_img = \ T.Compose([ T.ToTensor(), T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), ]) model_patch_size = config.MODEL.SWINV2.PATCH_SIZE self.mask_generator = SimmimMaskGenerator( input_size=config.DATA.IMG_SIZE, mask_patch_size=config.DATA.MASK_PATCH_SIZE, model_patch_size=model_patch_size, mask_ratio=config.DATA.MASK_RATIO, ) def __call__(self, img): img = self.transform_img(img) mask = self.mask_generator() return img, mask class SimmimMaskGenerator: """ Generates the masks for masked-image-modeling """ def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): self.input_size = input_size self.mask_patch_size = mask_patch_size self.model_patch_size = model_patch_size self.mask_ratio = mask_ratio assert self.input_size % self.mask_patch_size == 0 assert self.mask_patch_size % self.model_patch_size == 0 self.rand_size = self.input_size // self.mask_patch_size self.scale = self.mask_patch_size // self.model_patch_size self.token_count = self.rand_size ** 2 self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) def __call__(self): mask = self.make_simmim_mask(self.token_count, self.mask_count, self.rand_size, self.scale) mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) return mask @staticmethod def make_simmim_mask(token_count, mask_count, rand_size, scale): """JIT-compiled random mask generation Args: token_count mask_count rand_size scale Returns: mask """ mask_idx = np.random.permutation(token_count)[:mask_count] mask = np.zeros(token_count, dtype=np.int64) mask[mask_idx] = 1 mask = mask.reshape((rand_size, rand_size)) return mask class InferenceModel(object): def __init__(self): self.checkpoint_path = 'ckpt_epoch_800.pth' self.config_path = 'simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav' self.architecture_path = 'model.sav' self.config = joblib.load(self.config_path) self.model = joblib.load(self.architecture_path) self.load_checkpoint() self.transform = Transform(self.config) def load_checkpoint(self): checkpoint = torch.load(self.checkpoint_path, map_location='cpu') # re-map keys due to name change (only for loading provided models) rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] for k in rpe_mlp_keys: checkpoint['model'][k.replace( 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) msg = self.model.load_state_dict(checkpoint['model'], strict=False) print(msg) del checkpoint torch.cuda.empty_cache() @staticmethod def minmax_norm(img_arr): arr_min = img_arr.min() arr_max = img_arr.max() img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) img_arr_scaled = img_arr_scaled * 255 img_arr_scaled = img_arr_scaled.astype(np.uint8) return img_arr_scaled # ------------------------------------------------------------------------- # load_selected_image # ------------------------------------------------------------------------- def preprocess(self, image): image, mask = self.transform(image) image = image.unsqueeze(0) mask = torch.tensor(mask).unsqueeze(0) print(image.size()) print(mask.shape) return image, mask # ------------------------------------------------------------------------- # load_selected_image # ------------------------------------------------------------------------- def predict(self, image, mask): with torch.no_grad(): logits = self.model.encoder(image, mask) image_recon = self.model.decoder(logits) image_recon = image_recon.numpy()[0, :, :, :] return image_recon # ------------------------------------------------------------------------- # load_selected_image # ------------------------------------------------------------------------- @staticmethod def process_mask(mask): mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous() mask = mask[0, 0, :, :] mask = np.stack([mask, mask, mask], axis=-1) return mask # ------------------------------------------------------------------------- # load_selected_image # ------------------------------------------------------------------------- def infer(self, image): image, mask = self.preprocess(image) img_recon = self.predict(image, mask) mask = self.process_mask(mask) img_normed = self.minmax_norm(image.numpy()[0, :, :, :]) print(img_normed.shape) rgb_image = np.stack((img_normed[0, :, :], img_normed[3, :, :], img_normed[2, :, :]), axis=-1) img_recon = self.minmax_norm(img_recon) rgb_image_recon = np.stack((img_recon[0, :, :], img_recon[3, :, :], img_recon[2, :, :]), axis=-1) rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) rgb_image_masked = np.where(mask == 1, 0, rgb_image) rgb_recon_masked = rgb_masked# self.minmax_norm(rgb_masked) return rgb_image, rgb_image_masked, rgb_recon_masked def infer(array_input: np.ndarray) -> tuple[np.ndarray, np.ndarray]: masked_input = np.random.rand(256, 256, 3) output = np.random.rand(256, 256, 3) return masked_input, output if __name__ == '__main__': inferenceModel = InferenceModel() image = np.load('data/images/sv-demo-mod09ga-11.npy') print(image.shape) image = np.moveaxis(image, 0, 2) print(image.shape) inference = inferenceModel.infer(image)