mag2mag / app.py
fpramunno's picture
Update app.py
eb240c8 verified
raw
history blame
1.62 kB
import gradio as gr
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import requests
import numpy as np
# Preprocessing
from modules import PaletteModelV2
from diffusion import Diffusion_cond
device = 'cuda'
model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, true_img_size=64).to(device)
ckpt = torch.load('ema_ckpt_cond.pt')
model.load_state_dict(ckpt)
diffusion = Diffusion_cond(img_size=256, device=device)
model.eval()
transform_hmi = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)),
transforms.RandomVerticalFlip(p=1.0),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
def generate_image(seed_image):
seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
# generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
img = np.squeeze(img.cpu().numpy())
v = Image.fromarray(img) # Create a PIL Image from array
v = v.transpose(Image.FLIP_TOP_BOTTOM)
return v
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs="file",
outputs="image",
title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
)
iface.launch()