File size: 3,537 Bytes
4a9a0f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710ac0a
 
 
 
 
4a9a0f5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import gc
import timm
import gradio as gr
import torch
import tensorflow as tf

model_names = [
    "mobilenetv4_conv_small.e2400_r224_in1k",
    "mobilenetv4_conv_medium.e500_r224_in1k",
    "mobilenetv4_conv_blur_medium.e500_r224_in1k",
    "mobilenetv4_conv_medium.e500_r256_in1k",
    "mobilenetv4_conv_large.e500_r256_in1k",
    "mobilenetv4_conv_large.e600_r384_in1k",
]

with open('imagenet_classes.txt', 'r') as file:
    lines = file.readlines()

index_to_label = {index: line.strip() for index, line in enumerate(lines)}

model, transforms = None, None
tfl_model, input_details, output_details = None, None, None
last_model = None


def load_models(timm_model):
    convert_dir = "tflite_models"

    tf_model_path = os.path.join(convert_dir, f"{timm_model}_float16.tflite")

    model = timm.create_model(timm_model, pretrained=True)
    model = model.eval()
    data_config = timm.data.resolve_data_config(model=model)
    transforms = timm.data.create_transform(**data_config, is_training=False)

    tfl_model = tf.lite.Interpreter(model_path=tf_model_path)
    tfl_model.allocate_tensors()
    input_details = tfl_model.get_input_details()
    output_details = tfl_model.get_output_details()

    return model, transforms, tfl_model, input_details, output_details


def classify(img, model_name):
    global model, transforms, tfl_model, input_details, output_details, last_model
    if last_model is None or model_name != last_model:
        if model is not None:
            model = None
            gc.collect()
        if tfl_model is not None:
            tfl_model = None
            gc.collect()

        model, transforms, tfl_model, input_details, output_details = load_models(model_name)
        
        last_model = model_name
    
    processed_img = transforms(img).unsqueeze(0)

    pt_output = model(processed_img)
    
    pt_top5_probs, pt_top5_indices = torch.topk(pt_output.softmax(dim=1), k=5)

    pt_index_list = pt_top5_indices[0].tolist()
    pt_probs_list = pt_top5_probs[0].tolist()
    
    pt_result_labels = {
        index_to_label[index]: prob
        for index, prob in zip(pt_index_list, pt_probs_list)
    }
    
    ############################################################
    
    img_tf = processed_img.permute(0, 2, 3, 1)  # BCHW to numpy BHWC

    input = input_details[0]
    tfl_model.set_tensor(input["index"], img_tf)
    tfl_model.invoke()

    tfl_output = tfl_model.get_tensor(output_details[0]["index"])
    
    tfl_output_tensor = tf.convert_to_tensor(tfl_output)

    tfl_softmax_output = tf.nn.softmax(tfl_output_tensor, axis=1)

    tfl_top5_probs, tfl_top5_indices = tf.math.top_k(tfl_softmax_output, k=5)


    tfl_probs_list = tfl_top5_probs[0].numpy().tolist()
    tfl_index_list = tfl_top5_indices[0].numpy().tolist()
    
    tfl_result_labels = {
        index_to_label[index]: prob
        for index, prob in zip(tfl_index_list, tfl_probs_list)
    }

    return pt_result_labels, tfl_result_labels


iface = gr.Interface(
    fn=classify,
    inputs=[gr.Image(type="pil"), gr.Dropdown(choices=model_names, value=model_names[0], label="Model Variant.")],
    outputs=[gr.Label(label="Pytorch Output"), gr.Label(label="TFLite Output")],
    title="MobileNetV4 Pytorch vs TFLite Imagenet1K Classification",
    examples=[
        ["example_images/n01818515_macaw.JPEG", model_names[0]],
        ["example_images/n01828970_bee_eater.jpg", model_names[0]],
        ["example_images/n01833805_hummingbird.JPEG", model_names[0]]
    ]
)

iface.launch()