File size: 4,012 Bytes
0f1af34 be9674f 13afb1c 0f1af34 ff34208 05493b2 0f1af34 a5a784e 4eeda6c 0f1af34 0210cac 0f1af34 8576ba9 0f1af34 8576ba9 0f1af34 ff34208 9f24831 ff34208 8576ba9 78f112e ff34208 78f112e ff34208 eb240c8 0210cac ff34208 0f1af34 cf7ac47 0f1af34 cf7ac47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
from astropy.io import fits
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;">
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1>
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;">
</div>'''
# Check for GPU availability, else use CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
from torchvision import transforms
# Define a custom transform to clamp data
class ClampTransform(object):
def __init__(self, min_value=-250, max_value=250):
self.min_value = min_value
self.max_value = max_value
def __call__(self, tensor):
return torch.clamp(tensor, self.min_value, self.max_value)
transform_hmi_jp2 = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
transform_hmi_fits = transforms.Compose([
transforms.ToTensor(),
ClampTransform(-250, 250),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
_, file_ext = os.path.splitext(seed_image)
if file_ext.lower() == '.jp2':
input_img = Image.open(seed_image)
input_img_pil = transform_hmi_jp2(input_img).reshape(1, 1, 256, 256).to(device)
elif file_ext.lower() == '.fits':
with fits.open(seed_image) as hdul:
data = hdul[0].data
input_img_pil = transform_hmi_fits(data).reshape(1, 1, 256, 256).to(device)
else:
print(f'Format {file_ext.lower()} not supported')
generated_image = diffusion.sample(model, y=input_img_pil, labels=None, n=1)
inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0)
inp_img = np.squeeze(inp_img.cpu().numpy())
inp = Image.fromarray(inp_img) # Create a PIL Image from array
inp = inp.transpose(Image.FLIP_TOP_BOTTOM)
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return inp, v
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input_image = gr.File(label='Input Image')
output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False)
output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False)
# Buttons are placed in a nested Row inside the main Row to align them directly under the image
with gr.Row():
clear_button = gr.Button('Clear')
process_button = gr.Button('Generate')
# Binding the process button to the function
process_button.click(
fn=generate_image,
inputs=input_image,
outputs=[output_image1, output_image2]
)
# Clear button to reset the input image
clear_button.click(
fn=lambda: None, # Clears the input
inputs=None,
outputs=input_image
)
demo.launch()
|