Spaces:
Runtime error
Runtime error
Committing all app files
Browse files- app.py +60 -0
- requirements.txt +3 -0
- utils/constants.py +32 -0
- utils/lr_schedule.py +92 -0
- utils/predict.py +88 -0
app.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.predict import predict, predict_batch
|
3 |
+
import os
|
4 |
+
|
5 |
+
inputs_list = []
|
6 |
+
|
7 |
+
|
8 |
+
demo = gr.Blocks()
|
9 |
+
|
10 |
+
with demo:
|
11 |
+
|
12 |
+
gr.Markdown("# **<p align='center'>ShiftViT: A Vision Transformer without Attention</p>**")
|
13 |
+
gr.Markdown("This space demonstrates the use of ShiftViT proposed in the paper: <a href=\"https://arxiv.org/abs/2201.10801/\">When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism</a> for image classification task.")
|
14 |
+
gr.Markdown("Vision Transformers (ViTs) have proven to be very useful for Computer Vision tasks. Many researchers believe that the attention layer is the main reason behind the success of ViTs.")
|
15 |
+
gr.Markdown("In the ShiftViT paper, the authors have tried to show that the attention mechanism may not be vital for the success of ViTs by replacing the attention operation with a shifting operation.")
|
16 |
+
|
17 |
+
with gr.Tabs():
|
18 |
+
|
19 |
+
with gr.TabItem("Skip Uploading!"):
|
20 |
+
|
21 |
+
gr.Markdown("Just click *Run Model* below:")
|
22 |
+
with gr.Box():
|
23 |
+
gr.Markdown("**Prediction Probabilities** \n")
|
24 |
+
output_df = gr.Dataframe(headers=["image","1st_highest_probability", "2nd_highest_probability","3rd_highest_probability"],datatype=["str", "str", "str", "str"])
|
25 |
+
gr.Markdown("**Output Plot** \n")
|
26 |
+
output_plot = gr.Image(type='filepath')
|
27 |
+
|
28 |
+
gr.Markdown("**Predict**")
|
29 |
+
|
30 |
+
with gr.Box():
|
31 |
+
with gr.Row():
|
32 |
+
compute_button = gr.Button("Run Model")
|
33 |
+
|
34 |
+
|
35 |
+
with gr.TabItem("Upload & Predict"):
|
36 |
+
with gr.Box():
|
37 |
+
|
38 |
+
with gr.Row():
|
39 |
+
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
40 |
+
output_label = gr.Label(label="Model", show_label=True)
|
41 |
+
|
42 |
+
gr.Markdown("**Predict**")
|
43 |
+
|
44 |
+
with gr.Box():
|
45 |
+
with gr.Row():
|
46 |
+
submit_button = gr.Button("Submit")
|
47 |
+
|
48 |
+
gr.Markdown("**Examples:**")
|
49 |
+
gr.Markdown("The model is trained to classify images belonging to the following classes:")
|
50 |
+
|
51 |
+
with gr.Column():
|
52 |
+
gr.Examples("examples/set2", [input_image], output_label, predict, cache_examples=True)
|
53 |
+
|
54 |
+
|
55 |
+
compute_button.click(predict_batch, inputs=input_image, outputs=[output_plot,output_df])
|
56 |
+
submit_button.click(predict, inputs=input_image, outputs=output_label)
|
57 |
+
|
58 |
+
gr.Markdown('\n Author: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/vision/shiftvit/\">Keras example</a> by <a href=\"https://twitter.com/ariG23498\">Aritra Roy Gosthipaty</a> and <a href=\"https://twitter.com/ritwik_raha\">Ritwik Raha</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/shiftvit/\">ShiftViT model</a>')
|
59 |
+
|
60 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow==2.9.1
|
2 |
+
gradio
|
3 |
+
tensorflow-addons
|
utils/constants.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Config(object):
|
3 |
+
# DATA
|
4 |
+
batch_size = 256
|
5 |
+
buffer_size = batch_size * 2
|
6 |
+
input_shape = (32, 32, 3)
|
7 |
+
num_classes = 10
|
8 |
+
|
9 |
+
# AUGMENTATION
|
10 |
+
image_size = 48
|
11 |
+
|
12 |
+
# ARCHITECTURE
|
13 |
+
patch_size = 4
|
14 |
+
projected_dim = 96
|
15 |
+
num_shift_blocks_per_stages = [2, 4, 8, 2]
|
16 |
+
epsilon = 1e-5
|
17 |
+
stochastic_depth_rate = 0.2
|
18 |
+
mlp_dropout_rate = 0.2
|
19 |
+
num_div = 12
|
20 |
+
shift_pixel = 1
|
21 |
+
mlp_expand_ratio = 2
|
22 |
+
|
23 |
+
# OPTIMIZER
|
24 |
+
lr_start = 1e-5
|
25 |
+
lr_max = 1e-3
|
26 |
+
weight_decay = 1e-4
|
27 |
+
|
28 |
+
# TRAINING
|
29 |
+
epochs = 100
|
30 |
+
|
31 |
+
|
32 |
+
class_vocab={0:"airplane",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog",6:"frog",7:"horse",8:"ship", 9:"truck"}
|
utils/lr_schedule.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow import keras
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
"""
|
6 |
+
Below code is taken from the [ShiftViT keras example](https://keras.io/examples/vision/shiftvit/) by Aritra Roy Gosthipaty & Ritwik Raha
|
7 |
+
"""
|
8 |
+
|
9 |
+
# Some code is taken from:
|
10 |
+
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
|
11 |
+
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
|
12 |
+
"""A LearningRateSchedule that uses a warmup cosine decay schedule."""
|
13 |
+
|
14 |
+
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
lr_start: The initial learning rate
|
18 |
+
lr_max: The maximum learning rate to which lr should increase to in
|
19 |
+
the warmup steps
|
20 |
+
warmup_steps: The number of steps for which the model warms up
|
21 |
+
total_steps: The total number of steps for the model training
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
self.lr_start = lr_start
|
25 |
+
self.lr_max = lr_max
|
26 |
+
self.warmup_steps = warmup_steps
|
27 |
+
self.total_steps = total_steps
|
28 |
+
self.pi = tf.constant(np.pi)
|
29 |
+
|
30 |
+
def __call__(self, step):
|
31 |
+
# Check whether the total number of steps is larger than the warmup
|
32 |
+
# steps. If not, then throw a value error.
|
33 |
+
if self.total_steps < self.warmup_steps:
|
34 |
+
raise ValueError(
|
35 |
+
f"Total number of steps {self.total_steps} must be"
|
36 |
+
+ f"larger or equal to warmup steps {self.warmup_steps}."
|
37 |
+
)
|
38 |
+
|
39 |
+
# `cos_annealed_lr` is a graph that increases to 1 from the initial
|
40 |
+
# step to the warmup step. After that this graph decays to -1 at the
|
41 |
+
# final step mark.
|
42 |
+
cos_annealed_lr = tf.cos(
|
43 |
+
self.pi
|
44 |
+
* (tf.cast(step, tf.float32) - self.warmup_steps)
|
45 |
+
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
|
46 |
+
)
|
47 |
+
|
48 |
+
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
|
49 |
+
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
|
50 |
+
# to 1. With the normalized graph we scale it with `lr_max` such that
|
51 |
+
# it goes from 0 to `lr_max`
|
52 |
+
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
|
53 |
+
|
54 |
+
# Check whether warmup_steps is more than 0.
|
55 |
+
if self.warmup_steps > 0:
|
56 |
+
# Check whether lr_max is larger that lr_start. If not, throw a value
|
57 |
+
# error.
|
58 |
+
if self.lr_max < self.lr_start:
|
59 |
+
raise ValueError(
|
60 |
+
f"lr_start {self.lr_start} must be smaller or"
|
61 |
+
+ f"equal to lr_max {self.lr_max}."
|
62 |
+
)
|
63 |
+
|
64 |
+
# Calculate the slope with which the learning rate should increase
|
65 |
+
# in the warumup schedule. The formula for slope is m = ((b-a)/steps)
|
66 |
+
slope = (self.lr_max - self.lr_start) / self.warmup_steps
|
67 |
+
|
68 |
+
# With the formula for a straight line (y = mx+c) build the warmup
|
69 |
+
# schedule
|
70 |
+
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
|
71 |
+
|
72 |
+
# When the current step is lesser that warmup steps, get the line
|
73 |
+
# graph. When the current step is greater than the warmup steps, get
|
74 |
+
# the scaled cos graph.
|
75 |
+
learning_rate = tf.where(
|
76 |
+
step < self.warmup_steps, warmup_rate, learning_rate
|
77 |
+
)
|
78 |
+
|
79 |
+
# When the current step is more that the total steps, return 0 else return
|
80 |
+
# the calculated graph.
|
81 |
+
return tf.where(
|
82 |
+
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
|
83 |
+
)
|
84 |
+
|
85 |
+
def get_config(self):
|
86 |
+
config = {
|
87 |
+
"lr_start": self.lr_start,
|
88 |
+
"lr_max": self.lr_max,
|
89 |
+
"total_steps": self.total_steps,
|
90 |
+
'warmup_steps': self.warmup_steps
|
91 |
+
}
|
92 |
+
return config
|
utils/predict.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow import keras
|
4 |
+
from huggingface_hub import from_pretrained_keras
|
5 |
+
from .lr_schedule import WarmUpCosine
|
6 |
+
from .constants import Config, class_vocab
|
7 |
+
from keras.utils import load_img, img_to_array
|
8 |
+
from tensorflow_addons.optimizers import AdamW
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import pandas as pd
|
11 |
+
import random
|
12 |
+
config = Config()
|
13 |
+
|
14 |
+
##Load Model
|
15 |
+
model = from_pretrained_keras("shivi/shiftvit", custom_objects={"WarmUpCosine":WarmUpCosine, "AdamW": AdamW})
|
16 |
+
|
17 |
+
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
|
18 |
+
|
19 |
+
|
20 |
+
AUTO = tf.data.AUTOTUNE
|
21 |
+
|
22 |
+
def predict(image_path):
|
23 |
+
"""
|
24 |
+
This function is used for fetching predictions corresponding to input_dataframe.
|
25 |
+
It outputs another dataframe containing:
|
26 |
+
1. prediction probability for each class
|
27 |
+
2. actual expected outcome for each entry in the input dataframe
|
28 |
+
"""
|
29 |
+
|
30 |
+
test_image1 = load_img(image_path,target_size =(32,32))
|
31 |
+
test_image = img_to_array(test_image1)
|
32 |
+
test_image = np.expand_dims(test_image, axis =0)
|
33 |
+
test_image = test_image.astype('uint8')
|
34 |
+
|
35 |
+
|
36 |
+
predict_ds = tf.data.Dataset.from_tensor_slices(test_image)
|
37 |
+
predict_ds = predict_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
|
38 |
+
logits = model.predict(predict_ds)
|
39 |
+
prob = tf.nn.softmax(logits)
|
40 |
+
|
41 |
+
confidences = {}
|
42 |
+
prob_list = prob.numpy().flatten().tolist()
|
43 |
+
sorted_prob = np.argsort(prob)[::-1].flatten()
|
44 |
+
for i in sorted_prob:
|
45 |
+
confidences[class_vocab[i]] = float(prob_list[i])
|
46 |
+
|
47 |
+
return confidences
|
48 |
+
|
49 |
+
|
50 |
+
def predict_batch(image_path):
|
51 |
+
|
52 |
+
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
53 |
+
test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
|
54 |
+
slice = test_ds.take(1)
|
55 |
+
|
56 |
+
slice_pred = model.predict(slice)
|
57 |
+
slice_pred = tf.nn.softmax(slice_pred)
|
58 |
+
|
59 |
+
saved_plot = "plot.jpg"
|
60 |
+
fig = plt.figure()
|
61 |
+
|
62 |
+
predictions_df = pd.DataFrame()
|
63 |
+
num = random.randint(0,50)
|
64 |
+
for images, labels in slice:
|
65 |
+
for i,j in zip(range(num,num+3), range(3)):
|
66 |
+
ax = plt.subplot(1, 3, j + 1)
|
67 |
+
plt.imshow(images[i].numpy().astype("uint8"))
|
68 |
+
output = np.argmax(slice_pred[i])
|
69 |
+
|
70 |
+
prob_list = slice_pred[i].numpy().flatten().tolist()
|
71 |
+
sorted_prob = np.argsort(slice_pred[i])[::-1].flatten()
|
72 |
+
prob_scores = {"image": "image "+ str(j), "1st_highest_probability": f"prob of {class_vocab[sorted_prob[0]]} is {round(prob_list[sorted_prob[0]] * 100,2)} %",
|
73 |
+
"2nd_highest_probability": f"prob of {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)} %",
|
74 |
+
"3rd_highest_probability": f"prob of {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)} %"}
|
75 |
+
predictions_df = predictions_df.append(prob_scores,ignore_index=True)
|
76 |
+
|
77 |
+
plt.title(f"image {j} : {class_vocab[output]}")
|
78 |
+
plt.axis("off")
|
79 |
+
plt.savefig(saved_plot,bbox_inches='tight')
|
80 |
+
|
81 |
+
return saved_plot, predictions_df
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|