ixxan commited on
Commit
087ba34
1 Parent(s): 7e69a2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -21,26 +21,27 @@ vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa"
21
  vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
22
 
23
  def vilt_vqa(image, question):
24
- inputs = vilt_processor(image, question, return_tensors="pt")
25
- with torch.no_grad():
26
- outputs = vilt_model(**inputs)
27
- logits = outputs.logits
28
- idx = logits.argmax(-1).item()
29
- answer = vilt_model.config.id2label[idx]
30
- logger.info("ViLT: " + answer)
31
- return answer
32
 
33
  # Load FLAN-T5
34
  t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
35
  t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
36
 
37
  def flan_t5_complete_sentence(question, answer):
38
- input_text = f"A question: {question} An incomplete answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
39
- logger.info("T5 input: " + input_text)
40
- inputs = t5_tokenizer(input_text, return_tensors="pt")
41
- outputs = t5_model.generate(**inputs, max_length=50)
42
- result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
43
- return result_sentence
 
44
 
45
  # Main function
46
  def vqa_main(image, question):
@@ -48,7 +49,7 @@ def vqa_main(image, question):
48
  vqa_answer = vilt_vqa(image, en_question)
49
  llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
50
  final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
51
-
52
  return final_answer
53
 
54
  # Home page text
 
21
  vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
22
 
23
  def vilt_vqa(image, question):
24
+ inputs = vilt_processor(image, question, return_tensors="pt")
25
+ with torch.no_grad():
26
+ outputs = vilt_model(**inputs)
27
+ logits = outputs.logits
28
+ idx = logits.argmax(-1).item()
29
+ answer = vilt_model.config.id2label[idx]
30
+ logger.info("ViLT: " + answer)
31
+ return answer
32
 
33
  # Load FLAN-T5
34
  t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
35
  t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
36
 
37
  def flan_t5_complete_sentence(question, answer):
38
+ input_text = f"A question: {question} An incomplete answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
39
+ logger.info("T5 input: " + input_text)
40
+ inputs = t5_tokenizer(input_text, return_tensors="pt")
41
+ outputs = t5_model.generate(**inputs, max_length=50)
42
+ result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
43
+ logger.info("T5 output: " + result_sentence)
44
+ return result_sentence
45
 
46
  # Main function
47
  def vqa_main(image, question):
 
49
  vqa_answer = vilt_vqa(image, en_question)
50
  llm_answer = flan_t5_complete_sentence(en_question, vqa_answer)[0]
51
  final_answer, answer_src_lang = google_translate(llm_answer, dest=question_src_lang)
52
+ logger.info("Final Answer: " + final_answer)
53
  return final_answer
54
 
55
  # Home page text