kakao-brain-vit / app.py
adirik's picture
update model version, add examples
5b7f9a4
raw
history blame contribute delete
No virus
2.13 kB
import cv2
import json
import gradio as gr
import numpy as np
import tensorflow as tf
from backbone import create_name_vit
from backbone import ClassificationModel
vit_l16_512 = {
"backbone_name": "vit-l/16",
"backbone_params": {
"image_size": 512,
"representation_size": 0,
"attention_dropout_rate": 0.,
"dropout_rate": 0.,
"channels": 3
},
"dropout_rate": 0.,
"pretrained": "./weights/vit_l16_512/model-weights"
}
# Init backbone
backbone = create_name_vit(vit_l16_512["backbone_name"], **vit_l16_512["backbone_params"])
# Init classification model
model = ClassificationModel(
backbone=backbone,
dropout_rate=vit_l16_512["dropout_rate"],
num_classes=1000
)
# Load weights
model.load_weights(vit_l16_512["pretrained"])
model.trainable = False
# Load ImageNet idx to label mapping
with open("assets/imagenet_1000_idx2labels.json") as f:
idx_to_label = json.load(f)
def resize_with_normalization(image, size=[512, 512]):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, size)
image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
image = tf.expand_dims(image, axis=0)
return image
def softmax_stable(x):
return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum())
def classify_image(img, top_k):
img = tf.convert_to_tensor(img)
img = resize_with_normalization(img)
pred_logits = model.predict(img, batch_size=1, workers=8)[0]
pred_probs = softmax_stable(pred_logits)
top_k_labels = pred_probs.argsort()[-top_k:][::-1]
return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels}
demo = gr.Interface(
classify_image,
inputs=[gr.Image(), gr.Slider(0, 1000, value=5)],
outputs=gr.outputs.Label(),
title="Image Classification with Kakao Brain ViT",
examples=[
["assets/halloween-gaf8ad7ebc_1920.jpeg", 5],
["assets/IMG_4484.jpeg", 5],
["assets/IMG_4737.jpeg", 5],
["assets/IMG_4740.jpeg", 5],
],
)
demo.launch()