Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageClassification | |
import gradio as gr | |
from PIL import Image | |
image_preprocess = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) | |
]) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model_id = "PinChunPai/cat20_breed_fine_tune" | |
model = AutoModelForImageClassification.from_pretrained(model_id).to(device) | |
id2label=model.config.id2label | |
model.eval() | |
def classifier_app(image): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image = image_preprocess(image) | |
image = image.unsqueeze(0) | |
image = image.to(device) | |
with torch.no_grad(): | |
logits = model(image).logits | |
prediction = torch.softmax(logits, dim=1) | |
top_indices = torch.argsort(prediction, descending=True)[:].tolist() | |
top1 = round( prediction[0][top_indices[0][0]].item(), 4) | |
top2 = round( prediction[0][top_indices[0][1]].item(), 4) | |
top3 = round( prediction[0][top_indices[0][2]].item(), 4) | |
others = round(1.0 - top1 - top2-top3 , 4) | |
len_1 = len(id2label[top_indices[0][0]]) | |
len_2 = len(id2label[top_indices[0][1]]) | |
len_3 = len(id2label[top_indices[0][2]]) | |
max_length = max(len_1, len_2, len_3, len('Others')) | |
labels = [f'{id2label[top_indices[0][0]]}'.ljust(max_length), f'{id2label[top_indices[0][1]]}'.ljust(max_length), f'{id2label[top_indices[0][2]]}'.ljust(max_length), f'Others'.ljust(max_length)] | |
preds = [top1, top2,top3 , others] | |
result = {label : float(pred) for label, pred in zip(labels, preds)} | |
return result | |
description = f"This classifier assumes the picture is among the 20 breeds :{list(id2label.values())}" | |
article = "First two examples are extracted from Kaggle dataset(https://www.kaggle.com/datasets/knucharat/pop-cats). Model is trained on the same dataset. Project details can be found in https://github.com/PinChunPai/Cat_breeds_classification/tree/main" | |
interface = gr.Interface( | |
fn=classifier_app, | |
inputs="image", | |
examples = ['cat_1.jpg','cat_2.jpg','cat_3.jpg'], | |
outputs="label", | |
title="Cat20_breeds", | |
description=description, | |
article = article | |
) | |
interface.launch(share=True) |