File size: 526 Bytes
79cd580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torchvision
from model import efficient_transformer , efficient_model
CLASS_NAMES = ['pizza', 'steak', 'sushi']
def predict_gradio(image):
image = efficient_transformer(image)
efficient_model.eval()
with torch.no_grad():
pred = efficient_model(torch.unsqueeze(image , dim = 0))
prediction_per_labels = {CLASS_NAMES[i]: float(torch.sigmoid(pred[0][i])) for i in range(len(CLASS_NAMES))}
prediction = CLASS_NAMES[torch.argmax(pred).item()]
return prediction_per_labels , prediction
|