File size: 526 Bytes
79cd580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torchvision
from model import efficient_transformer , efficient_model

CLASS_NAMES = ['pizza', 'steak', 'sushi']

def predict_gradio(image):
  
  image = efficient_transformer(image)

  efficient_model.eval()

  with torch.no_grad():
    pred = efficient_model(torch.unsqueeze(image , dim = 0))

  prediction_per_labels = {CLASS_NAMES[i]: float(torch.sigmoid(pred[0][i])) for i in range(len(CLASS_NAMES))}

  prediction = CLASS_NAMES[torch.argmax(pred).item()]

  return prediction_per_labels , prediction