rajrathi commited on
Commit
b65e8c0
1 Parent(s): 3045e78

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ import matplotlib.pyplot as plt
5
+ from math import sqrt, ceil
6
+
7
+ from huggingface_hub import from_pretrained_keras
8
+
9
+ import numpy as np
10
+
11
+
12
+ model = from_pretrained_keras("keras-io/conditional-gan")
13
+
14
+ latent_dim = 128
15
+
16
+ def generate_latent_points(digit, latent_dim, n_samples, n_classes=10):
17
+ # generate points in the latent space
18
+ random_latent_vectors = tf.random.normal(shape=(n_samples, latent_dim))
19
+ labels = tf.keras.utils.to_categorical([digit for _ in range(n_samples)], n_classes)
20
+ return tf.concat([random_latent_vectors, labels], 1)
21
+
22
+ def create_digit_samples(digit, n_samples, latent_dim=latent_dim):
23
+ random_vector_labels = generate_latent_points(digit, latent_dim, n_samples)
24
+ examples = cgan_generator.predict(random_vector_labels)
25
+ examples = examples * 255.0
26
+ size = ceil(sqrt(n_samples))
27
+ digit_images = np.zeros((28*size, 28*size))
28
+ n = 0
29
+ for i in range(size):
30
+ for j in range(size):
31
+ if n == n_samples:
32
+ break
33
+ digit_images[i* 28 : (i+1)*28, j*28 : (j+1)*28] = examples[n, :, :, 0]
34
+ n += 1
35
+
36
+ return digit_images
37
+
38
+ description = "This model is based on the example created here: https://keras.io/examples/generative/conditional_gan/"
39
+
40
+ title = "Conditional GAN for MNIST"
41
+
42
+ examples = [[1, 10], [3, 5], [5, 15]]
43
+
44
+
45
+ iface = gr.Interface(
46
+ fn = create_digit_samples,
47
+ inputs = ["number", "number"],
48
+ outputs = [gradio.outputs.Image(invert_colors=True, type="numpy", label="Samples for given digit")],
49
+ examples = examples,
50
+ description = description,
51
+ title = title
52
+ )
53
+
54
+ iface.launch()