File size: 5,298 Bytes
b802c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34d659e
8df48e3
b802c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e010032
b802c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b872a0b
b802c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
b872a0b
b802c2a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import requests
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces

@spaces.GPU
def infer_diagram(image, question):
    model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ai2d-448").to("cuda")
    processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ai2d-448")

    inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")

    predictions = model.generate(**inputs, max_new_tokens=100)
    return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")

@spaces.GPU
def infer_ocrvqa(image, question):
    model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
    processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896")

    inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")

    predictions = model.generate(**inputs, max_new_tokens=100)
    return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")

@spaces.GPU
def infer_infographics(image, question):
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-infovqa-896").to("cuda")
  processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-infovqa-896")

  inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")

  predictions = model.generate(**inputs, max_new_tokens=100)
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_doc(image, question):
  model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")

  processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-docvqa-896")
  inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
  predictions = model.generate(**inputs, max_new_tokens=100)
  return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
  gr.HTML("<h1><center>PaliGemma Fine-tuned on Documents πŸ“„<center><h1>")
  gr.HTML("<h3><center>This Space is built for you to compare different PaliGemma models fine-tuned on document tasks. ⚑</h3>")
  gr.HTML("<h3><center>Each tab in this app demonstrates PaliGemma models fine-tuned on document question answering, infographics question answering, diagram understanding, and reading comprehension from images. πŸ“„πŸ“•πŸ“Š<h3>")
  gr.HTML("<h3><center>Models are downloaded on the go, so first inference in each tab might take time if it's not already downloaded.<h3>")

  with gr.Tab(label="Visual Question Answering over Documents"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Document")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Answer")
    gr.Examples(
    [["assets/docvqa_example.png", "How many items are sold?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    label='Click on any Examples below to get Document Question Answering results quickly πŸ‘‡'
    )

    submit_btn.click(infer_doc, [input_img, question], [output])

  with gr.Tab(label="Visual Question Answering over Infographics"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Image")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Answer")
    gr.Examples(
    [["assets/infographics_example (1).jpeg", "What is this infographic about?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_infographics,
    label='Click on any Examples below to get Infographics QA results quickly πŸ‘‡'
    )

    submit_btn.click(infer_infographics, [input_img, question], [output])
  with gr.Tab(label="Reading from Images"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Document")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Infer")
    submit_btn.click(infer_ocrvqa, [input_img, question], [output])
    gr.Examples(
    [["assets/ocrvqa.jpg", "Who is the author of this book?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    label='Click on any Examples below to get image reading comprehension results quickly πŸ‘‡'
    )
  with gr.Tab(label="Diagram Understanding"):
    with gr.Row():
      with gr.Column():
        input_img = gr.Image(label="Input Diagram")
        question = gr.Text(label="Question")
        submit_btn = gr.Button(value="Submit")
      output = gr.Text(label="Infer")
    submit_btn.click(infer_diagram, [input_img, question], [output])
    gr.Examples(
    [["assets/diagram.png", "What is the diagram showing?"]],
    inputs = [input_img, question],
    outputs = [output],
    fn=infer_doc,
    label='Click on any Examples below to get diagram understanding results quickly πŸ‘‡'
    )

demo.launch(debug=True)