File size: 2,863 Bytes
3b72de3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import streamlit as st
import torchvision.utils as vutils
import matplotlib.pyplot as plt

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)
    

# Load the trained model
@st.cache_resource
def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"):
    checkpoint = torch.load(model_path, map_location=device)
    
    # Recreate generator model
    gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device)
    gen.load_state_dict(checkpoint["generator"])
    gen.eval()
    
    return gen

# Function to generate images
def generate_images(generator, num_images=1, noise_dim=100, device="cpu"):
    noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
    with torch.no_grad():
        fake_images = generator(noise).cpu()
    
    # Denormalize from [-1,1] to [0,1]
    fake_images = (fake_images * 0.5) + 0.5  

    return fake_images

# Streamlit UI
st.title("GAN Image Generator 🎨")
st.write("Generate images using a trained GAN model.")

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = load_model(device=device)

# User input for number of images
num_images = st.slider("Select number of images", 1, 8, 4)

# Generate button
if st.button("Generate Images"):
    st.write("🖌️ Generating images...")
    fake_images = generate_images(generator, num_images=num_images, device=device)

    # Display images
    fig, ax = plt.subplots(figsize=(num_images, num_images))
    ax.axis("off")
    ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0))
    st.pyplot(fig)