satvision-base-demo / inference.py
cssprad1's picture
initial commit
ab687e7
raw
history blame contribute delete
No virus
6.84 kB
import numpy as np
import torch
import joblib
import numpy as np
import torchvision.transforms as T
import sys
sys.path.append('pytorch-caney')
# from pytorch_caney.models.mim.mim import build_mim_model
class Transform:
"""
torchvision transform which transforms the input imagery into
addition to generating a MiM mask
"""
def __init__(self, config):
self.transform_img = \
T.Compose([
T.ToTensor(),
T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)),
])
model_patch_size = config.MODEL.SWINV2.PATCH_SIZE
self.mask_generator = SimmimMaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)
def __call__(self, img):
img = self.transform_img(img)
mask = self.mask_generator()
return img, mask
class SimmimMaskGenerator:
"""
Generates the masks for masked-image-modeling
"""
def __init__(self,
input_size=192,
mask_patch_size=32,
model_patch_size=4,
mask_ratio=0.6):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self):
mask = self.make_simmim_mask(self.token_count, self.mask_count,
self.rand_size, self.scale)
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
return mask
@staticmethod
def make_simmim_mask(token_count, mask_count, rand_size, scale):
"""JIT-compiled random mask generation
Args:
token_count
mask_count
rand_size
scale
Returns:
mask
"""
mask_idx = np.random.permutation(token_count)[:mask_count]
mask = np.zeros(token_count, dtype=np.int64)
mask[mask_idx] = 1
mask = mask.reshape((rand_size, rand_size))
return mask
class InferenceModel(object):
def __init__(self):
self.checkpoint_path = 'ckpt_epoch_800.pth'
self.config_path = 'simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav'
self.architecture_path = 'model.sav'
self.config = joblib.load(self.config_path)
self.model = joblib.load(self.architecture_path)
self.load_checkpoint()
self.transform = Transform(self.config)
def load_checkpoint(self):
checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
# re-map keys due to name change (only for loading provided models)
rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
for k in rpe_mlp_keys:
checkpoint['model'][k.replace(
'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
msg = self.model.load_state_dict(checkpoint['model'], strict=False)
print(msg)
del checkpoint
torch.cuda.empty_cache()
@staticmethod
def minmax_norm(img_arr):
arr_min = img_arr.min()
arr_max = img_arr.max()
img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min)
img_arr_scaled = img_arr_scaled * 255
img_arr_scaled = img_arr_scaled.astype(np.uint8)
return img_arr_scaled
# -------------------------------------------------------------------------
# load_selected_image
# -------------------------------------------------------------------------
def preprocess(self, image):
image, mask = self.transform(image)
image = image.unsqueeze(0)
mask = torch.tensor(mask).unsqueeze(0)
print(image.size())
print(mask.shape)
return image, mask
# -------------------------------------------------------------------------
# load_selected_image
# -------------------------------------------------------------------------
def predict(self, image, mask):
with torch.no_grad():
logits = self.model.encoder(image, mask)
image_recon = self.model.decoder(logits)
image_recon = image_recon.numpy()[0, :, :, :]
return image_recon
# -------------------------------------------------------------------------
# load_selected_image
# -------------------------------------------------------------------------
@staticmethod
def process_mask(mask):
mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()
mask = mask[0, 0, :, :]
mask = np.stack([mask, mask, mask], axis=-1)
return mask
# -------------------------------------------------------------------------
# load_selected_image
# -------------------------------------------------------------------------
def infer(self, image):
image, mask = self.preprocess(image)
img_recon = self.predict(image, mask)
mask = self.process_mask(mask)
img_normed = self.minmax_norm(image.numpy()[0, :, :, :])
print(img_normed.shape)
rgb_image = np.stack((img_normed[0, :, :],
img_normed[3, :, :],
img_normed[2, :, :]),
axis=-1)
img_recon = self.minmax_norm(img_recon)
rgb_image_recon = np.stack((img_recon[0, :, :],
img_recon[3, :, :],
img_recon[2, :, :]),
axis=-1)
rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)
rgb_image_masked = np.where(mask == 1, 0, rgb_image)
rgb_recon_masked = rgb_masked# self.minmax_norm(rgb_masked)
return rgb_image, rgb_image_masked, rgb_recon_masked
def infer(array_input: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
masked_input = np.random.rand(256, 256, 3)
output = np.random.rand(256, 256, 3)
return masked_input, output
if __name__ == '__main__':
inferenceModel = InferenceModel()
image = np.load('data/images/sv-demo-mod09ga-11.npy')
print(image.shape)
image = np.moveaxis(image, 0, 2)
print(image.shape)
inference = inferenceModel.infer(image)