Spaces:
Build error
Build error
Add submit button
Browse files- app.py +12 -4
- fromage/models.py +14 -10
app.py
CHANGED
@@ -127,19 +127,27 @@ with gr.Blocks(css=css) as demo:
|
|
127 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
128 |
|
129 |
with gr.Row():
|
130 |
-
with gr.Column(scale=0.3, min_width=
|
131 |
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
132 |
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
133 |
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
134 |
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
135 |
|
136 |
-
with gr.Column(scale=0.7, min_width=
|
137 |
image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
|
138 |
-
text_input = gr.Textbox(label="
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
142 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
|
|
|
|
|
|
143 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
144 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
145 |
share_button.click(None, [], [], _js=share_js)
|
|
|
127 |
share_button = gr.Button("Share to community", elem_id="share-btn")
|
128 |
|
129 |
with gr.Row():
|
130 |
+
with gr.Column(scale=0.3, min_width=100):
|
131 |
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
132 |
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
133 |
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
134 |
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
135 |
|
136 |
+
with gr.Column(scale=0.7, min_width=400):
|
137 |
image_btn = gr.UploadButton("🖼️ Image Input", file_types=["image"])
|
138 |
+
text_input = gr.Textbox(label="Chat Input", lines=1, placeholder="Upload an image above [optional]. Then enter a text prompt, and press enter!")
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
with gr.Column(scale=0.5):
|
142 |
+
submit_btn = gr.Button("Submit", interactive=True, variant="primary")
|
143 |
+
with gr.Column(scale=0.5):
|
144 |
+
clear_btn = gr.Button("Clear History")
|
145 |
|
146 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
147 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
148 |
+
submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
149 |
+
submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
|
150 |
+
|
151 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
152 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
153 |
share_button.click(None, [], [], _js=share_js)
|
fromage/models.py
CHANGED
@@ -634,21 +634,25 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
-
|
638 |
-
|
|
|
|
|
|
|
639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
640 |
|
641 |
# Initialize model for inference.
|
642 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
643 |
model = model.eval()
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
|
|
652 |
|
653 |
logit_scale = model.model.logit_scale.exp()
|
654 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
|
|
634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
+
|
638 |
+
debug = False
|
639 |
+
if debug:
|
640 |
+
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
641 |
+
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
642 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
643 |
|
644 |
# Initialize model for inference.
|
645 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
646 |
model = model.eval()
|
647 |
+
if not debug:
|
648 |
+
model = model.bfloat16()
|
649 |
+
model = model.cuda()
|
650 |
+
|
651 |
+
# Load pretrained linear mappings and [RET] embeddings.
|
652 |
+
checkpoint = torch.load(model_ckpt_path)
|
653 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
654 |
+
with torch.no_grad():
|
655 |
+
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
|
656 |
|
657 |
logit_scale = model.model.logit_scale.exp()
|
658 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|