Gholamreza commited on
Commit
8895b4f
1 Parent(s): 4d341ce

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +26 -14
  2. app.py +33 -0
  3. conditional_gan.py +47 -0
  4. generated_digit.png +0 -0
  5. models.py +67 -0
README.md CHANGED
@@ -1,14 +1,26 @@
1
- ---
2
- title: Conditional GAN MNIST
3
- emoji: 💻
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.7.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: This is a simple implementation of Conditional Generative Ad
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generating MNIST digits using Conditional GAN
2
+
3
+ This is a simple implementation of Conditional Generative Adversarial Networks (GAN) for generating MNIST digits.
4
+
5
+ ![cover](demos/gen_all_digits.png)
6
+
7
+ I use simple BCE loss function for calculating the loss and Adam optimizer (lr=0.0001) for training.
8
+
9
+ ## Architecture
10
+
11
+ - The **generator** is series of Linear layers with BatchNorm and ReLU activations.
12
+ - The **discriminator** is a series of Linear layers with BatchNorm andLeakyReLU activations.
13
+ - The Conditioning class is appended to the noise vector as a one-hot vector.
14
+
15
+ ## Huggingface Space
16
+
17
+ You can try generating digits using this model on Huggingface Space.
18
+ https://huggingface.co/spaces/gholamreza/Conditional-GAN-MNIST
19
+
20
+ ![Huggingface Space](demos/gradio_app.png)
21
+
22
+ ## Training History
23
+
24
+ ![losses_plot](demos/losses.png)
25
+
26
+ Visit https://github.com/gholamrezadar/GAN-MNIST for a simpler version of this code and more details.
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from models import Generator
4
+ from conditional_gan import generate_digit
5
+
6
+ generator = Generator()
7
+
8
+ def init():
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Load the generator
12
+ generator.load_state_dict(torch.load('models/generator.pt', map_location=device))
13
+ generator.to(device)
14
+
15
+ def generate_mnist_digit(digit):
16
+ return generate_digit(generator, digit)
17
+
18
+ # Gradio Interface
19
+ def gradio_generate(digit):
20
+ return generate_mnist_digit(digit)
21
+
22
+ with gr.Blocks() as demo:
23
+ gr.Markdown("# MNIST Digit Generator")
24
+ digit = gr.Dropdown(list(range(10)), label="Select a Digit")
25
+ generate_button = gr.Button("Generate")
26
+ output_image = gr.Image(label="Generated Image", type="filepath")
27
+
28
+ generate_button.click(gradio_generate, inputs=digit, outputs=output_image)
29
+
30
+ if __name__ == '__main__':
31
+ init()
32
+ print("* Model loaded")
33
+ demo.launch()
conditional_gan.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This files serves the neccessary functions for generating images using pretrained models
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.utils import make_grid
6
+ import matplotlib.pyplot as plt
7
+
8
+ from models import get_noise
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ def display_image_grid(images, num_rows=5, title=""):
13
+ if(images.shape[-1]!=28):
14
+ images = images.view(-1, 1, 28, 28)
15
+ plt.figure(figsize=(5, 5))
16
+ plt.axis("off")
17
+ plt.title(title)
18
+ grid = make_grid(images.detach().cpu()[:25], nrow=num_rows).permute(1, 2, 0).numpy()
19
+ plt.imshow(grid)
20
+ plt.show()
21
+
22
+ def check_generation(generator):
23
+ generator.eval()
24
+ labels = torch.tensor([0,1,2,3,4,5,6,7,8,9] * 10).to(device)
25
+ fake_eval_batch = generator(get_noise(100, 10, device=device), labels).view(-1, 1, 28, 28)
26
+ grid = make_grid(fake_eval_batch.detach().cpu(), nrow=10).permute(1, 2, 0).numpy()
27
+ plt.figure(figsize=(9, 9))
28
+ plt.title("Generated Images")
29
+ plt.axis('off')
30
+ plt.xlabel("Class")
31
+ plt.imshow(grid)
32
+ plt.show()
33
+
34
+ def generate_digit(generator, digit):
35
+ generator.eval()
36
+ labels = torch.tensor([digit] * 25).to(device)
37
+ fake_eval_batch = generator(get_noise(25, 10, device=device), labels).view(-1, 1, 28, 28)
38
+ grid = make_grid(fake_eval_batch.detach().cpu(), nrow=5).permute(1, 2, 0).numpy()
39
+ plt.figure(figsize=(5, 5))
40
+ # no border
41
+ plt.axis('off')
42
+ plt.grid(False)
43
+ plt.xticks([])
44
+ plt.yticks([])
45
+ plt.imshow(grid)
46
+ plt.savefig('generated_digit.png', bbox_inches='tight', pad_inches=0) # Save the generated image
47
+ return 'generated_digit.png' # Return the image path
generated_digit.png ADDED
models.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.utils import make_grid
5
+ import matplotlib.pyplot as plt
6
+
7
+ def get_noise(n_samples, z_dim, device='cpu'):
8
+ return torch.randn((n_samples, z_dim), device=device)
9
+
10
+ def get_random_labels(n_samples, device='cpu'):
11
+ return torch.randint(0, 10, (n_samples,), device=device).type(torch.long)
12
+
13
+ def get_generator_block(input_dim, output_dim):
14
+ return nn.Sequential(
15
+ nn.Linear(input_dim, output_dim),
16
+ nn.BatchNorm1d(output_dim),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ class Generator(nn.Module):
21
+ def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
22
+ super(Generator, self).__init__()
23
+
24
+ # input is of shape (batch_size, z_dim + 10)
25
+ self.gen = nn.Sequential(
26
+ get_generator_block(z_dim + 10, hidden_dim), # 128
27
+ get_generator_block(hidden_dim, hidden_dim*2), # 256
28
+ get_generator_block(hidden_dim*2, hidden_dim*4), # 512
29
+ get_generator_block(hidden_dim*4, hidden_dim*8), # 1024
30
+ nn.Linear(hidden_dim*8, im_dim), # 784
31
+ nn.Sigmoid(), # output between 0 and 1
32
+ )
33
+
34
+ def forward(self, noise, classes):
35
+ '''
36
+ noise (batch_size, z_dim) noise vector for each image in a batch
37
+ classes:long (batch_size) condition class for each image in a batch
38
+ '''
39
+ # classes = classes.type(torch.long)
40
+ # one-hot encode condition_class e.g. 3 -> [0,0,0,1,0,0,0,0,0,0]
41
+ one_hot_vec = F.one_hot(classes, num_classes=10).type(torch.float32) # (batch_size, 10)
42
+ conditioned_noise = torch.concat((noise, one_hot_vec), dim=1) # (batch_size, z_dim + 10)
43
+ return self.gen(conditioned_noise)
44
+
45
+
46
+ def get_discriminator_block(input_dim, output_dim):
47
+ return nn.Sequential(
48
+ nn.Linear(input_dim, output_dim),
49
+ nn.LeakyReLU(0.2, inplace=True)
50
+ )
51
+
52
+ class Discriminator(nn.Module):
53
+ def __init__(self, im_dim=784, hidden_dim=128):
54
+ super(Discriminator, self).__init__()
55
+ self.disc = nn.Sequential(
56
+ get_discriminator_block(im_dim + 10, hidden_dim*4), # 512
57
+ get_discriminator_block(hidden_dim * 4, hidden_dim * 2), # 256
58
+ get_discriminator_block(hidden_dim * 2, hidden_dim), # 128
59
+ nn.Linear(hidden_dim, 1),
60
+ # nn.Sigmoid(),
61
+ # using a sigmoid followed by BCE is less numerically stable than BCEWithLogitsLoss alone
62
+ # https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html#torch.nn.BCEWithLogitsLoss:~:text=This%20loss%20combines%20a%20Sigmoid%20layer%20and%20the%20BCELoss%20in%20one%20single%20class.%20This%20version%20is%20more%20numerically%20stable%20than%20using%20a%20plain%20Sigmoid%20followed%20by%20a%20BCELoss%20as%2C%20by%20combining%20the%20operations%20into%20one%20layer%2C%20we%20take%20advantage%20of%20the%20log%2Dsum%2Dexp%20trick%20for%20numerical%20stability.
63
+ )
64
+
65
+ def forward(self, image_batch):
66
+ '''image_batch (batch_size, 784+10)'''
67
+ return self.disc(image_batch)