chansung's picture
upload 1683859487 model
c44275a
import tarfile
import wandb
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from transformers import ViTFeatureExtractor
PRETRAIN_CHECKPOINT = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(PRETRAIN_CHECKPOINT)
MODEL = None
RESOLTUION = 224
labels = []
with open(r"labels.txt", "r") as fp:
for line in fp:
labels.append(line[:-1])
def normalize_img(
img, mean=feature_extractor.image_mean, std=feature_extractor.image_std
):
img = img / 255
mean = tf.constant(mean)
std = tf.constant(std)
return (img - mean) / std
def preprocess_input(image):
image = np.array(image)
image = tf.convert_to_tensor(image)
image = tf.image.resize(image, (RESOLTUION, RESOLTUION))
image = normalize_img(image)
image = tf.transpose(
image, (2, 0, 1)
) # Since HF models are channel-first.
return {
"pixel_values": tf.expand_dims(image, 0)
}
def get_predictions(wb_token, image):
global MODEL
if MODEL is None:
wandb.login(key=wb_token)
wandb.init(project="tfx-vit-pipeline")
path = wandb.use_artifact('tfx-vit-pipeline/final_model:1683859487', type='model').download()
tar = tarfile.open(f"{path}/model.tar.gz")
tar.extractall(path=".")
MODEL = tf.keras.models.load_model("./model")
preprocessed_image = preprocess_input(image)
prediction = MODEL.predict(preprocessed_image)
probs = tf.nn.softmax(prediction['logits'], axis=1)
confidences = {labels[i]: float(probs[0][i]) for i in range(3)}
return confidences
with gr.Blocks() as demo:
gr.Markdown("## Simple demo for a Image Classification of the Beans Dataset with HF ViT model")
wb_token_if = gr.Textbox(interactive=True, label="Your Weight & Biases API Key")
with gr.Row():
image_if = gr.Image()
label_if = gr.Label(num_top_classes=3)
classify_if = gr.Button()
classify_if.click(
get_predictions,
[wb_token_if, image_if],
label_if
)
demo.launch(debug=True)