Anime-GAN / app.py
Sohaib9920's picture
Uploaded files
3b72de3 verified
raw
history blame contribute delete
2.86 kB
import torch
import torch.nn as nn
import streamlit as st
import torchvision.utils as vutils
import matplotlib.pyplot as plt
class Generator(nn.Module):
def __init__(self, channels_noise, channels_img, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# Input: N x channels_noise x 1 x 1
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
nn.ConvTranspose2d(
features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
),
# Output: N x channels_img x 64 x 64
nn.Tanh(),
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
# Load the trained model
@st.cache_resource
def load_model(model_path="gan_final.pth", noise_dim=100, device="cpu"):
checkpoint = torch.load(model_path, map_location=device)
# Recreate generator model
gen = Generator(channels_noise=noise_dim, channels_img=3, features_g=64).to(device)
gen.load_state_dict(checkpoint["generator"])
gen.eval()
return gen
# Function to generate images
def generate_images(generator, num_images=1, noise_dim=100, device="cpu"):
noise = torch.randn(num_images, noise_dim, 1, 1, device=device)
with torch.no_grad():
fake_images = generator(noise).cpu()
# Denormalize from [-1,1] to [0,1]
fake_images = (fake_images * 0.5) + 0.5
return fake_images
# Streamlit UI
st.title("GAN Image Generator 🎨")
st.write("Generate images using a trained GAN model.")
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = load_model(device=device)
# User input for number of images
num_images = st.slider("Select number of images", 1, 8, 4)
# Generate button
if st.button("Generate Images"):
st.write("πŸ–ŒοΈ Generating images...")
fake_images = generate_images(generator, num_images=num_images, device=device)
# Display images
fig, ax = plt.subplots(figsize=(num_images, num_images))
ax.axis("off")
ax.imshow(vutils.make_grid(fake_images, padding=2, normalize=False).permute(1, 2, 0))
st.pyplot(fig)