Spaces:
Build error
Build error
| """ | |
| Ryan Tietjen | |
| Aug 2024 | |
| Demo application for a food classificiation demonstration | |
| """ | |
| from model import vit_b_16 | |
| from timeit import default_timer as timer | |
| import torch | |
| import os | |
| import gradio as gr | |
| with open("class_names.txt", 'r') as file: | |
| class_names = [line.strip() for line in file] | |
| model, transforms = vit_b_16(num_classes=101, | |
| seed=31, | |
| freeze_gradients=True, | |
| unfreeze_blocks=0) | |
| model.load_state_dict(torch.load('vit_b_16_unfreeze_one_encoder_block_10_total_epochs.pth', | |
| map_location=torch.device("cpu"))) | |
| def predict_single_image(img): | |
| start_time = timer() | |
| model.eval() | |
| #Add batch dim | |
| img = transforms(img).unsqueeze(dim=0) | |
| with torch.inference_mode(): | |
| # Obtain prediction logits -> prediction probabilities from image | |
| logits = model(img) | |
| probabilities = torch.softmax(logits, dim=1) | |
| class_probabilities = {} | |
| for i in range(len(class_names)): | |
| class_probabilities[class_names[i]] = float(probabilities[0][i]) | |
| end_time = timer() | |
| pred_time = round(end_time - start_time, 3) | |
| return class_probabilities, pred_time | |
| title = "Food Image Classification With PyTorch by Ryan Tietjen" | |
| description = f""" | |
| Determines what type of food is presented in a given image. | |
| This model is capable of classifying [101 different types of food](https://github.com/RyanTietjen/Food-Classifier-pytorch-ver.-/blob/main/demo/class_names.txt) by | |
| utilizing a [pre-trained Vision Transformer](https://pytorch.org/vision/stable/models/generated/torchvision.models.vit_b_16.html#torchvision.models.ViT_B_16_Weights), | |
| and fine-tuning the results for specific food categories. You can find more information in [my GitHub repo.](https://github.com/RyanTietjen/Food-Classifier-pytorch-ver.-) | |
| This model achieved a Top-1 accuracy of 91.55% and a Top-5 accuracy of 98.56% | |
| """ | |
| sample_list = [["samples/" + sample] for sample in os.listdir("samples")] | |
| #Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_single_image, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Label(num_top_classes=5, label="Predictions"), | |
| gr.Number(label="Prediction time (s)"), | |
| ], | |
| examples=sample_list, | |
| title=title, | |
| description=description, | |
| ) | |
| demo.launch(share=True) |