mri-autoencoder-v0.1 / inference.py
pidajay's picture
Commited model weights and demo code
83e314f
import torch
import data_utils as du
def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"):
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
two_channel_image = du.complex_to_two_channel_image(coil_complex_image)
two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device)
autoencoder = autoencoder.to(device)
with torch.no_grad():
autoencoder_output = autoencoder.encode(two_channel_tensor)
latents = autoencoder_output.latent_dist.mean
decoded_image = autoencoder.decode(latents).sample
recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy())
input = coil_complex_image
return input, recon
def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"):
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
three_channel_image = du.create_three_channel_image(coil_complex_image)
three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device)
autoencoder = autoencoder.to(device)
with torch.no_grad():
autoencoder_output = autoencoder.encode(three_channel_tensor)
latents = autoencoder_output.latent_dist.mean
decoded_image = autoencoder.decode(latents).sample
recon = decoded_image[0].detach().cpu().numpy()
input = three_channel_image
return input, recon