Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import streamlit as st | |
| import torchvision.utils as vutils | |
| import matplotlib.pyplot as plt | |
| class Generator(nn.Module): | |
| def __init__(self, channels_noise, channels_img, features_g): | |
| super(Generator, self).__init__() | |
| self.net = nn.Sequential( | |
| # Input: N x channels_noise x 1 x 1 | |
| self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4 | |
| self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8 | |
| self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16 | |
| self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32 | |
| nn.ConvTranspose2d( | |
| features_g * 2, channels_img, kernel_size=4, stride=2, padding=1 | |
| ), | |
| # Output: N x channels_img x 64 x 64 | |
| nn.Tanh(), | |
| ) | |
| def _block(self, in_channels, out_channels, kernel_size, stride, padding): | |
| return nn.Sequential( | |
| nn.ConvTranspose2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # Load the trained model | |
| def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"): | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Recreate generator model | |
| gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device) | |
| gen.load_state_dict(checkpoint["generator"]) | |
| gen.eval() | |
| return gen | |
| # Function to generate images | |
| def generate_images(generator, num_images=1, noise_dim=100, device="cpu"): | |
| noise = torch.randn(num_images, noise_dim, 1, 1, device=device) | |
| with torch.no_grad(): | |
| fake_images = generator(noise).cpu() | |
| # Denormalize from [-1,1] to [0,1] | |
| fake_images = (fake_images * 0.5) + 0.5 | |
| return fake_images | |
| # Streamlit UI | |
| st.title("GAN Image Generator π¨") | |
| st.write("Generate images using a trained GAN model.") | |
| # Load the model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| generator = load_model(device=device) | |
| # User input for number of images | |
| num_images = st.slider("Select number of images", 1, 8, 4) | |
| # Generate button | |
| if st.button("Generate Images"): | |
| st.write("ποΈ Generating images...") | |
| fake_images = generate_images(generator, num_images=num_images, device=device) | |
| # Display images | |
| fig, ax = plt.subplots(figsize=(num_images, num_images)) | |
| ax.axis("off") | |
| ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0)) | |
| st.pyplot(fig) | |