import tensorflow as tf import torch def load_cesbio_sr() -> tf.function: """Prepare the CESBIO model Returns: tf.function: A tf.function to get the SR image """ # read the model model = tf.saved_model.load("sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel") # get the signature signature = list(model.signatures.keys())[0] # get the function func = model.signatures[signature] return func def run_sr4rs( model: tf.function, lr: tf.Tensor, hr: tf.Tensor, ) -> dict: """Run the SR4RS model Args: model (tf.function): The model to use lr (tf.Tensor): The low resolution image hr (tf.Tensor): The high resolution image cropsize (int, optional): The cropsize. Defaults to 32. overlap (int, optional): The overlap. Defaults to 0. Returns: dict: The results """ # Run inference Xnp = torch.from_numpy(lr[[3, 2, 1, 7]][None]).permute(0, 2, 3, 1) Xtf = tf.convert_to_tensor(Xnp, dtype=tf.float32) pred = model(Xtf) # Save the results pred_np = pred['output_32:0'].numpy() pred_torch = torch.from_numpy(pred_np).permute(0, 3, 1, 2) pred_torch_padded = torch.nn.functional.pad( pred_torch, (32, 32, 32, 32), mode='constant', value=0, ).squeeze().numpy().astype('uint16') results = { "lr": lr[[3, 2, 1]], "sr": pred_torch_padded[0:3], "hr": hr[0:3], } return results