Dongxu Li commited on
Commit
8f68280
1 Parent(s): 30474d6

add generation options.

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +140 -73
  3. house.png +3 -0
  4. sunset.png +3 -0
  5. utils.py +24 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
36
+ house.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,12 +1,12 @@
1
- from PIL import Image
2
 
3
- import requests
4
- import json
5
  import gradio as gr
 
 
 
6
 
7
 
8
- from io import BytesIO
9
-
10
  def encode_image(image):
11
  buffered = BytesIO()
12
  image.save(buffered, format="JPEG")
@@ -15,16 +15,19 @@ def encode_image(image):
15
  return buffered
16
 
17
 
18
- def query_api(image, prompt, decoding_method):
19
- # local host for testing
20
- url = "http://34.132.142.70:5000/api/generate"
 
21
 
22
- headers = {
23
- 'User-Agent': 'BLIP-2 HuggingFace Space'
 
 
 
 
24
  }
25
 
26
- data = {"prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling"}
27
-
28
  image = encode_image(image)
29
  files = {"image": image}
30
 
@@ -36,80 +39,144 @@ def query_api(image, prompt, decoding_method):
36
  return "Error: " + response.text
37
 
38
 
39
- def prepend_question(text):
40
- text = text.strip().lower()
41
-
42
- return "question: " + text
43
-
44
-
45
- def prepend_answer(text):
46
- text = text.strip().lower()
47
 
48
- return "answer: " + text
49
 
50
 
51
- def get_prompt_from_history(history):
52
- prompts = []
53
-
54
- for i in range(len(history)):
55
- if i % 2 == 0:
56
- prompts.append(prepend_question(history[i]))
57
- else:
58
- prompts.append(prepend_answer(history[i]))
59
-
60
- return "\n".join(prompts)
61
-
62
-
63
- def postp_answer(text):
64
- if text.startswith("answer: "):
65
- return text[8:]
66
- elif text.startswith("a: "):
67
- return text[2:]
68
- else:
69
- return text
70
 
 
71
 
72
- def prep_question(text):
73
- if text.startswith("question: "):
74
- text = text[10:]
75
- elif text.startswith("q: "):
76
- text = text[2:]
77
-
78
- if not text.endswith("?"):
79
- text += "?"
80
-
81
- return text
82
 
 
 
 
83
 
84
- def inference(image, text_input, decoding_method, history=[]):
85
- text_input = prep_question(text_input)
86
- history.append(text_input)
87
 
88
- # prompt = '\n'.join(history)
89
- prompt = get_prompt_from_history(history)
90
- # print("prompt: " + prompt)
91
 
92
- output = query_api(image, prompt, decoding_method)
93
- output = [postp_answer(output[0])]
94
- history += output
95
-
96
- chat = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list
97
-
98
- return chat, history
99
 
 
 
 
 
 
 
 
100
 
101
- inputs = [gr.inputs.Image(type='pil'),
102
- gr.inputs.Textbox(lines=2, label="Text input"),
103
- gr.inputs.Radio(choices=['Nucleus sampling','Beam search'], type="value", default="Nucleus sampling", label="Text Decoding Method"),
104
- "state",
105
- ]
106
 
107
- outputs = ["chatbot", "state"]
108
-
109
- title = "BLIP-2"
110
  description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
111
  <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
112
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
113
 
114
- iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article)
115
- iface.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
 
3
+ import string
 
4
  import gradio as gr
5
+ import requests
6
+ from PIL import Image
7
+ from utils import Endpoint
8
 
9
 
 
 
10
  def encode_image(image):
11
  buffered = BytesIO()
12
  image.save(buffered, format="JPEG")
15
  return buffered
16
 
17
 
18
+ def query_api(image, prompt, decoding_method, temperature, len_penalty, repetition_penalty):
19
+ url = endpoint.url
20
+
21
+ headers = {"User-Agent": "BLIP-2 HuggingFace Space"}
22
 
23
+ data = {
24
+ "prompt": prompt,
25
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
26
+ "temperature": temperature,
27
+ "length_penalty": len_penalty,
28
+ "repetition_penalty": repetition_penalty,
29
  }
30
 
 
 
31
  image = encode_image(image)
32
  files = {"image": image}
33
 
39
  return "Error: " + response.text
40
 
41
 
42
+ def postprocess_output(output):
43
+ # if last character is not a punctuation, add a full stop
44
+ if not output[0][-1] in string.punctuation:
45
+ output[0] += "."
 
 
 
 
46
 
47
+ return output
48
 
49
 
50
+ def inference(
51
+ image,
52
+ text_input,
53
+ decoding_method,
54
+ temperature,
55
+ length_penalty,
56
+ repetition_penalty,
57
+ history=[],
58
+ ):
59
+ text_input = text_input
60
+ history.append(text_input)
 
 
 
 
 
 
 
 
61
 
