cat20_breeds / app.py
PinChunPai's picture
Update app.py
9d53036 verified
import numpy as np
import torch
from torchvision import transforms
from transformers import AutoModelForImageClassification
import gradio as gr
from PIL import Image
image_preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = "PinChunPai/cat20_breed_fine_tune"
model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
id2label=model.config.id2label
model.eval()
def classifier_app(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = image_preprocess(image)
image = image.unsqueeze(0)
image = image.to(device)
with torch.no_grad():
logits = model(image).logits
prediction = torch.softmax(logits, dim=1)
top_indices = torch.argsort(prediction, descending=True)[:].tolist()
top1 = round( prediction[0][top_indices[0][0]].item(), 4)
top2 = round( prediction[0][top_indices[0][1]].item(), 4)
top3 = round( prediction[0][top_indices[0][2]].item(), 4)
others = round(1.0 - top1 - top2-top3 , 4)
len_1 = len(id2label[top_indices[0][0]])
len_2 = len(id2label[top_indices[0][1]])
len_3 = len(id2label[top_indices[0][2]])
max_length = max(len_1, len_2, len_3, len('Others'))
labels = [f'{id2label[top_indices[0][0]]}'.ljust(max_length), f'{id2label[top_indices[0][1]]}'.ljust(max_length), f'{id2label[top_indices[0][2]]}'.ljust(max_length), f'Others'.ljust(max_length)]
preds = [top1, top2,top3 , others]
result = {label : float(pred) for label, pred in zip(labels, preds)}
return result
description = f"This classifier assumes the picture is among the 20 breeds :{list(id2label.values())}"
article = "First two examples are extracted from Kaggle dataset(https://www.kaggle.com/datasets/knucharat/pop-cats). Model is trained on the same dataset. Project details can be found in https://github.com/PinChunPai/Cat_breeds_classification/tree/main"
interface = gr.Interface(
fn=classifier_app,
inputs="image",
examples = ['cat_1.jpg','cat_2.jpg','cat_3.jpg'],
outputs="label",
title="Cat20_breeds",
description=description,
article = article
)
interface.launch(share=True)