File size: 2,124 Bytes
237e09c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c44275a
237e09c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tarfile
import wandb

import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor

PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)

MODEL = None 

RESOLTUION = 224

labels = []

with open(r"labels.txt", "r") as fp:
    for line in fp:
        labels.append(line[:-1])

def normalize_img(
    img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
    img = img / 255
    mean = tf.constant(mean)
    std = tf.constant(std)
    return (img - mean) / std

def preprocess_input(image):
    image = np.array(image)
    image = tf.convert_to_tensor(image)

    image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
    image = normalize_img(image)

    image = tf.transpose(
        image, (2, 0, 1)
    )  # Since HF models are channel-first.

    return {
        "pixel_values": tf.expand_dims(image, 0)
    }

def get_predictions(wb_token, image):
    global MODEL
    
    if MODEL is None:
        wandb.login(key=wb_token)
        wandb.init(project="tfx-vit-pipeline")
        path = wandb.use_artifact('tfx-vit-pipeline/final_model:1683859487', type='model').download()

        tar = tarfile.open(f"{path}/model.tar.gz")
        tar.extractall(path=".")

        MODEL = tf.keras.models.load_model("./model")
    
    preprocessed_image = preprocess_input(image)
    prediction = MODEL.predict(preprocessed_image)
    probs = tf.nn.softmax(prediction['logits'], axis=1)

    confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
    return confidences

with gr.Blocks() as demo:
    gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")

    wb_token_if = gr.Textbox(interactive=True, label="Your Weight & Biases API Key")

    with gr.Row():
        image_if = gr.Image()
        label_if = gr.Label(num_top_classes=3)

    classify_if = gr.Button()

    classify_if.click(
        get_predictions,
        [wb_token_if, image_if],
        label_if
    )

demo.launch(debug=True)