|
import numpy as np |
|
|
|
import torch |
|
import joblib |
|
import numpy as np |
|
|
|
import torchvision.transforms as T |
|
import sys |
|
|
|
sys.path.append('pytorch-caney') |
|
|
|
|
|
|
|
class Transform: |
|
""" |
|
torchvision transform which transforms the input imagery into |
|
addition to generating a MiM mask |
|
""" |
|
|
|
def __init__(self, config): |
|
|
|
self.transform_img = \ |
|
T.Compose([ |
|
T.ToTensor(), |
|
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), |
|
]) |
|
|
|
model_patch_size = config.MODEL.SWINV2.PATCH_SIZE |
|
|
|
self.mask_generator = SimmimMaskGenerator( |
|
input_size=config.DATA.IMG_SIZE, |
|
mask_patch_size=config.DATA.MASK_PATCH_SIZE, |
|
model_patch_size=model_patch_size, |
|
mask_ratio=config.DATA.MASK_RATIO, |
|
) |
|
|
|
def __call__(self, img): |
|
|
|
img = self.transform_img(img) |
|
mask = self.mask_generator() |
|
|
|
return img, mask |
|
|
|
|
|
class SimmimMaskGenerator: |
|
""" |
|
Generates the masks for masked-image-modeling |
|
""" |
|
def __init__(self, |
|
input_size=192, |
|
mask_patch_size=32, |
|
model_patch_size=4, |
|
mask_ratio=0.6): |
|
self.input_size = input_size |
|
self.mask_patch_size = mask_patch_size |
|
self.model_patch_size = model_patch_size |
|
self.mask_ratio = mask_ratio |
|
|
|
assert self.input_size % self.mask_patch_size == 0 |
|
assert self.mask_patch_size % self.model_patch_size == 0 |
|
|
|
self.rand_size = self.input_size // self.mask_patch_size |
|
self.scale = self.mask_patch_size // self.model_patch_size |
|
|
|
self.token_count = self.rand_size ** 2 |
|
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) |
|
|
|
def __call__(self): |
|
mask = self.make_simmim_mask(self.token_count, self.mask_count, |
|
self.rand_size, self.scale) |
|
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) |
|
return mask |
|
|
|
@staticmethod |
|
def make_simmim_mask(token_count, mask_count, rand_size, scale): |
|
"""JIT-compiled random mask generation |
|
|
|
Args: |
|
token_count |
|
mask_count |
|
rand_size |
|
scale |
|
|
|
Returns: |
|
mask |
|
""" |
|
mask_idx = np.random.permutation(token_count)[:mask_count] |
|
mask = np.zeros(token_count, dtype=np.int64) |
|
mask[mask_idx] = 1 |
|
mask = mask.reshape((rand_size, rand_size)) |
|
return mask |
|
|
|
|
|
class InferenceModel(object): |
|
|
|
def __init__(self): |
|
self.checkpoint_path = 'ckpt_epoch_800.pth' |
|
self.config_path = 'simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav' |
|
self.architecture_path = 'model.sav' |
|
|
|
self.config = joblib.load(self.config_path) |
|
self.model = joblib.load(self.architecture_path) |
|
self.load_checkpoint() |
|
|
|
self.transform = Transform(self.config) |
|
|
|
|
|
def load_checkpoint(self): |
|
|
|
|
|
checkpoint = torch.load(self.checkpoint_path, map_location='cpu') |
|
|
|
|
|
rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] |
|
|
|
for k in rpe_mlp_keys: |
|
|
|
checkpoint['model'][k.replace( |
|
'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) |
|
|
|
msg = self.model.load_state_dict(checkpoint['model'], strict=False) |
|
|
|
print(msg) |
|
|
|
del checkpoint |
|
|
|
torch.cuda.empty_cache() |
|
|
|
@staticmethod |
|
def minmax_norm(img_arr): |
|
arr_min = img_arr.min() |
|
arr_max = img_arr.max() |
|
img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) |
|
img_arr_scaled = img_arr_scaled * 255 |
|
img_arr_scaled = img_arr_scaled.astype(np.uint8) |
|
return img_arr_scaled |
|
|
|
|
|
|
|
|
|
def preprocess(self, image): |
|
|
|
image, mask = self.transform(image) |
|
|
|
image = image.unsqueeze(0) |
|
|
|
mask = torch.tensor(mask).unsqueeze(0) |
|
|
|
print(image.size()) |
|
print(mask.shape) |
|
|
|
return image, mask |
|
|
|
|
|
|
|
|
|
def predict(self, image, mask): |
|
|
|
with torch.no_grad(): |
|
|
|
logits = self.model.encoder(image, mask) |
|
|
|
image_recon = self.model.decoder(logits) |
|
|
|
image_recon = image_recon.numpy()[0, :, :, :] |
|
|
|
return image_recon |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def process_mask(mask): |
|
mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous() |
|
mask = mask[0, 0, :, :] |
|
mask = np.stack([mask, mask, mask], axis=-1) |
|
return mask |
|
|
|
|
|
|
|
|
|
def infer(self, image): |
|
|
|
image, mask = self.preprocess(image) |
|
|
|
img_recon = self.predict(image, mask) |
|
|
|
mask = self.process_mask(mask) |
|
|
|
img_normed = self.minmax_norm(image.numpy()[0, :, :, :]) |
|
|
|
print(img_normed.shape) |
|
rgb_image = np.stack((img_normed[0, :, :], |
|
img_normed[3, :, :], |
|
img_normed[2, :, :]), |
|
axis=-1) |
|
|
|
img_recon = self.minmax_norm(img_recon) |
|
rgb_image_recon = np.stack((img_recon[0, :, :], |
|
img_recon[3, :, :], |
|
img_recon[2, :, :]), |
|
axis=-1) |
|
|
|
rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) |
|
rgb_image_masked = np.where(mask == 1, 0, rgb_image) |
|
rgb_recon_masked = rgb_masked |
|
|
|
return rgb_image, rgb_image_masked, rgb_recon_masked |
|
|
|
|
|
def infer(array_input: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
|
|
|
masked_input = np.random.rand(256, 256, 3) |
|
|
|
output = np.random.rand(256, 256, 3) |
|
|
|
return masked_input, output |
|
|
|
if __name__ == '__main__': |
|
inferenceModel = InferenceModel() |
|
|
|
image = np.load('data/images/sv-demo-mod09ga-11.npy') |
|
print(image.shape) |
|
image = np.moveaxis(image, 0, 2) |
|
print(image.shape) |
|
|
|
inference = inferenceModel.infer(image) |