| import torch | |
| import torchvision.transforms as T | |
| from timm import create_model | |
| from safetensors.torch import load_model | |
| import numpy as np | |
| from pathlib import Path | |
| import gradio as gr | |
| examples = Path('./examples').glob('*') | |
| examples = list(map(str,examples)) | |
| valid_tfms = T.Compose([ | |
| T.Resize((224,224)), | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean = (0.5,0.5,0.5), | |
| std = (0.5,0.5,0.5) | |
| ) | |
| ]) | |
| model_path = 'model/swin_s3_base_224-pascal/model.safetensors' | |
| model = create_model( | |
| 'swin_s3_base_224', | |
| pretrained = False, | |
| num_classes = 20 | |
| ) | |
| load_model(model,model_path) | |
| model.eval() | |
| class_names = [ | |
| "Aeroplane","Bicycle","Bird","Boat","Bottle", | |
| "Bus","Car","Cat","Chair","Cow","Diningtable", | |
| "Dog","Horse","Motorbike","Person", | |
| "Potted plant","Sheep","Sofa","Train","Tv/monitor" | |
| ] | |
| label2id = {c:idx for idx,c in enumerate(class_names)} | |
| id2label = {idx:c for idx,c in enumerate(class_names)} | |
| def predict(im): | |
| im = valid_tfms(im).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(im) | |
| confidences = logits.sigmoid().flatten() | |
| predictions = confidences > 0.5 | |
| predictions = predictions.float().numpy() | |
| pred_labels = np.where(predictions==1)[0] | |
| confidences = confidences[pred_labels].numpy() | |
| pred_labels = [id2label[label] for label in pred_labels] | |
| outputs = {l:c for l,c in zip(pred_labels, confidences)} | |
| return outputs | |
| gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(label='the image contains:'), | |
| examples=examples).queue().launch() |