Spaces:
Running
Running
import os | |
import gc | |
import timm | |
import gradio as gr | |
import torch | |
import tensorflow as tf | |
model_names = [ | |
"mobilenetv4_conv_small.e2400_r224_in1k", | |
"mobilenetv4_conv_medium.e500_r224_in1k", | |
"mobilenetv4_conv_blur_medium.e500_r224_in1k", | |
"mobilenetv4_conv_medium.e500_r256_in1k", | |
"mobilenetv4_conv_large.e500_r256_in1k", | |
"mobilenetv4_conv_large.e600_r384_in1k", | |
] | |
with open('imagenet_classes.txt', 'r') as file: | |
lines = file.readlines() | |
index_to_label = {index: line.strip() for index, line in enumerate(lines)} | |
model, transforms = None, None | |
tfl_model, input_details, output_details = None, None, None | |
last_model = None | |
def load_models(timm_model): | |
convert_dir = "tflite_models" | |
tf_model_path = os.path.join(convert_dir, f"{timm_model}_float16.tflite") | |
model = timm.create_model(timm_model, pretrained=True) | |
model = model.eval() | |
data_config = timm.data.resolve_data_config(model=model) | |
transforms = timm.data.create_transform(**data_config, is_training=False) | |
tfl_model = tf.lite.Interpreter(model_path=tf_model_path) | |
tfl_model.allocate_tensors() | |
input_details = tfl_model.get_input_details() | |
output_details = tfl_model.get_output_details() | |
return model, transforms, tfl_model, input_details, output_details | |
def classify(img, model_name): | |
global model, transforms, tfl_model, input_details, output_details, last_model | |
if last_model is None or model_name != last_model: | |
if model is not None: | |
model = None | |
gc.collect() | |
if tfl_model is not None: | |
tfl_model = None | |
gc.collect() | |
model, transforms, tfl_model, input_details, output_details = load_models(model_name) | |
last_model = model_name | |
processed_img = transforms(img).unsqueeze(0) | |
pt_output = model(processed_img) | |
pt_top5_probs, pt_top5_indices = torch.topk(pt_output.softmax(dim=1), k=5) | |
pt_index_list = pt_top5_indices[0].tolist() | |
pt_probs_list = pt_top5_probs[0].tolist() | |
pt_result_labels = { | |
index_to_label[index]: prob | |
for index, prob in zip(pt_index_list, pt_probs_list) | |
} | |
############################################################ | |
img_tf = processed_img.permute(0, 2, 3, 1) # BCHW to numpy BHWC | |
input = input_details[0] | |
tfl_model.set_tensor(input["index"], img_tf) | |
tfl_model.invoke() | |
tfl_output = tfl_model.get_tensor(output_details[0]["index"]) | |
tfl_output_tensor = tf.convert_to_tensor(tfl_output) | |
tfl_softmax_output = tf.nn.softmax(tfl_output_tensor, axis=1) | |
tfl_top5_probs, tfl_top5_indices = tf.math.top_k(tfl_softmax_output, k=5) | |
tfl_probs_list = tfl_top5_probs[0].numpy().tolist() | |
tfl_index_list = tfl_top5_indices[0].numpy().tolist() | |
tfl_result_labels = { | |
index_to_label[index]: prob | |
for index, prob in zip(tfl_index_list, tfl_probs_list) | |
} | |
return pt_result_labels, tfl_result_labels | |
iface = gr.Interface( | |
fn=classify, | |
inputs=[gr.Image(type="pil"), gr.Dropdown(choices=model_names, value=model_names[0], label="Model Variant.")], | |
outputs=[gr.Label(label="Pytorch Output"), gr.Label(label="TFLite Output")], | |
title="MobileNetV4 Pytorch vs TFLite Imagenet1K Classification", | |
examples=[ | |
["example_images/n01818515_macaw.JPEG", model_names[0]], | |
["example_images/n01828970_bee_eater.jpg", model_names[0]], | |
["example_images/n01833805_hummingbird.JPEG", model_names[0]] | |
] | |
) | |
iface.launch() | |