Vrk's picture
Update app.py
c5d6aa2
raw history blame
No virus
3.29 kB
import gradio as gr
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import timm
from PIL import Image
from torchvision import transforms
from Models import ResNet, EfficientNet, BaseLine
def get_model(model_name, classes, device):
if model_name == 'Inception-V3':
model = tf.lite.Interpreter(model_path='vgg.tflite')
model.allocate_tensors()
elif model_name == 'VGG':
model = tf.lite.Interpreter(model_path='vgg.tflite')
model.allocate_tensors()
elif model_name == 'EfficientNet-B0':
model = EfficientNet(len(classes)).to(device)
model.load_state_dict(torch.load('EfficientNet-Model.pt'))
elif model_name == 'ResNet-50':
model = ResNet(len(classes)).to(device)
model.load_state_dict(torch.load('model-resnet50.pt'))
elif model_name == 'Base Line Model':
model = BaseLine(len(classes)).to(device)
model.load_state_dict(torch.load('BaseLine-Model.pt'))
return model
def make_predictions(input_img, model_name):
classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(model_name, classes, device)
if model_name in ['EfficientNet-B0', 'ResNet-50', 'Base Line Model']:
model.eval()
img = get_transform(input_img, device)
pred = model(img)
if torch.cuda.is_available():
pred = F.softmax(pred).detach().cpu().numpy()
y_prob = pred.argmax(axis=1)[0]
else:
pred = F.softmax(pred).detach().numpy()
y_prob = pred.argmax(axis=1)[0]
if model_name in ['Inception-V3', 'VGG']:
input_img = np.array(input_img)
img = input_img / 255.
input_tensor= np.array(np.expand_dims(img,0), dtype=np.float32)
input_index = model.get_input_details()[0]["index"]
# setting input tensor
model.set_tensor(input_index, input_tensor)
#Run the inference
model.invoke()
output_details = model.get_output_details()
# output data of image
pred = model.get_tensor(output_details[0]['index'])
y_prob = pred.argmax()
label = classes[y_prob]
confidences = {classes[i]: float(pred[0][i]) for i in range(len(classes))}
return label, confidences
demo = gr.Interface(
fn = make_predictions,
inputs = [gr.Image(shape=(150, 150), type="pil"), gr.Dropdown(choices=['EfficientNet-B0', 'ResNet-50', 'Inception-V3', 'VGG', 'Base Line Model'], value='EfficientNet-B0', label='Choose Model')],
outputs = [gr.outputs.Textbox(label="Output Class"), gr.outputs.Label(label='Confidences')],
title = "MultiClass Classifier",
examples=[
["Sample_Images/Buildings.jpg", 'EfficientNet-B0'],
["Sample_Images/Forest.jpg", 'EfficientNet-B0'],
['Sample_Images/Street.jpg', 'EfficientNet-B0'],
['Sample_Images/glacier.jpg', 'EfficientNet-B0'],
['Sample_Images/mountain.jpg', 'EfficientNet-B0'],
['Sample_Images/sea.jpg', 'EfficientNet-B0']
],
)
demo.launch(debug=True, inline=True)