MohammadAliMKH's picture
Upload 7 files
79cd580
raw
history blame
526 Bytes
import torch
import torchvision
from model import efficient_transformer , efficient_model
CLASS_NAMES = ['pizza', 'steak', 'sushi']
def predict_gradio(image):
image = efficient_transformer(image)
efficient_model.eval()
with torch.no_grad():
pred = efficient_model(torch.unsqueeze(image , dim = 0))
prediction_per_labels = {CLASS_NAMES[i]: float(torch.sigmoid(pred[0][i])) for i in range(len(CLASS_NAMES))}
prediction = CLASS_NAMES[torch.argmax(pred).item()]
return prediction_per_labels , prediction