62
+ prompt = " ".join(history)
63
 
64
+ output = query_api(image, prompt, decoding_method, temperature, length_penalty, repetition_penalty)
65
+ output = postprocess_output(output)
66
+ history += output
 
 
 
 
 
 
 
67
 
68
+ chat = [
69
+ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
70
+ ] # convert to tuples of list
71
 
72
+ return chat, history
 
 
73
 
 
 
 
74
 
75
+ # image source: https://m.facebook.com/112483753737319/photos/112489593736735/
76
+ endpoint = Endpoint()
 
 
 
 
 
77
 
78
+ examples = [
79
+ ["house.png", "How could someone get out of the house?"],
80
+ [
81
+ "sunset.png",
82
+ "Write a romantic message that goes along this photo.",
83
+ ],
84
+ ]
85
 
86
+ # outputs = ["chatbot", "state"]
 
 
 
 
87
 
88
+ title = """<h1 align="center">BLIP-2</h1>"""
 
 
89
  description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p>
90
  <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
91
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"
92
 
93
+ # iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples)
94
+
95
+
96
+ def reset_all(text_input, image_input, chatbot, history):
97
+ return "", None, None, []
98
+
99
+
100
+ def reset_chatbot(chatbot, history):
101
+ return None, []
102
+
103
+
104
+ with gr.Blocks() as iface:
105
+ state = gr.State([])
106
+
107
+ gr.Markdown(title)
108
+ gr.Markdown(description)
109
+ gr.Markdown(article)
110
+ with gr.Row():
111
+ with gr.Column():
112
+ image_input = gr.Image(type="pil")
113
+ text_input = gr.Textbox(lines=2, label="Text input")
114
+
115
+ sampling = gr.Radio(
116
+ choices=["Beam search", "Nucleus sampling"],
117
+ value="Beam search",
118
+ label="Text Decoding Method",
119
+ interactive=True,
120
+ )
121
+
122
+ with gr.Row():
123
+ temperature = gr.Slider(
124
+ minimum=0.5,
125
+ maximum=1.0,
126
+ value=0.8,
127
+ interactive=True,
128
+ label="Temperature",
129
+ )
130
+
131
+ len_penalty = gr.Slider(
132
+ minimum=-2.0,
133
+ maximum=2.0,
134
+ value=1.0,
135
+ step=0.5,
136
+ interactive=True,
137
+ label="Length Penalty",
138
+ )
139
+
140
+ rep_penalty = gr.Slider(
141
+ minimum=1.0,
142
+ maximum=10.0,
143
+ value=1.0,
144
+ step=0.5,
145
+ interactive=True,
146
+ label="Repetition Penalty",
147
+ )
148
+
149
+ with gr.Column():
150
+ chatbot = gr.Chatbot()
151
+
152
+ with gr.Row():
153
+ clear_button = gr.Button(value="Clear", interactive=True)
154
+ clear_button.click(
155
+ reset_all,
156
+ [text_input, image_input, chatbot, state],
157
+ [text_input, image_input, chatbot, state],
158
+ )
159
+
160
+ submit_button = gr.Button(value="Submit", interactive=True, variant="primary")
161
+ submit_button.click(
162
+ inference,
163
+ [
164
+ image_input,
165
+ text_input,
166
+ sampling,
167
+ temperature,
168
+ len_penalty,
169
+ state,
170
+ ],
171
+ [chatbot, state],
172
+ )
173
+
174
+ image_input.change(reset_chatbot, [chatbot, state], [chatbot, state])
175
+
176
+ examples = gr.Examples(
177
+ examples=examples,
178
+ inputs=[image_input, text_input],
179
+ )
180
+
181
+ iface.queue(concurrency_count=1)
182
+ iface.launch(enable_queue=True, debug=True)
house.png ADDED

Git LFS Details

  • SHA256: a7b8999524f8f178a43d3417b9f7dfa80d8aff7ccb7ea1b5ba0e5f96bc17bdc0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
sunset.png ADDED

Git LFS Details

  • SHA256: 9a3778b1890ee461c7b052a5f25ce566ffbd706d6c2beb7280f1393052808008
  • Pointer size: 130 Bytes
  • Size of remote file: 78 kB
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+
4
+ class Endpoint:
5
+ def __init__(self):
6
+ self.config_path = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/blip2/config.json"
7
+
8
+ self._url = None
9
+
10
+ @property
11
+ def url(self):
12
+ if self._url is None:
13
+ self._url = self.get_url()
14
+
15
+ return self._url
16
+
17
+ def get_url(self):
18
+ response = requests.get(self.config_path)
19
+ config = response.json()
20
+
21
+ return config["endpoint"]
22
+
23
+
24
+