import torch | |
import torchvision | |
from model import efficient_transformer , efficient_model | |
CLASS_NAMES = ['pizza', 'steak', 'sushi'] | |
def predict_gradio(image): | |
image = efficient_transformer(image) | |
efficient_model.eval() | |
with torch.no_grad(): | |
pred = efficient_model(torch.unsqueeze(image , dim = 0)) | |
prediction_per_labels = {CLASS_NAMES[i]: float(torch.sigmoid(pred[0][i])) for i in range(len(CLASS_NAMES))} | |
prediction = CLASS_NAMES[torch.argmax(pred).item()] | |
return prediction_per_labels , prediction | |