ixxan commited on
Commit
30a0050
1 Parent(s): d0bee02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -66,10 +66,14 @@ def vilt_vqa(image, question):
66
  with torch.no_grad():
67
  outputs = vilt_model(**inputs)
68
  logits = outputs.logits
69
- logger.info("ViLT logits:" + logits)
70
  idx = logits.argmax(-1).item()
71
  answer = vilt_model.config.id2label[idx]
72
  logger.info("ViLT: " + answer)
 
 
 
 
 
73
  return answer
74
 
75
  # Load FLAN-T5
 
66
  with torch.no_grad():
67
  outputs = vilt_model(**inputs)
68
  logits = outputs.logits
 
69
  idx = logits.argmax(-1).item()
70
  answer = vilt_model.config.id2label[idx]
71
  logger.info("ViLT: " + answer)
72
+
73
+ # Get the top 10 scores and their indices
74
+ topk_values, topk_indices = torch.topk(logits, 10, dim=-1)
75
+ topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]]
76
+ logger.info("ViLT top 10 answers: " + str(topk_answers))
77
  return answer
78
 
79
  # Load FLAN-T5