|
import gradio as gr |
|
from pathlib import Path |
|
import os |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
import requests |
|
import numpy as np |
|
|
|
|
|
from modules import PaletteModelV2 |
|
from diffusion import Diffusion_cond |
|
|
|
DESCRIPTION = ''' |
|
<div style="display: flex; justify-content: center; align-items: center; flex-direction: column; font-size: 36px; margin-top: 20px;"> |
|
<h1><a href="https://github.com/fpramunno/MAG2MAG" target="_blank" style="color: black; text-decoration: none;">MAG2MAG</a></h1> |
|
<img src="https://raw.githubusercontent.com/fpramunno/MAG2MAG/main/pred.png" alt="teaser" style="width: 100%; max-width: 800px; height: auto;"> |
|
</div>''' |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device) |
|
ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device)) |
|
model.load_state_dict(ckpt) |
|
|
|
diffusion = Diffusion_cond(img_size=256, device=device) |
|
model.eval() |
|
|
|
transform_hmi = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((256, 256)), |
|
transforms.RandomVerticalFlip(p=1.0), |
|
transforms.Normalize(mean=(0.5,), std=(0.5,)) |
|
]) |
|
|
|
def generate_image(seed_image): |
|
_, file_ext = os.path.splitext(seed_image) |
|
|
|
if file_ext.lower() == '.jp2': |
|
input_img = Image.open(seed_image) |
|
input_img_pil = transform_hmi(input_img).reshape(1, 1, 256, 256).to(device) |
|
elif file_ext.lower() == '.fits': |
|
with fits.open(seed_image) as hdul: |
|
data = hdul[0].data |
|
|
|
input_img_pil = transform_hmi(data).reshape(1, 1, 256, 256).to(device) |
|
|
|
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1) |
|
|
|
inp_img = seed_image_tensor.reshape(1, 256, 256).permute(1, 2, 0) |
|
inp_img = np.squeeze(inp_img.cpu().numpy()) |
|
inp = Image.fromarray(inp_img) |
|
inp = inp.transpose(Image.FLIP_TOP_BOTTOM) |
|
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) |
|
img = np.squeeze(img.cpu().numpy()) |
|
v = Image.fromarray(img) |
|
v = v.transpose(Image.FLIP_TOP_BOTTOM) |
|
|
|
return inp, v |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Row(): |
|
input_image = gr.File(label='Input Image') |
|
output_image1 = gr.Image(label='Input LoS Magnetogram', type='pil', interactive=False) |
|
output_image2 = gr.Image(label='Predicted LoS Magnetogram in 24 hours', type='pil', interactive=False) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button('Clear') |
|
process_button = gr.Button('Generate') |
|
|
|
|
|
|
|
process_button.click( |
|
fn=generate_image, |
|
inputs=input_image, |
|
outputs=[output_image1, output_image2] |
|
) |
|
|
|
|
|
clear_button.click( |
|
fn=lambda: None, |
|
inputs=None, |
|
outputs=input_image |
|
) |
|
|
|
demo.launch() |
|
|