egmaminta commited on
Commit
59df8b5
1 Parent(s): 96396c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -77,14 +77,15 @@ labels = {
77
  "66": "winecellar"
78
  }
79
 
 
 
80
  def classify(image):
81
- model.eval()
82
  with torch.no_grad():
83
  inputs = extractor(images=image, return_tensors='pt')
84
  outputs = model(**inputs).logits
85
  outputs = rearrange(outputs, '1 j->j')
 
86
  outputs = outputs.cpu().numpy()
87
- outputs = (numpy.exp(outputs)) / (numpy.sum(numpy.exp(outputs)))
88
  return {labels[str(i)]: float(outputs[i]) for i in range(len(labels))}
89
 
90
  gradio.Interface(fn=classify,
 
77
  "66": "winecellar"
78
  }
79
 
80
+ model.eval()
81
+
82
  def classify(image):
 
83
  with torch.no_grad():
84
  inputs = extractor(images=image, return_tensors='pt')
85
  outputs = model(**inputs).logits
86
  outputs = rearrange(outputs, '1 j->j')
87
+ outputs = torch.nn.functional.softmax(outputs)
88
  outputs = outputs.cpu().numpy()
 
89
  return {labels[str(i)]: float(outputs[i]) for i in range(len(labels))}
90
 
91
  gradio.Interface(fn=classify,