fpramunno commited on
Commit
edfd4bd
1 Parent(s): f41fc94

Create app_backup.py

Browse files
Files changed (1) hide show
  1. app_backup.py +52 -0
app_backup.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pathlib import Path
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import requests
8
+ import numpy as np
9
+
10
+ # Preprocessing
11
+ from modules import PaletteModelV2
12
+ from diffusion import Diffusion_cond
13
+
14
+
15
+ # Check for GPU availability, else use CPU
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ model = PaletteModelV2(c_in=2, c_out=1, num_classes=5, image_size=256, device=device, true_img_size=64).to(device)
19
+ ckpt = torch.load('ema_ckpt_cond.pt', map_location=torch.device(device))
20
+ model.load_state_dict(ckpt)
21
+
22
+ diffusion = Diffusion_cond(img_size=256, device=device)
23
+ model.eval()
24
+
25
+ transform_hmi = transforms.Compose([
26
+ transforms.ToTensor(),
27
+ transforms.Resize((256, 256)),
28
+ transforms.RandomVerticalFlip(p=1.0),
29
+ transforms.Normalize(mean=(0.5,), std=(0.5,))
30
+ ])
31
+
32
+ def generate_image(seed_image):
33
+ seed_image_tensor = transform_hmi(Image.open(seed_image)).reshape(1, 1, 256, 256).to(device)
34
+ generated_image = diffusion.sample(model, y=seed_image_tensor, labels=None, n=1)
35
+ # generated_image_pil = transforms.ToPILImage()(generated_image.squeeze().cpu())
36
+ img = generated_image[0].reshape(1, 256, 256).permute(1, 2, 0) # Permute dimensions to height x width x channels
37
+ img = np.squeeze(img.cpu().numpy())
38
+ v = Image.fromarray(img) # Create a PIL Image from array
39
+ v = v.transpose(Image.FLIP_TOP_BOTTOM)
40
+
41
+ return v
42
+
43
+ # Create Gradio interface
44
+ iface = gr.Interface(
45
+ fn=generate_image,
46
+ inputs="file",
47
+ outputs="image",
48
+ title="Magnetogram-to-Magnetogram: Generative Forecasting of Solar Evolution",
49
+ description="Upload a LoS magnetogram and predict how it is going to be in 24 hours."
50
+ )
51
+
52
+ iface.launch()