File size: 3,814 Bytes
89fb082
 
 
 
 
d34f45c
 
e0b823f
a51dc81
5dd794e
d34f45c
89fb082
f817463
db4a829
89fb082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d34f45c
10d2426
 
 
8669566
d89adef
 
8669566
4fa426b
10d2426
 
 
 
d3d4e02
10d2426
 
 
 
89fb082
 
 
0a2e4df
d01708c
89fb082
8664f05
 
89fb082
fa67a36
89fb082
 
9791f6a
 
 
 
 
 
89fb082
 
 
f2e2bcd
89fb082
 
 
 
 
3c18717
89fb082
 
 
 
9358ae1
8fbfe88
 
 
89fb082
 
 
 
8fbfe88
89fb082
 
fa67a36
 
 
 
 
3fcf295
89fb082
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

import base64
import streamlit as st
from PIL import Image
import numpy as np
from keras.models import model_from_json
import subprocess
import os
import tensorflow as tf
from keras.applications.imagenet_utils import preprocess_input


st.markdown('<h1 style="color:white;">Image Classification App</h1>', unsafe_allow_html=True)
st.markdown('<h2 style="color:white;">for classifying **zebras** and **horses**</h2>', unsafe_allow_html=True)

st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
    with open(bin_file, 'rb') as f:
        data = f.read()
    return base64.b64encode(data).decode()

def set_png_as_page_bg(png_file):
    bin_str = get_base64_of_bin_file(png_file) 
    page_bg_img = '''
    <style>
    .stApp {
    background-image: url("data:image/png;base64,%s");
    background-size: cover;
    background-repeat: no-repeat;
    background-attachment: scroll; # doesn't work
    }
    </style>
    ''' % bin_str
    
    st.markdown(page_bg_img, unsafe_allow_html=True)
    return

set_png_as_page_bg('background.webp')
        

# def load_model():
#     # load json and create model
#     json_file = open('model.json', 'r')
#     loaded_model_json = json_file.read()
#     json_file.close()
#     CNN_class_index = model_from_json(loaded_model_json)
#     # load weights into new model
#     model = CNN_class_index.load_weights("model.h5")

#     #model= tf.keras.load_model('model.h5')
#     #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
#     return model, CNN_class_index

def load_model():
    if not os.path.isfile('model.h5'):
        subprocess.run(['curl --output model.h5 "https://github.com/KaburaJ/Binary-Image-classification/blob/main/ZebraHorse/CNN%20Application/model.h5"'], shell=True)

        model=tf.keras.models.load_model('model.h5', compile=False)
        return model

        
# def load_model():
#     # Load the model architecture
#     with open('model.json', 'r') as f:
#         model_from_json(f.read())

#     # Load the model weights
#     model.load_weights('model.h5')
#     #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
#     return model


def image_transformation(image):
    #image = Image._resize_dispatcher(image, new_shape=(256, 256))
    #image= np.resize((256,256))
    image = np.array(image)
    np.save('images.npy', image)
    image = np.load('images.npy', allow_pickle=True)

    return image


# def image_prediction(image, model):
#     image = image_transformation(image=image)
#     outputs = float(model.predict(image))
#     _, y_hat = outputs.max(1)
#     predicted_idx = str(y_hat.item())
#     return predicted_idx

def main():
    
    image_file  = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])

    if image_file:
       
        left_column, right_column = st.columns(2)
        left_column.image(image_file, caption="Uploaded image", use_column_width=True)
        image_pred = image_transformation(image=Image.open(image_file))
        

        pred_button = st.button("Predict")
        
        model=load_model()
        if model is None:
            st.error("Error: Model could not be loaded")
            return
        # label = ['Zebra', 'Horse']
        # label = np.array(label).reshape(1, -1)
        # ohe= OneHotEncoder()
        # labels = ohe.fit_transform(label).toarray()
        

        if pred_button:
            outputs = model.predict(int(image_pred))
            _, y_hat = outputs.max(1)
            predicted_idx = str(y_hat.item())
            right_column.title("Prediction")
            right_column.write(predicted_idx)
            right_column.write(decode_predictions(outputs, top=2)[0])
        

if __name__ == '__main__':
    main()