Spaces:
Build error
Build error
Bug fixes
Browse files- .gitignore +1 -0
- app.py +19 -15
- fromage/models.py +3 -2
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
.DS_Store
|
|
|
|
1 |
.DS_Store
|
2 |
+
venv/
|
app.py
CHANGED
@@ -19,13 +19,15 @@ model = models.load_fromage('./', args_path, ckpt_path)
|
|
19 |
|
20 |
|
21 |
def upload_image(state, image_input):
|
22 |
-
|
|
|
|
|
23 |
input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
24 |
-
return [
|
25 |
|
26 |
|
27 |
def reset():
|
28 |
-
return [[], None], []
|
29 |
|
30 |
|
31 |
def save_image_to_local(image: Image.Image):
|
@@ -37,16 +39,19 @@ def save_image_to_local(image: Image.Image):
|
|
37 |
|
38 |
def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
39 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
40 |
-
|
41 |
-
chat_history
|
|
|
42 |
print('Generating for', chat_history, flush=True)
|
43 |
|
44 |
# If an image was uploaded, prepend it to the model.
|
45 |
model_inputs = None
|
46 |
if input_image is not None:
|
47 |
-
model_inputs = [input_image
|
48 |
else:
|
49 |
-
model_inputs =
|
|
|
|
|
50 |
|
51 |
top_p = 1.0
|
52 |
if temperature != 0.0:
|
@@ -74,15 +79,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
|
|
74 |
response += f'<img src="/file={filename}">'
|
75 |
|
76 |
# TODO(jykoh): Persist image inputs.
|
77 |
-
chat_history
|
78 |
-
|
79 |
-
chat_history += '\n'
|
80 |
-
|
81 |
-
state.append((input_text, response))
|
82 |
|
83 |
# Set input image to None.
|
84 |
print('state', state, flush=True)
|
85 |
-
|
|
|
86 |
|
87 |
|
88 |
with gr.Blocks() as demo:
|
@@ -91,7 +94,7 @@ with gr.Blocks() as demo:
|
|
91 |
)
|
92 |
|
93 |
chatbot = gr.Chatbot()
|
94 |
-
gr_state = gr.State([[], None]) # chat_history, input_image
|
95 |
|
96 |
with gr.Row():
|
97 |
with gr.Column(scale=0.3, min_width=0):
|
@@ -106,7 +109,8 @@ with gr.Blocks() as demo:
|
|
106 |
clear_btn = gr.Button("Clear History")
|
107 |
|
108 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
|
|
109 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
110 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
111 |
|
112 |
-
demo.launch(share=False, debug=True, server_name="
|
|
|
19 |
|
20 |
|
21 |
def upload_image(state, image_input):
|
22 |
+
conversation = state[0]
|
23 |
+
chat_history = state[1]
|
24 |
+
conversation += [(f"![](/file={image_input.name})", "")]
|
25 |
input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
26 |
+
return [conversation, chat_history, input_image], conversation
|
27 |
|
28 |
|
29 |
def reset():
|
30 |
+
return [[], [], None], []
|
31 |
|
32 |
|
33 |
def save_image_to_local(image: Image.Image):
|
|
|
39 |
|
40 |
def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
|
41 |
input_prompt = 'Q: ' + input_text + '\nA:'
|
42 |
+
conversation = state[0]
|
43 |
+
chat_history = state[1]
|
44 |
+
input_image = state[2]
|
45 |
print('Generating for', chat_history, flush=True)
|
46 |
|
47 |
# If an image was uploaded, prepend it to the model.
|
48 |
model_inputs = None
|
49 |
if input_image is not None:
|
50 |
+
model_inputs = chat_history + [input_image]
|
51 |
else:
|
52 |
+
model_inputs = chat_history
|
53 |
+
|
54 |
+
model_inputs.append(input_prompt)
|
55 |
|
56 |
top_p = 1.0
|
57 |
if temperature != 0.0:
|
|
|
79 |
response += f'<img src="/file={filename}">'
|
80 |
|
81 |
# TODO(jykoh): Persist image inputs.
|
82 |
+
chat_history = model_inputs + model_outputs
|
83 |
+
conversation.append((input_text, response))
|
|
|
|
|
|
|
84 |
|
85 |
# Set input image to None.
|
86 |
print('state', state, flush=True)
|
87 |
+
print('updated state', [conversation, chat_history, None], flush=True)
|
88 |
+
return [conversation, chat_history, None], conversation
|
89 |
|
90 |
|
91 |
with gr.Blocks() as demo:
|
|
|
94 |
)
|
95 |
|
96 |
chatbot = gr.Chatbot()
|
97 |
+
gr_state = gr.State([[], [], None]) # chat_history, input_image
|
98 |
|
99 |
with gr.Row():
|
100 |
with gr.Column(scale=0.3, min_width=0):
|
|
|
109 |
clear_btn = gr.Button("Clear History")
|
110 |
|
111 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
112 |
+
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
113 |
image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
114 |
clear_btn.click(reset, [], [gr_state, chatbot])
|
115 |
|
116 |
+
demo.launch(share=False, debug=True, server_name="127.0.0.1")
|
fromage/models.py
CHANGED
@@ -628,13 +628,14 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
628 |
|
629 |
# Initialize tokenizer.
|
630 |
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
631 |
-
tokenizer.pad_token = tokenizer.eos_token
|
632 |
# Add special tokens to the model to enable [RET].
|
633 |
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
634 |
tokenizer.add_tokens('[RET]')
|
635 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
636 |
assert len(ret_token_idx) == 1, ret_token_idx
|
637 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
|
|
|
|
638 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
639 |
|
640 |
# Initialize model for inference.
|
@@ -643,7 +644,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
643 |
model = model.bfloat16()
|
644 |
model = model.cuda()
|
645 |
|
646 |
-
|
647 |
checkpoint = torch.load(model_ckpt_path)
|
648 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
649 |
with torch.no_grad():
|
|
|
628 |
|
629 |
# Initialize tokenizer.
|
630 |
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
|
|
631 |
# Add special tokens to the model to enable [RET].
|
632 |
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
633 |
tokenizer.add_tokens('[RET]')
|
634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
+
# model_kwargs['opt_version'] = 'facebook/opt-125m'
|
638 |
+
# model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
640 |
|
641 |
# Initialize model for inference.
|
|
|
644 |
model = model.bfloat16()
|
645 |
model = model.cuda()
|
646 |
|
647 |
+
Load pretrained linear mappings and [RET] embeddings.
|
648 |
checkpoint = torch.load(model_ckpt_path)
|
649 |
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
650 |
with torch.no_grad():
|