from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import requests import os import gradio as gr from timeit import default_timer as timer from typing import Tuple, Dict def predict(img) -> Tuple[Dict, float]: start_time = timer() processor = ViTImageProcessor.from_pretrained('bazyl/gtsrb-model') model = ViTForImageClassification.from_pretrained('bazyl/gtsrb-model') inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) title = "GTSRB - German Traffic Sign Recognition by Bazyl Horsey" description = "CNN created for the GTSRB Dataset, achieved 99.93% test accuracy" # Create examples list from "examples/" directory example_list = [["examples/" + example] for example in os.listdir("examples")] # Create Gradio interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=5, label="Predictions"), gr.Number(label="Prediction time (s)"), ], examples=example_list, title=title, description=description, ) # Launch the app! demo.launch()