Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -66,10 +66,14 @@ def vilt_vqa(image, question):
|
|
66 |
with torch.no_grad():
|
67 |
outputs = vilt_model(**inputs)
|
68 |
logits = outputs.logits
|
69 |
-
logger.info("ViLT logits:" + logits)
|
70 |
idx = logits.argmax(-1).item()
|
71 |
answer = vilt_model.config.id2label[idx]
|
72 |
logger.info("ViLT: " + answer)
|
|
|
|
|
|
|
|
|
|
|
73 |
return answer
|
74 |
|
75 |
# Load FLAN-T5
|
|
|
66 |
with torch.no_grad():
|
67 |
outputs = vilt_model(**inputs)
|
68 |
logits = outputs.logits
|
|
|
69 |
idx = logits.argmax(-1).item()
|
70 |
answer = vilt_model.config.id2label[idx]
|
71 |
logger.info("ViLT: " + answer)
|
72 |
+
|
73 |
+
# Get the top 10 scores and their indices
|
74 |
+
topk_values, topk_indices = torch.topk(logits, 10, dim=-1)
|
75 |
+
topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]]
|
76 |
+
logger.info("ViLT top 10 answers: " + str(topk_answers))
|
77 |
return answer
|
78 |
|
79 |
# Load FLAN-T5
|