Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,26 +21,27 @@ vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa"
|
|
21 |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
22 |
|
23 |
def vilt_vqa(image, question):
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
# Load FLAN-T5
|
34 |
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
|
35 |
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
|
36 |
|
37 |
def flan_t5_complete_sentence(question, answer):
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
# Main function
|
46 |
def vqa_main(image, question):
|
@@ -48,7 +49,7 @@ def vqa_main(image, question):
|
|
48 |
vqa_answer = vilt_vqa(image, en_question)
|
49 |
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
|
50 |
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
|
51 |
-
|
52 |
return final_answer
|
53 |
|
54 |
# Home page text
|
|
|
21 |
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
22 |
|
23 |
def vilt_vqa(image, question):
|
24 |
+
inputs = vilt_processor(image, question, return_tensors="pt")
|
25 |
+
with torch.no_grad():
|
26 |
+
outputs = vilt_model(**inputs)
|
27 |
+
logits = outputs.logits
|
28 |
+
idx = logits.argmax(-1).item()
|
29 |
+
answer = vilt_model.config.id2label[idx]
|
30 |
+
logger.info("ViLT: " + answer)
|
31 |
+
return answer
|
32 |
|
33 |
# Load FLAN-T5
|
34 |
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
|
35 |
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
|
36 |
|
37 |
def flan_t5_complete_sentence(question, answer):
|
38 |
+
input_text = f"A question: {question} An incomplete answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
|
39 |
+
logger.info("T5 input: " + input_text)
|
40 |
+
inputs = t5_tokenizer(input_text, return_tensors="pt")
|
41 |
+
outputs = t5_model.generate(**inputs, max_length=50)
|
42 |
+
result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
43 |
+
logger.info("T5 output: " + result_sentence)
|
44 |
+
return result_sentence
|
45 |
|
46 |
# Main function
|
47 |
def vqa_main(image, question):
|
|
|
49 |
vqa_answer = vilt_vqa(image, en_question)
|
50 |
llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
|
51 |
final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
|
52 |
+
logger.info("Final Answer: " + final_answer)
|
53 |
return final_answer
|
54 |
|
55 |
# Home page text
|