BioMike commited on
Commit
59e7820
1 Parent(s): dfe21b4

Update vae.py

Browse files
Files changed (1) hide show
  1. vae.py +59 -58
vae.py CHANGED
@@ -1,58 +1,59 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import transforms
4
- from PIL import Image
5
- import gradio as gr
6
- import numpy as np
7
-
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- transform1 = transforms.Compose([
11
- transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
12
- transforms.ToTensor(),
13
- transforms.Normalize((0.5,), (0.5,))
14
- ])
15
-
16
- transform2 = transforms.Compose([
17
- transforms.Resize((512, 512)) # Resize the image to 512x512 for display
18
- ])
19
-
20
- def load_image(image):
21
- image = Image.fromarray(image).convert('RGB')
22
- image = transform1(image)
23
- return image.unsqueeze(0).to(device)
24
-
25
- def infer_image(image, noise_level):
26
- image = load_image(image)
27
- with torch.no_grad():
28
- mu, logvar = model.encode(image)
29
- std = torch.exp(0.5 * logvar)
30
- eps = torch.randn_like(std) * noise_level
31
- z = mu + eps * std
32
- decoded_image = model.decode(z)
33
-
34
- decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
35
- decoded_image = np.clip(decoded_image, 0, 1)
36
-
37
- decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
38
- decoded_image = transform2(decoded_image)
39
- return np.array(decoded_image)
40
-
41
- examples = [
42
- ["example_images/image1.jpg", 0.1],
43
- ["example_images/image2.png", 0.5],
44
- ["example_images/image3.jpg", 1.0],
45
- ]
46
-
47
- with gr.Blocks() as vae:
48
- noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
49
- with gr.Row():
50
- with gr.Column():
51
- input_image = gr.Image(label="Upload an image", type="numpy")
52
- with gr.Column():
53
- output_image = gr.Image(label="Reconstructed Image")
54
-
55
- input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
56
- noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
57
-
58
- gr.Examples(examples=examples, inputs=[input_image, noise_slider])
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import numpy as np
7
+ from model import model
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ transform1 = transforms.Compose([
12
+ transforms.Resize((128, 128)), # Resize the image to 128x128 for the model
13
+ transforms.ToTensor(),
14
+ transforms.Normalize((0.5,), (0.5,))
15
+ ])
16
+
17
+ transform2 = transforms.Compose([
18
+ transforms.Resize((512, 512)) # Resize the image to 512x512 for display
19
+ ])
20
+
21
+ def load_image(image):
22
+ image = Image.fromarray(image).convert('RGB')
23
+ image = transform1(image)
24
+ return image.unsqueeze(0).to(device)
25
+
26
+ def infer_image(image, noise_level):
27
+ image = load_image(image)
28
+ with torch.no_grad():
29
+ mu, logvar = model.encode(image)
30
+ std = torch.exp(0.5 * logvar)
31
+ eps = torch.randn_like(std) * noise_level
32
+ z = mu + eps * std
33
+ decoded_image = model.decode(z)
34
+
35
+ decoded_image = decoded_image.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.float32) * 0.5 + 0.5
36
+ decoded_image = np.clip(decoded_image, 0, 1)
37
+
38
+ decoded_image = Image.fromarray((decoded_image * 255).astype(np.uint8))
39
+ decoded_image = transform2(decoded_image)
40
+ return np.array(decoded_image)
41
+
42
+ examples = [
43
+ ["example_images/image1.jpg", 0.1],
44
+ ["example_images/image2.png", 0.5],
45
+ ["example_images/image3.jpg", 1.0],
46
+ ]
47
+
48
+ with gr.Blocks() as vae:
49
+ noise_slider = gr.Slider(0, 10, value=0.01, step=0.01, label="Noise Level")
50
+ with gr.Row():
51
+ with gr.Column():
52
+ input_image = gr.Image(label="Upload an image", type="numpy")
53
+ with gr.Column():
54
+ output_image = gr.Image(label="Reconstructed Image")
55
+
56
+ input_image.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
57
+ noise_slider.change(fn=infer_image, inputs=[input_image, noise_slider], outputs=output_image)
58
+
59
+ gr.Examples(examples=examples, inputs=[input_image, noise_slider])