jykoh commited on
Commit
3d6dac6
β€’
1 Parent(s): ef1d866

Fixes to image resolution

Browse files
Files changed (2) hide show
  1. app.py +5 -6
  2. gill/models.py +3 -3
app.py CHANGED
@@ -115,14 +115,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperat
115
  elif type(p) == dict:
116
  # Decide whether to generate or retrieve.
117
  if p['decision'] is not None and p['decision'][0] == 'gen':
118
- image = p['gen'][0][0].resize((224, 224))
119
  filename = save_image_to_local(image)
120
- response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Generated)</p>'
121
  else:
122
- image = p['ret'][0][0].resize((224, 224))
123
  filename = save_image_to_local(image)
124
- response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Retrieved)</p>'
125
-
126
 
127
  chat_history = model_inputs + \
128
  [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
@@ -180,7 +179,7 @@ with gr.Blocks(css=css) as demo:
180
  share_button = gr.Button("πŸ€— Share to Community (opens new window)", elem_id="share-btn")
181
 
182
  with gr.Column(scale=0.3, min_width=400):
183
- ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True,
184
  label="Frequency multiplier for returning images (higher means more frequent)")
185
  # max_ret_images = gr.Number(
186
  # minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
 
115
  elif type(p) == dict:
116
  # Decide whether to generate or retrieve.
117
  if p['decision'] is not None and p['decision'][0] == 'gen':
118
+ image = p['gen'][0][0]#.resize((224, 224))
119
  filename = save_image_to_local(image)
120
+ response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>'
121
  else:
122
+ image = p['ret'][0][0]#.resize((224, 224))
123
  filename = save_image_to_local(image)
124
+ response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>'
 
125
 
126
  chat_history = model_inputs + \
127
  [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
 
179
  share_button = gr.Button("πŸ€— Share to Community (opens new window)", elem_id="share-btn")
180
 
181
  with gr.Column(scale=0.3, min_width=400):
182
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.2, step=0.1, interactive=True,
183
  label="Frequency multiplier for returning images (higher means more frequent)")
184
  # max_ret_images = gr.Number(
185
  # minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
gill/models.py CHANGED
@@ -878,10 +878,10 @@ def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, d
878
  model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
879
  load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
880
  model = model.eval()
881
- if not debug:
882
- model = model.bfloat16()
883
- model = model.cuda()
884
 
 
885
  # Load pretrained linear mappings and [IMG] embeddings.
886
  checkpoint = torch.load(model_ckpt_path)
887
  state_dict = {}
 
878
  model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
879
  load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
880
  model = model.eval()
881
+ if torch.cuda.is_available():
882
+ model = model.bfloat16().cuda()
 
883
 
884
+ if not debug:
885
  # Load pretrained linear mappings and [IMG] embeddings.
886
  checkpoint = torch.load(model_ckpt_path)
887
  state_dict = {}