import numpy as np import torch from torchvision import transforms from transformers import AutoModelForImageClassification import gradio as gr from PIL import Image image_preprocess = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) ]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_id = "PinChunPai/cat20_breed_fine_tune" model = AutoModelForImageClassification.from_pretrained(model_id).to(device) id2label=model.config.id2label model.eval() def classifier_app(image): if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') image = image_preprocess(image) image = image.unsqueeze(0) image = image.to(device) with torch.no_grad(): logits = model(image).logits prediction = torch.softmax(logits, dim=1) top_indices = torch.argsort(prediction, descending=True)[:].tolist() top1 = round( prediction[0][top_indices[0][0]].item(), 4) top2 = round( prediction[0][top_indices[0][1]].item(), 4) top3 = round( prediction[0][top_indices[0][2]].item(), 4) others = round(1.0 - top1 - top2-top3 , 4) len_1 = len(id2label[top_indices[0][0]]) len_2 = len(id2label[top_indices[0][1]]) len_3 = len(id2label[top_indices[0][2]]) max_length = max(len_1, len_2, len_3, len('Others')) labels = [f'{id2label[top_indices[0][0]]}'.ljust(max_length), f'{id2label[top_indices[0][1]]}'.ljust(max_length), f'{id2label[top_indices[0][2]]}'.ljust(max_length), f'Others'.ljust(max_length)] preds = [top1, top2,top3 , others] result = {label : float(pred) for label, pred in zip(labels, preds)} return result description = f"This classifier assumes the picture is among the 20 breeds :{list(id2label.values())}" article = "First two examples are extracted from Kaggle dataset(https://www.kaggle.com/datasets/knucharat/pop-cats). Model is trained on the same dataset. Project details can be found in https://github.com/PinChunPai/Cat_breeds_classification/tree/main" interface = gr.Interface( fn=classifier_app, inputs="image", examples = ['cat_1.jpg','cat_2.jpg','cat_3.jpg'], outputs="label", title="Cat20_breeds", description=description, article = article ) interface.launch(share=True)