File size: 1,710 Bytes
fdcbba1
6db8928
 
 
 
 
 
 
 
 
613a9bc
6db8928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdcbba1
 
83190db
fdcbba1
 
 
 
 
 
 
 
 
83190db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor
from huggingface_hub import from_pretrained_keras

PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)

MODEL_CKPT = "chansung/vit-e2e-pipeline-hf-integration@v1664863171"
MODEL = from_pretrained_keras(MODEL_CKPT)

RESOLTUION = 224

labels = []

with open(r"labels.txt", "r") as fp:
    for line in fp:
        labels.append(line[:-1])

def normalize_img(
    img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
    img = img / 255
    mean = tf.constant(mean)
    std = tf.constant(std)
    return (img - mean) / std

def preprocess_input(image: Image) -> tf.Tensor:
    image = np.array(image)
    image = tf.convert_to_tensor(image)

    image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
    image = normalize_img(image)

    image = tf.transpose(
        image, (2, 0, 1)
    )  # Since HF models are channel-first.

    return {
        "pixel_values": tf.expand_dims(image, 0)
    }

def get_predictions(image: Image) -> tf.Tensor:
    preprocessed_image = preprocess_input(image)
    prediction = MODEL.predict(preprocessed_image)
    probs = tf.nn.softmax(prediction['logits'], axis=1)

    confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
    return confidences

title = "Simple demo for a Image Classification of the Beans Dataset with HF ViT model"

demo = gr.Interface(
    get_predictions,
    gr.inputs.Image(type="pil"),
    gr.outputs.Label(num_top_classes=3),
    allow_flagging="never",
    title=title,
)

demo.launch(debug=True)