jykoh commited on
Commit
5b4ede2
1 Parent(s): 278032e

Add submit button

Browse files
Files changed (2) hide show
  1. app.py +12 -4
  2. fromage/models.py +14 -10
app.py CHANGED
@@ -127,19 +127,27 @@ with gr.Blocks(css=css) as demo:
127
  share_button = gr.Button("Share to community", elem_id="share-btn")
128
 
129
  with gr.Row():
130
- with gr.Column(scale=0.3, min_width=0):
131
  ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
132
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
133
  gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
134
  gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
135
 
136
- with gr.Column(scale=0.7, min_width=0):
137
  image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
138
- text_input = gr.Textbox(label="Text Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
139
- clear_btn = gr.Button("Clear History")
 
 
 
 
 
140
 
141
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
142
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
 
 
 
143
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
144
  clear_btn.click(reset, [], [gr_state, chatbot])
145
  share_button.click(None, [], [], _js=share_js)
 
127
  share_button = gr.Button("Share to community", elem_id="share-btn")
128
 
129
  with gr.Row():
130
+ with gr.Column(scale=0.3, min_width=100):
131
  ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
132
  max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
133
  gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
134
  gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
135
 
136
+ with gr.Column(scale=0.7, min_width=400):
137
  image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
138
+ text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
139
+
140
+ with gr.Row():
141
+ with gr.Column(scale=0.5):
142
+ submit_btn = gr.Button("Submit", interactive=True, variant="primary")
143
+ with gr.Column(scale=0.5):
144
+ clear_btn = gr.Button("Clear History")
145
 
146
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
147
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
148
+ submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
149
+ submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
150
+
151
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
152
  clear_btn.click(reset, [], [gr_state, chatbot])
153
  share_button.click(None, [], [], _js=share_js)
fromage/models.py CHANGED
@@ -634,21 +634,25 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
- # model_kwargs['opt_version'] = 'facebook/opt-125m'
638
- # model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
 
 
 
639
  args = namedtuple('args', model_kwargs)(**model_kwargs)
640
 
641
  # Initialize model for inference.
642
  model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
643
  model = model.eval()
644
- model = model.bfloat16()
645
- model = model.cuda()
646
-
647
- # Load pretrained linear mappings and [RET] embeddings.
648
- checkpoint = torch.load(model_ckpt_path)
649
- model.load_state_dict(checkpoint['state_dict'], strict=False)
650
- with torch.no_grad():
651
- model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
 
652
 
653
  logit_scale = model.model.logit_scale.exp()
654
  emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
 
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
+
638
+ debug = False
639
+ if debug:
640
+ model_kwargs['opt_version'] = 'facebook/opt-125m'
641
+ model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
642
  args = namedtuple('args', model_kwargs)(**model_kwargs)
643
 
644
  # Initialize model for inference.
645
  model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
646
  model = model.eval()
647
+ if not debug:
648
+ model = model.bfloat16()
649
+ model = model.cuda()
650
+
651
+ # Load pretrained linear mappings and [RET] embeddings.
652
+ checkpoint = torch.load(model_ckpt_path)
653
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
654
+ with torch.no_grad():
655
+ model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
656
 
657
  logit_scale = model.model.logit_scale.exp()
658
  emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)