import gradio as gr import os import torch from model import create_vit_model from timeit import default_timer as timer from typing import Tuple, Dict class_names = ['dew', 'fogsmog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow'] vitb16, vitb16_transforms = create_vit_model(num_classes=len(class_names)) vitb16.load_state_dict( torch.load("vitb16_feature_extractor_weather_rcg.pth", map_location=torch.device("cpu") ) ) def predict(img): start_timer = timer() img = vitb16_transforms(img).unsqueeze(0) vitb16.eval() with torch.inference_mode(): pred_probs = torch.softmax(vitb16(img), dim=1) pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} pred_timer = round(timer()- start_timer, 4) return pred_labels_and_probs, pred_timer title = "Wather Recognition" description = "A ViTb16 Feature Extractor CV model to recognize weather conditions" example_list = [["examples/" + example] for example in os.listdir("examples")] demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=11, label="Predictions"), gr.Number(label="Prediction time(s)")], examples=example_list, title=title, description=description ) demo.launch()