jykoh commited on
Commit
d32e597
1 Parent(s): aeef1a1

Add UI changes

Browse files
Files changed (2) hide show
  1. app.py +22 -9
  2. fromage/models.py +1 -1
app.py CHANGED
@@ -45,9 +45,10 @@ model = models.load_fromage('./', args_path, ckpt_path)
45
  def upload_image(state, image_input):
46
  conversation = state[0]
47
  chat_history = state[1]
48
- conversation += [(f"![](/file={image_input.name})", "")]
49
  input_image = Image.open(image_input.name).resize(
50
  (224, 224)).convert('RGB')
 
 
51
  return [conversation, chat_history, input_image], conversation
52
 
53
 
@@ -69,7 +70,11 @@ def save_image_to_local(image: Image.Image):
69
  return filename
70
 
71
 
72
- def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
 
 
 
 
73
  input_prompt = 'Q: ' + input_text + '\nA:'
74
  conversation = state[0]
75
  chat_history = state[1]
@@ -93,7 +98,7 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
93
  model_inputs, flush=True)
94
  model_outputs = model.generate_for_images_and_texts(model_inputs,
95
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
96
- temperature=temperature, max_num_rets=max_nm_rets)
97
  print('model_outputs', model_outputs, flush=True)
98
 
99
  im_names = []
@@ -104,12 +109,16 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
104
  text_outputs.append(output)
105
  response += output
106
  elif type(output) == list:
 
107
  for image in output:
108
  filename = save_image_to_local(image)
109
- response += f'<img src="/file={filename}">'
 
110
  elif type(output) == Image.Image:
111
  filename = save_image_to_local(output)
112
- response += f'<img src="/file={filename}">'
 
 
113
 
114
  # TODO(jykoh): Persist image inputs.
115
  chat_history = model_inputs + \
@@ -165,10 +174,14 @@ with gr.Blocks(css=css) as demo:
165
  clear_btn = gr.Button("Clear All")
166
 
167
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
168
- max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group])
 
 
169
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
170
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
171
- max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group])
 
 
172
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
173
 
174
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
@@ -177,5 +190,5 @@ with gr.Blocks(css=css) as demo:
177
  share_button.click(None, [], [], _js=share_js)
178
 
179
 
180
- demo.queue(concurrency_count=1, api_open=False, max_size=16)
181
- demo.launch(debug=True, server_name="0.0.0.0")
 
45
  def upload_image(state, image_input):
46
  conversation = state[0]
47
  chat_history = state[1]
 
48
  input_image = Image.open(image_input.name).resize(
49
  (224, 224)).convert('RGB')
50
+ input_image.save(image_input.name) # Overwrite with smaller image.
51
+ conversation += [(f"![](/file={image_input.name})", "")]
52
  return [conversation, chat_history, input_image], conversation
53
 
54
 
 
70
  return filename
71
 
72
 
73
+ def generate_for_prompt(input_text, state, ret_scale_factor, max_num_rets, num_words, temperature):
74
+ # Ignore empty inputs.
75
+ if len(input_text) == 0:
76
+ return state, state[0], gr.update(visible=True)
77
+
78
  input_prompt = 'Q: ' + input_text + '\nA:'
79
  conversation = state[0]
80
  chat_history = state[1]
 
98
  model_inputs, flush=True)
99
  model_outputs = model.generate_for_images_and_texts(model_inputs,
100
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
101
+ temperature=temperature, max_num_rets=max_num_rets)
102
  print('model_outputs', model_outputs, flush=True)
103
 
104
  im_names = []
 
109
  text_outputs.append(output)
110
  response += output
111
  elif type(output) == list:
112
+ response += '<br/>' # Add line break between images.
113
  for image in output:
114
  filename = save_image_to_local(image)
115
+ response += f'<img src="/file={filename}" style="display: inline-block;">'
116
+ response += '<br/>'
117
  elif type(output) == Image.Image:
118
  filename = save_image_to_local(output)
119
+ response += '<br/>'
120
+ response += f'<img src="/file={filename}" style="display: inline-block;">'
121
+ response += '<br/>'
122
 
123
  # TODO(jykoh): Persist image inputs.
124
  chat_history = model_inputs + \
 
174
  clear_btn = gr.Button("Clear All")
175
 
176
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
177
+ max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group],
178
+ )
179
+ # _js = "() => document.getElementById('#chatbot').scrollTop = document.getElementById('#chatbot').scrollHeight")
180
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
181
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
182
+ max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group],
183
+ )
184
+ # _js = "() => document.getElementById('#chatbot').scrollTop = document.getElementById('#chatbot').scrollHeight")
185
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
186
 
187
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
 
190
  share_button.click(None, [], [], _js=share_js)
191
 
192
 
193
+ # demo.queue(concurrency_count=1, api_open=False, max_size=16)
194
+ demo.launch(debug=True, server_name="127.0.0.1")
fromage/models.py CHANGED
@@ -635,7 +635,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
- debug = False
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
 
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
+ debug = True
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'