import torch import data_utils as du def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"): coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) two_channel_image = du.complex_to_two_channel_image(coil_complex_image) two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device) autoencoder = autoencoder.to(device) with torch.no_grad(): autoencoder_output = autoencoder.encode(two_channel_tensor) latents = autoencoder_output.latent_dist.mean decoded_image = autoencoder.decode(latents).sample recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy()) input = coil_complex_image return input, recon def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"): coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) three_channel_image = du.create_three_channel_image(coil_complex_image) three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device) autoencoder = autoencoder.to(device) with torch.no_grad(): autoencoder_output = autoencoder.encode(three_channel_tensor) latents = autoencoder_output.latent_dist.mean decoded_image = autoencoder.decode(latents).sample recon = decoded_image[0].detach().cpu().numpy() input = three_channel_image return input, recon