File size: 2,043 Bytes
3bad857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33fb0ac
3bad857
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python
# coding: utf-8

# In[1]:


# import required libraries
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
import gradio as gr


# In[2]:


# pipeline function with default values
def query(image, user_question):
    """
    image: single image or batch of images;
    question: user prompt question;
    """
    # select model from hugging face
    model_name = "google/deplot"
    # set preprocessor for current model
    processor = Pix2StructProcessor.from_pretrained(model_name)
    # load pre-trained model
    model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
    # process the inputs for prediction
    inputs = processor(images=image, text=user_question, return_tensors="pt")
    # save the results
    predictions = model.generate(**inputs, max_new_tokens=512)
    # save output
    result = processor.decode(predictions[0], skip_special_tokens=True)
    # process the results for output table
    outs = [x.strip() for x in result.split("<0x0A>")]
    # create an empty list
    nested = list()
    # loop for splitting the data
    for data in outs:
        if "|" in data:
            nested.append([x.strip() for x in data.split("|")])
        else:
            nested.append(data)
    # return the converted output
    return nested


# In[ ]:


# Interface framework to customize the io page    
ui = gr.Interface(title="Chart Q/A",
                  fn=query,
                  inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")],
                  outputs="list",
                  examples=[["./samples/sample1.png", "Generate underlying data table of the figure"], 
                             ["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]],
                             # ["./samples/sample3.webp", "What are the 2020 net sales?"]],
                  cache_examples=True,
                  allow_flagging='never')

ui.queue(api_open=True)
ui.launch(inline=False, share=False, debug=True)