shivi commited on
Commit
cacf2d0
1 Parent(s): 268e7e7

Committing all app files

Browse files
Files changed (5) hide show
  1. app.py +60 -0
  2. requirements.txt +3 -0
  3. utils/constants.py +32 -0
  4. utils/lr_schedule.py +92 -0
  5. utils/predict.py +88 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.predict import predict, predict_batch
3
+ import os
4
+
5
+ inputs_list = []
6
+
7
+
8
+ demo = gr.Blocks()
9
+
10
+ with demo:
11
+
12
+ gr.Markdown("# **<p align='center'>ShiftViT: A Vision Transformer without Attention</p>**")
13
+ gr.Markdown("This space demonstrates the use of ShiftViT proposed in the paper: <a href=\"https://arxiv.org/abs/2201.10801/\">When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism</a> for image classification task.")
14
+ gr.Markdown("Vision Transformers (ViTs) have proven to be very useful for Computer Vision tasks. Many researchers believe that the attention layer is the main reason behind the success of ViTs.")
15
+ gr.Markdown("In the ShiftViT paper, the authors have tried to show that the attention mechanism may not be vital for the success of ViTs by replacing the attention operation with a shifting operation.")
16
+
17
+ with gr.Tabs():
18
+
19
+ with gr.TabItem("Skip Uploading!"):
20
+
21
+ gr.Markdown("Just click *Run Model* below:")
22
+ with gr.Box():
23
+ gr.Markdown("**Prediction Probabilities** \n")
24
+ output_df = gr.Dataframe(headers=["image","1st_highest_probability", "2nd_highest_probability","3rd_highest_probability"],datatype=["str", "str", "str", "str"])
25
+ gr.Markdown("**Output Plot** \n")
26
+ output_plot = gr.Image(type='filepath')
27
+
28
+ gr.Markdown("**Predict**")
29
+
30
+ with gr.Box():
31
+ with gr.Row():
32
+ compute_button = gr.Button("Run Model")
33
+
34
+
35
+ with gr.TabItem("Upload & Predict"):
36
+ with gr.Box():
37
+
38
+ with gr.Row():
39
+ input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
40
+ output_label = gr.Label(label="Model", show_label=True)
41
+
42
+ gr.Markdown("**Predict**")
43
+
44
+ with gr.Box():
45
+ with gr.Row():
46
+ submit_button = gr.Button("Submit")
47
+
48
+ gr.Markdown("**Examples:**")
49
+ gr.Markdown("The model is trained to classify images belonging to the following classes:")
50
+
51
+ with gr.Column():
52
+ gr.Examples("examples/set2", [input_image], output_label, predict, cache_examples=True)
53
+
54
+
55
+ compute_button.click(predict_batch, inputs=input_image, outputs=[output_plot,output_df])
56
+ submit_button.click(predict, inputs=input_image, outputs=output_label)
57
+
58
+ gr.Markdown('\n Author: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/vision/shiftvit/\">Keras example</a> by <a href=\"https://twitter.com/ariG23498\">Aritra Roy Gosthipaty</a> and <a href=\"https://twitter.com/ritwik_raha\">Ritwik Raha</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/shiftvit/\">ShiftViT model</a>')
59
+
60
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ tensorflow==2.9.1
2
+ gradio
3
+ tensorflow-addons
utils/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Config(object):
3
+ # DATA
4
+ batch_size = 256
5
+ buffer_size = batch_size * 2
6
+ input_shape = (32, 32, 3)
7
+ num_classes = 10
8
+
9
+ # AUGMENTATION
10
+ image_size = 48
11
+
12
+ # ARCHITECTURE
13
+ patch_size = 4
14
+ projected_dim = 96
15
+ num_shift_blocks_per_stages = [2, 4, 8, 2]
16
+ epsilon = 1e-5
17
+ stochastic_depth_rate = 0.2
18
+ mlp_dropout_rate = 0.2
19
+ num_div = 12
20
+ shift_pixel = 1
21
+ mlp_expand_ratio = 2
22
+
23
+ # OPTIMIZER
24
+ lr_start = 1e-5
25
+ lr_max = 1e-3
26
+ weight_decay = 1e-4
27
+
28
+ # TRAINING
29
+ epochs = 100
30
+
31
+
32
+ class_vocab={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog",6:"frog",7:"horse",8:"ship", 9:"truck"}
utils/lr_schedule.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow import keras
3
+ import numpy as np
4
+
5
+ """
6
+ Below code is taken from the [ShiftViT keras example](https://keras.io/examples/vision/shiftvit/) by Aritra Roy Gosthipaty & Ritwik Raha
7
+ """
8
+
9
+ # Some code is taken from:
10
+ # https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
11
+ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
12
+ """A LearningRateSchedule that uses a warmup cosine decay schedule."""
13
+
14
+ def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
15
+ """
16
+ Args:
17
+ lr_start: The initial learning rate
18
+ lr_max: The maximum learning rate to which lr should increase to in
19
+ the warmup steps
20
+ warmup_steps: The number of steps for which the model warms up
21
+ total_steps: The total number of steps for the model training
22
+ """
23
+ super().__init__()
24
+ self.lr_start = lr_start
25
+ self.lr_max = lr_max
26
+ self.warmup_steps = warmup_steps
27
+ self.total_steps = total_steps
28
+ self.pi = tf.constant(np.pi)
29
+
30
+ def __call__(self, step):
31
+ # Check whether the total number of steps is larger than the warmup
32
+ # steps. If not, then throw a value error.
33
+ if self.total_steps < self.warmup_steps:
34
+ raise ValueError(
35
+ f"Total number of steps {self.total_steps} must be"
36
+ + f"larger or equal to warmup steps {self.warmup_steps}."
37
+ )
38
+
39
+ # `cos_annealed_lr` is a graph that increases to 1 from the initial
40
+ # step to the warmup step. After that this graph decays to -1 at the
41
+ # final step mark.
42
+ cos_annealed_lr = tf.cos(
43
+ self.pi
44
+ * (tf.cast(step, tf.float32) - self.warmup_steps)
45
+ / tf.cast(self.total_steps - self.warmup_steps, tf.float32)
46
+ )
47
+
48
+ # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
49
+ # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
50
+ # to 1. With the normalized graph we scale it with `lr_max` such that
51
+ # it goes from 0 to `lr_max`
52
+ learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
53
+
54
+ # Check whether warmup_steps is more than 0.
55
+ if self.warmup_steps > 0:
56
+ # Check whether lr_max is larger that lr_start. If not, throw a value
57
+ # error.
58
+ if self.lr_max < self.lr_start:
59
+ raise ValueError(
60
+ f"lr_start {self.lr_start} must be smaller or"
61
+ + f"equal to lr_max {self.lr_max}."
62
+ )
63
+
64
+ # Calculate the slope with which the learning rate should increase
65
+ # in the warumup schedule. The formula for slope is m = ((b-a)/steps)
66
+ slope = (self.lr_max - self.lr_start) / self.warmup_steps
67
+
68
+ # With the formula for a straight line (y = mx+c) build the warmup
69
+ # schedule
70
+ warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
71
+
72
+ # When the current step is lesser that warmup steps, get the line
73
+ # graph. When the current step is greater than the warmup steps, get
74
+ # the scaled cos graph.
75
+ learning_rate = tf.where(
76
+ step < self.warmup_steps, warmup_rate, learning_rate
77
+ )
78
+
79
+ # When the current step is more that the total steps, return 0 else return
80
+ # the calculated graph.
81
+ return tf.where(
82
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
83
+ )
84
+
85
+ def get_config(self):
86
+ config = {
87
+ "lr_start": self.lr_start,
88
+ "lr_max": self.lr_max,
89
+ "total_steps": self.total_steps,
90
+ 'warmup_steps': self.warmup_steps
91
+ }
92
+ return config
utils/predict.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ from huggingface_hub import from_pretrained_keras
5
+ from .lr_schedule import WarmUpCosine
6
+ from .constants import Config, class_vocab
7
+ from keras.utils import load_img, img_to_array
8
+ from tensorflow_addons.optimizers import AdamW
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import random
12
+ config = Config()
13
+
14
+ ##Load Model
15
+ model = from_pretrained_keras("shivi/shiftvit", custom_objects={"WarmUpCosine":WarmUpCosine, "AdamW": AdamW})
16
+
17
+ (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
18
+
19
+
20
+ AUTO = tf.data.AUTOTUNE
21
+
22
+ def predict(image_path):
23
+ """
24
+ This function is used for fetching predictions corresponding to input_dataframe.
25
+ It outputs another dataframe containing:
26
+ 1. prediction probability for each class
27
+ 2. actual expected outcome for each entry in the input dataframe
28
+ """
29
+
30
+ test_image1 = load_img(image_path,target_size =(32,32))
31
+ test_image = img_to_array(test_image1)
32
+ test_image = np.expand_dims(test_image, axis =0)
33
+ test_image = test_image.astype('uint8')
34
+
35
+
36
+ predict_ds = tf.data.Dataset.from_tensor_slices(test_image)
37
+ predict_ds = predict_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
38
+ logits = model.predict(predict_ds)
39
+ prob = tf.nn.softmax(logits)
40
+
41
+ confidences = {}
42
+ prob_list = prob.numpy().flatten().tolist()
43
+ sorted_prob = np.argsort(prob)[::-1].flatten()
44
+ for i in sorted_prob:
45
+ confidences[class_vocab[i]] = float(prob_list[i])
46
+
47
+ return confidences
48
+
49
+
50
+ def predict_batch(image_path):
51
+
52
+ test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
53
+ test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
54
+ slice = test_ds.take(1)
55
+
56
+ slice_pred = model.predict(slice)
57
+ slice_pred = tf.nn.softmax(slice_pred)
58
+
59
+ saved_plot = "plot.jpg"
60
+ fig = plt.figure()
61
+
62
+ predictions_df = pd.DataFrame()
63
+ num = random.randint(0,50)
64
+ for images, labels in slice:
65
+ for i,j in zip(range(num,num+3), range(3)):
66
+ ax = plt.subplot(1, 3, j + 1)
67
+ plt.imshow(images[i].numpy().astype("uint8"))
68
+ output = np.argmax(slice_pred[i])
69
+
70
+ prob_list = slice_pred[i].numpy().flatten().tolist()
71
+ sorted_prob = np.argsort(slice_pred[i])[::-1].flatten()
72
+ prob_scores = {"image": "image "+ str(j), "1st_highest_probability": f"prob of {class_vocab[sorted_prob[0]]} is {round(prob_list[sorted_prob[0]] * 100,2)} %",
73
+ "2nd_highest_probability": f"prob of {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)} %",
74
+ "3rd_highest_probability": f"prob of {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)} %"}
75
+ predictions_df = predictions_df.append(prob_scores,ignore_index=True)
76
+
77
+ plt.title(f"image {j} : {class_vocab[output]}")
78
+ plt.axis("off")
79
+ plt.savefig(saved_plot,bbox_inches='tight')
80
+
81
+ return saved_plot, predictions_df
82
+
83
+
84
+
85
+
86
+
87
+
88
+