File size: 3,396 Bytes
920c21d
6ed8d49
 
 
 
77d0631
9918625
 
6ed8d49
 
cc8c26c
6ed8d49
 
 
 
 
 
 
 
 
 
 
 
 
 
9918625
 
c1a178d
 
 
 
 
9918625
0fecf04
 
 
 
 
9918625
 
 
 
 
6ed8d49
 
81069e7
2d91f92
6ed8d49
 
 
c1a178d
 
 
6ed8d49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff1048e
6ed8d49
 
ff1048e
 
 
1d4b58f
ff1048e
 
5bafd2e
1d4b58f
0fecf04
6ed8d49
 
 
 
0fecf04
6ed8d49
c3a0462
6ed8d49
 
 
0fecf04
 
6ed8d49
 
0fecf04
6ed8d49
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import keras_nlp
import gradio as gr
import tensorflow as tf
import gdown
from PIL import Image
#import keras_nlp
from tensorflow import keras
import time

input_shape = (32, 32, 3)
resized_shape = (299, 299, 3)
num_classes = 10
labels = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

# Download the NLP model 
def download_model_NLP():
    preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,)
    model = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor)
    
    output = "total.h5"
    id = "1-KgcnP1ayWQ6l2-4h723JCYPoWxzOnU3"
    gdown.download(id=id, output=output, quiet=False)    
    model.load_weights(output)     
    return model





# Download the model file
def download_model():
    url = "https://drive.google.com/uc?id=1zUGAPg9RVgo7bWtf_-L9MXoXKldZjs1y"
    output = "CIFAR10_Xception_(ACC_0.9704__LOSS_0.0335_).h5"
    gdown.download(url, output, quiet=False)
    return output




model_file = download_model()

# Load the model
model = tf.keras.models.load_model(model_file)

# Perform image classification for single class output
# def predict_class(image):
#     img = tf.cast(image, tf.float32)
#     img = tf.image.resize(img, [input_shape[0], input_shape[1]])
#     img = tf.expand_dims(img, axis=0)
#     prediction = model.predict(img)
#     class_index = tf.argmax(prediction[0]).numpy()
#     predicted_class = labels[class_index]
#     return predicted_class

# Perform image classification for multy class output
def predict_class(image):
    img = tf.cast(image, tf.float32)
    img = tf.image.resize(img, [input_shape[0], input_shape[1]])
    img = tf.expand_dims(img, axis=0)
    prediction = model.predict(img)
    return prediction[0]

# UI Design for single class output
# def classify_image(image):
#     predicted_class = predict_class(image)
#     output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>"
#     return output


# UI Design for multy class output
def classify_image(image):
    
    results = predict_class(image)
    output = {labels.get(i): float(results[i]) for i in range(len(results))}
    max_value = max(results)
    for i in range(len(results)):
      if results[i] == max_value:
        name = labels.get(i).capitalize()
    
    result_NLP = model_NLP.generate(f"{name} is able to", max_length=100)
    index = result_NLP.find('.', 75)
    result_NLP = result_NLP[0:index+1]
    return output, result_NLP


inputs = gr.inputs.Image(type="pil", label="Upload an image")
# outputs = gr.outputs.HTML() #uncomment for single class output 
output_1 = gr.outputs.Label(num_top_classes=4)

title = "<h1 style='text-align: center;'>Image Classifier with Funny Annotations :)</h1>"
description = "Upload an image and get the predicted class."
# css_code='body{background-image:url("file=wave.mp4");}'

model_NLP  = download_model_NLP()

gr.Interface(fn=classify_image, 
             inputs=inputs, 
             outputs=[output_1, "text"], 
             title=title, 
             examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_bird.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]],
             # css=css_code,
             description=description,
            enable_queue=True).launch()