jykoh commited on
Commit
55e476e
1 Parent(s): 5b4ede2

Add clear last round button.

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +17 -7
  3. fromage/models.py +1 -1
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .DS_Store
2
  venv/
 
 
 
1
  .DS_Store
2
  venv/
3
+ __pycache__
4
+ *.pyc
app.py CHANGED
@@ -52,6 +52,13 @@ def reset():
52
  return [[], [], None], []
53
 
54
 
 
 
 
 
 
 
 
55
  def save_image_to_local(image: Image.Image):
56
  # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
57
  filename = next(tempfile._get_candidate_names()) + '.png'
@@ -81,7 +88,7 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
81
 
82
  print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
83
  model_outputs = model.generate_for_images_and_texts(model_inputs,
84
- num_words=num_words, ret_scale_factor=ret_scale_factor, top_p=top_p,
85
  temperature=temperature, max_num_rets=max_nm_rets)
86
  print('model_outputs', model_outputs, flush=True)
87
 
@@ -130,17 +137,19 @@ with gr.Blocks(css=css) as demo:
130
  with gr.Column(scale=0.3, min_width=100):
131
  ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
132
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
133
- gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
134
- gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
135
 
136
  with gr.Column(scale=0.7, min_width=400):
137
  image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
138
  text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
139
 
140
  with gr.Row():
141
- with gr.Column(scale=0.5):
142
  submit_btn = gr.Button("Submit", interactive=True, variant="primary")
143
- with gr.Column(scale=0.5):
 
 
144
  clear_btn = gr.Button("Clear History")
145
 
146
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
@@ -149,9 +158,10 @@ with gr.Blocks(css=css) as demo:
149
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
150
 
151
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
 
152
  clear_btn.click(reset, [], [gr_state, chatbot])
153
  share_button.click(None, [], [], _js=share_js)
154
 
155
 
156
- demo.queue(concurrency_count=1, api_open=False, max_size=16)
157
- demo.launch(debug=True, server_name="0.0.0.0")
 
52
  return [[], [], None], []
53
 
54
 
55
+ def reset_last(state):
56
+ conversation = state[0][:-1]
57
+ chat_history = state[1][:-2]
58
+ input_image = state[2]
59
+ return [conversation, chat_history, input_image], conversation
60
+
61
+
62
  def save_image_to_local(image: Image.Image):
63
  # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
64
  filename = next(tempfile._get_candidate_names()) + '.png'
 
88
 
89
  print('Running model.generate_for_images_and_texts with', model_inputs, flush=True)
90
  model_outputs = model.generate_for_images_and_texts(model_inputs,
91
+ num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
92
  temperature=temperature, max_num_rets=max_nm_rets)
93
  print('model_outputs', model_outputs, flush=True)
94
 
 
137
  with gr.Column(scale=0.3, min_width=100):
138
  ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
139
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
140
+ gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, step=1, interactive=True, label="Max # of words returned")
141
+ gr_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature")
142
 
143
  with gr.Column(scale=0.7, min_width=400):
144
  image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
145
  text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
146
 
147
  with gr.Row():
148
+ with gr.Column(scale=0.33):
149
  submit_btn = gr.Button("Submit", interactive=True, variant="primary")
150
+ with gr.Column(scale=0.33):
151
+ clear_last_btn = gr.Button("Clear Last Round")
152
+ with gr.Column(scale=0.33):
153
  clear_btn = gr.Button("Clear History")
154
 
155
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
 
158
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
159
 
160
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
161
+ clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
162
  clear_btn.click(reset, [], [gr_state, chatbot])
163
  share_button.click(None, [], [], _js=share_js)
164
 
165
 
166
+ # demo.queue(concurrency_count=1, api_open=False, max_size=16)
167
+ demo.launch(debug=True, server_name="127.0.0.1")
fromage/models.py CHANGED
@@ -635,7 +635,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
- debug = False
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
 
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
+ debug = True
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'