chansung's picture
Update app.py
60451f1 verified
import os
import tarfile
import wandb
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor
PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
WB_KEY = os.environ['WB_KEY']
MODEL = None
RESOLTUION = 224
labels = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels.append(line[:-1])
def normalize_img(
img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
img = img / 255
mean = tf.constant(mean)
std = tf.constant(std)
return (img - mean) / std
def preprocess_input(image):
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = normalize_img(image)
image = tf.transpose(
image, (2, 0, 1)
) # Since HF models are channel-first.
return {
"pixel_values": tf.expand_dims(image, 0)
}
def get_predictions(image):
global MODEL
if MODEL is None:
wandb.login(key=WB_KEY)
wandb.init(project="tfx-vit-pipeline", id="gvtyqdgn", resume=True)
path = wandb.use_artifact('tfx-vit-pipeline/final_model:1688113391', type='model').download()
tar = tarfile.open(f"{path}/model.tar.gz")
tar.extractall(path=".")
MODEL = tf.keras.models.load_model("./model")
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
probs = tf.nn.softmax(prediction['logits'], axis=1)
confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
return confidences
with gr.Blocks() as demo:
gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
with gr.Row():
image_if = gr.Image()
label_if = gr.Label(num_top_classes=3)
classify_if = gr.Button()
classify_if.click(
get_predictions,
image_if,
label_if
)
gr.Examples(
[["test_image1.jpeg"], ["test_image2.jpeg"], ["test_image3.jpeg"]],
[image_if],
[label_if],
get_predictions,
cache_examples=True
)
demo.launch(debug=True)