import torch import torch.nn as nn import streamlit as st import torchvision.utils as vutils import matplotlib.pyplot as plt class Generator(nn.Module): def __init__(self, channels_noise, channels_img, features_g): super(Generator, self).__init__() self.net = nn.Sequential( # Input: N x channels_noise x 1 x 1 self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4 self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8 self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16 self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32 nn.ConvTranspose2d( features_g * 2, channels_img, kernel_size=4, stride=2, padding=1 ), # Output: N x channels_img x 64 x 64 nn.Tanh(), ) def _block(self, in_channels, out_channels, kernel_size, stride, padding): return nn.Sequential( nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride, padding, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(), ) def forward(self, x): return self.net(x) # Load the trained model @st.cache_resource def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"): checkpoint = torch.load(model_path, map_location=device) # Recreate generator model gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device) gen.load_state_dict(checkpoint["generator"]) gen.eval() return gen # Function to generate images def generate_images(generator, num_images=1, noise_dim=100, device="cpu"): noise = torch.randn(num_images, noise_dim, 1, 1, device=device) with torch.no_grad(): fake_images = generator(noise).cpu() # Denormalize from [-1,1] to [0,1] fake_images = (fake_images * 0.5) + 0.5 return fake_images # Streamlit UI st.title("GAN Image Generator 🎨") st.write("Generate images using a trained GAN model.") # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" generator = load_model(device=device) # User input for number of images num_images = st.slider("Select number of images", 1, 8, 4) # Generate button if st.button("Generate Images"): st.write("🖌️ Generating images...") fake_images = generate_images(generator, num_images=num_images, device=device) # Display images fig, ax = plt.subplots(figsize=(num_images, num_images)) ax.axis("off") ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0)) st.pyplot(fig)