jykoh commited on
Commit
a03fe94
1 Parent(s): cce1831
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +19 -15
  3. fromage/models.py +3 -2
.gitignore CHANGED
@@ -1 +1,2 @@
1
  .DS_Store
 
 
1
  .DS_Store
2
+ venv/
app.py CHANGED
@@ -19,13 +19,15 @@ model = models.load_fromage('./', args_path, ckpt_path)
19
 
20
 
21
  def upload_image(state, image_input):
22
- state += [(f"![](/file={image_input.name})", "(Image received. Type or ask something to continue.)")]
 
 
23
  input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
24
- return [state, input_image], state
25
 
26
 
27
  def reset():
28
- return [[], None], []
29
 
30
 
31
  def save_image_to_local(image: Image.Image):
@@ -37,16 +39,19 @@ def save_image_to_local(image: Image.Image):
37
 
38
  def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
39
  input_prompt = 'Q: ' + input_text + '\nA:'
40
- input_image = state[1]
41
- chat_history += input_prompt
 
42
  print('Generating for', chat_history, flush=True)
43
 
44
  # If an image was uploaded, prepend it to the model.
45
  model_inputs = None
46
  if input_image is not None:
47
- model_inputs = [input_image, chat_history]
48
  else:
49
- model_inputs = [chat_history]
 
 
50
 
51
  top_p = 1.0
52
  if temperature != 0.0:
@@ -74,15 +79,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
74
  response += f'<img src="/file={filename}">'
75
 
76
  # TODO(jykoh): Persist image inputs.
77
- chat_history += ' '.join(text_outputs)
78
- if chat_history[-1] != '\n':
79
- chat_history += '\n'
80
-
81
- state.append((input_text, response))
82
 
83
  # Set input image to None.
84
  print('state', state, flush=True)
85
- return [state, None], state
 
86
 
87
 
88
  with gr.Blocks() as demo:
@@ -91,7 +94,7 @@ with gr.Blocks() as demo:
91
  )
92
 
93
  chatbot = gr.Chatbot()
94
- gr_state = gr.State([[], None]) # chat_history, input_image
95
 
96
  with gr.Row():
97
  with gr.Column(scale=0.3, min_width=0):
@@ -106,7 +109,8 @@ with gr.Blocks() as demo:
106
  clear_btn = gr.Button("Clear History")
107
 
108
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
 
109
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
110
  clear_btn.click(reset, [], [gr_state, chatbot])
111
 
112
- demo.launch(share=False, debug=True, server_name="0.0.0.0")
 
19
 
20
 
21
  def upload_image(state, image_input):
22
+ conversation = state[0]
23
+ chat_history = state[1]
24
+ conversation += [(f"![](/file={image_input.name})", "")]
25
  input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
26
+ return [conversation, chat_history, input_image], conversation
27
 
28
 
29
  def reset():
30
+ return [[], [], None], []
31
 
32
 
33
  def save_image_to_local(image: Image.Image):
 
39
 
40
  def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_words, temperature):
41
  input_prompt = 'Q: ' + input_text + '\nA:'
42
+ conversation = state[0]
43
+ chat_history = state[1]
44
+ input_image = state[2]
45
  print('Generating for', chat_history, flush=True)
46
 
47
  # If an image was uploaded, prepend it to the model.
48
  model_inputs = None
49
  if input_image is not None:
50
+ model_inputs = chat_history + [input_image]
51
  else:
52
+ model_inputs = chat_history
53
+
54
+ model_inputs.append(input_prompt)
55
 
56
  top_p = 1.0
57
  if temperature != 0.0:
 
79
  response += f'<img src="/file={filename}">'
80
 
81
  # TODO(jykoh): Persist image inputs.
82
+ chat_history = model_inputs + model_outputs
83
+ conversation.append((input_text, response))
 
 
 
84
 
85
  # Set input image to None.
86
  print('state', state, flush=True)
87
+ print('updated state', [conversation, chat_history, None], flush=True)
88
+ return [conversation, chat_history, None], conversation
89
 
90
 
91
  with gr.Blocks() as demo:
 
94
  )
95
 
96
  chatbot = gr.Chatbot()
97
+ gr_state = gr.State([[], [], None]) # chat_history, input_image
98
 
99
  with gr.Row():
100
  with gr.Column(scale=0.3, min_width=0):
 
109
  clear_btn = gr.Button("Clear History")
110
 
111
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
112
+ text_input.submit(lambda: "", None, text_input) # Reset chatbox.
113
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
114
  clear_btn.click(reset, [], [gr_state, chatbot])
115
 
116
+ demo.launch(share=False, debug=True, server_name="127.0.0.1")
fromage/models.py CHANGED
@@ -628,13 +628,14 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
628
 
629
  # Initialize tokenizer.
630
  tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
631
- tokenizer.pad_token = tokenizer.eos_token
632
  # Add special tokens to the model to enable [RET].
633
  tokenizer.add_special_tokens({"cls_token": "<|image|>"})
634
  tokenizer.add_tokens('[RET]')
635
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
636
  assert len(ret_token_idx) == 1, ret_token_idx
637
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
 
 
638
  args = namedtuple('args', model_kwargs)(**model_kwargs)
639
 
640
  # Initialize model for inference.
@@ -643,7 +644,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
643
  model = model.bfloat16()
644
  model = model.cuda()
645
 
646
- # Load pretrained linear mappings and [RET] embeddings.
647
  checkpoint = torch.load(model_ckpt_path)
648
  model.load_state_dict(checkpoint['state_dict'], strict=False)
649
  with torch.no_grad():
 
628
 
629
  # Initialize tokenizer.
630
  tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
 
631
  # Add special tokens to the model to enable [RET].
632
  tokenizer.add_special_tokens({"cls_token": "<|image|>"})
633
  tokenizer.add_tokens('[RET]')
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
+ # model_kwargs['opt_version'] = 'facebook/opt-125m'
638
+ # model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
639
  args = namedtuple('args', model_kwargs)(**model_kwargs)
640
 
641
  # Initialize model for inference.
 
644
  model = model.bfloat16()
645
  model = model.cuda()
646
 
647
+ Load pretrained linear mappings and [RET] embeddings.
648
  checkpoint = torch.load(model_ckpt_path)
649
  model.load_state_dict(checkpoint['state_dict'], strict=False)
650
  with torch.no_grad():