Joseph
add examples
710ac0a
import os
import gc
import timm
import gradio as gr
import torch
import tensorflow as tf
model_names = [
"mobilenetv4_conv_small.e2400_r224_in1k",
"mobilenetv4_conv_medium.e500_r224_in1k",
"mobilenetv4_conv_blur_medium.e500_r224_in1k",
"mobilenetv4_conv_medium.e500_r256_in1k",
"mobilenetv4_conv_large.e500_r256_in1k",
"mobilenetv4_conv_large.e600_r384_in1k",
]
with open('imagenet_classes.txt', 'r') as file:
lines = file.readlines()
index_to_label = {index: line.strip() for index, line in enumerate(lines)}
model, transforms = None, None
tfl_model, input_details, output_details = None, None, None
last_model = None
def load_models(timm_model):
convert_dir = "tflite_models"
tf_model_path = os.path.join(convert_dir, f"{timm_model}_float16.tflite")
model = timm.create_model(timm_model, pretrained=True)
model = model.eval()
data_config = timm.data.resolve_data_config(model=model)
transforms = timm.data.create_transform(**data_config, is_training=False)
tfl_model = tf.lite.Interpreter(model_path=tf_model_path)
tfl_model.allocate_tensors()
input_details = tfl_model.get_input_details()
output_details = tfl_model.get_output_details()
return model, transforms, tfl_model, input_details, output_details
def classify(img, model_name):
global model, transforms, tfl_model, input_details, output_details, last_model
if last_model is None or model_name != last_model:
if model is not None:
model = None
gc.collect()
if tfl_model is not None:
tfl_model = None
gc.collect()
model, transforms, tfl_model, input_details, output_details = load_models(model_name)
last_model = model_name
processed_img = transforms(img).unsqueeze(0)
pt_output = model(processed_img)
pt_top5_probs, pt_top5_indices = torch.topk(pt_output.softmax(dim=1), k=5)
pt_index_list = pt_top5_indices[0].tolist()
pt_probs_list = pt_top5_probs[0].tolist()
pt_result_labels = {
index_to_label[index]: prob
for index, prob in zip(pt_index_list, pt_probs_list)
}
############################################################
img_tf = processed_img.permute(0, 2, 3, 1) # BCHW to numpy BHWC
input = input_details[0]
tfl_model.set_tensor(input["index"], img_tf)
tfl_model.invoke()
tfl_output = tfl_model.get_tensor(output_details[0]["index"])
tfl_output_tensor = tf.convert_to_tensor(tfl_output)
tfl_softmax_output = tf.nn.softmax(tfl_output_tensor, axis=1)
tfl_top5_probs, tfl_top5_indices = tf.math.top_k(tfl_softmax_output, k=5)
tfl_probs_list = tfl_top5_probs[0].numpy().tolist()
tfl_index_list = tfl_top5_indices[0].numpy().tolist()
tfl_result_labels = {
index_to_label[index]: prob
for index, prob in zip(tfl_index_list, tfl_probs_list)
}
return pt_result_labels, tfl_result_labels
iface = gr.Interface(
fn=classify,
inputs=[gr.Image(type="pil"), gr.Dropdown(choices=model_names, value=model_names[0], label="Model Variant.")],
outputs=[gr.Label(label="Pytorch Output"), gr.Label(label="TFLite Output")],
title="MobileNetV4 Pytorch vs TFLite Imagenet1K Classification",
examples=[
["example_images/n01818515_macaw.JPEG", model_names[0]],
["example_images/n01828970_bee_eater.jpg", model_names[0]],
["example_images/n01833805_hummingbird.JPEG", model_names[0]]
]
)
iface.launch()