import torch import torchvision from model import efficient_transformer , efficient_model CLASS_NAMES = ['pizza', 'steak', 'sushi'] def predict_gradio(image): image = efficient_transformer(image) efficient_model.eval() with torch.no_grad(): pred = efficient_model(torch.unsqueeze(image , dim = 0)) prediction_per_labels = {CLASS_NAMES[i]: float(torch.sigmoid(pred[0][i])) for i in range(len(CLASS_NAMES))} prediction = CLASS_NAMES[torch.argmax(pred).item()] return prediction_per_labels , prediction