ixxan commited on
Commit
775f1ae
1 Parent(s): dc1676e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -2,35 +2,52 @@ import gradio as gr
2
  from transformers import ViltProcessor, ViltForQuestionAnswering
3
  import torch
4
 
5
- # Load example images
6
- torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
7
-
8
  # Load Vilt
9
  vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
  vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
 
11
  def vilt_vqa(image, question):
12
- # prepare inputs
13
  inputs = vilt_processor(image, question, return_tensors="pt")
14
-
15
- # forward pass
16
  with torch.no_grad():
17
  outputs = vilt_model(**inputs)
18
-
19
  logits = outputs.logits
20
  idx = logits.argmax(-1).item()
21
  answer = vilt_model.config.id2label[idx]
22
  return answer
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  image = gr.inputs.Image(type="pil")
25
  question = gr.inputs.Textbox(label="Question")
26
  answer = gr.outputs.Textbox(label="Predicted answer")
27
  examples = [["cats.jpg", "What are the animals here called?"]]
28
 
29
- title = "Interactive demo: Multilingual VQA"
30
- description = "Demo for Multilingual VQA. Upload an image, type a question, click 'submit', or click one of the examples to load them."
31
- article = "article"
32
-
33
- interface = gr.Interface(fn=vilt_vqa,
34
  inputs=[image, question],
35
  outputs=answer,
36
  examples=examples,
 
2
  from transformers import ViltProcessor, ViltForQuestionAnswering
3
  import torch
4
 
 
 
 
5
  # Load Vilt
6
  vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
  vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
8
+
9
  def vilt_vqa(image, question):
 
10
  inputs = vilt_processor(image, question, return_tensors="pt")
 
 
11
  with torch.no_grad():
12
  outputs = vilt_model(**inputs)
 
13
  logits = outputs.logits
14
  idx = logits.argmax(-1).item()
15
  answer = vilt_model.config.id2label[idx]
16
  return answer
17
 
18
+ # Load FLAN-T5
19
+ t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
20
+ t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")
21
+
22
+ def flan_t5_complete_sentence(question, answer):
23
+ input_text = f"A question: {question} An incomplete answer: {answer}. Based on these, answer the question with a complete sentence without extra information."
24
+ print(input_text)
25
+ inputs = t5_tokenizer(input_text, return_tensors="pt")
26
+ outputs = t5_model.generate(**inputs, max_length=50)
27
+ result_sentence = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
28
+ return result_sentence
29
+
30
+ # Main function
31
+ def vqa_main(image, question):
32
+ incomplete_answer = vilt_vqa(image, question)
33
+ complete_answer = flan_t5_complete_sentence(question, answer)
34
+ return complete_answer
35
+
36
+ # Home page text
37
+ title = "Interactive demo: Multilingual VQA"
38
+ description = "Demo for Multilingual VQA. Upload an image, type a question, click 'submit', or click one of the examples to load them."
39
+ article = "article goes here"
40
+
41
+ # Load example images
42
+ torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
43
+
44
+ # Define home page variables
45
  image = gr.inputs.Image(type="pil")
46
  question = gr.inputs.Textbox(label="Question")
47
  answer = gr.outputs.Textbox(label="Predicted answer")
48
  examples = [["cats.jpg", "What are the animals here called?"]]
49
 
50
+ interface = gr.Interface(fn=vqa_main,
 
 
 
 
51
  inputs=[image, question],
52
  outputs=answer,
53
  examples=examples,