Spaces:
Runtime error
Runtime error
Fixes to image resolution
Browse files- app.py +5 -6
- gill/models.py +3 -3
app.py
CHANGED
@@ -115,14 +115,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperat
|
|
115 |
elif type(p) == dict:
|
116 |
# Decide whether to generate or retrieve.
|
117 |
if p['decision'] is not None and p['decision'][0] == 'gen':
|
118 |
-
image = p['gen'][0][0]
|
119 |
filename = save_image_to_local(image)
|
120 |
-
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Generated)</p>'
|
121 |
else:
|
122 |
-
image = p['ret'][0][0]
|
123 |
filename = save_image_to_local(image)
|
124 |
-
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Retrieved)</p>'
|
125 |
-
|
126 |
|
127 |
chat_history = model_inputs + \
|
128 |
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
@@ -180,7 +179,7 @@ with gr.Blocks(css=css) as demo:
|
|
180 |
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn")
|
181 |
|
182 |
with gr.Column(scale=0.3, min_width=400):
|
183 |
-
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.
|
184 |
label="Frequency multiplier for returning images (higher means more frequent)")
|
185 |
# max_ret_images = gr.Number(
|
186 |
# minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
|
|
|
115 |
elif type(p) == dict:
|
116 |
# Decide whether to generate or retrieve.
|
117 |
if p['decision'] is not None and p['decision'][0] == 'gen':
|
118 |
+
image = p['gen'][0][0]#.resize((224, 224))
|
119 |
filename = save_image_to_local(image)
|
120 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>'
|
121 |
else:
|
122 |
+
image = p['ret'][0][0]#.resize((224, 224))
|
123 |
filename = save_image_to_local(image)
|
124 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>'
|
|
|
125 |
|
126 |
chat_history = model_inputs + \
|
127 |
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
|
|
179 |
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn")
|
180 |
|
181 |
with gr.Column(scale=0.3, min_width=400):
|
182 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.2, step=0.1, interactive=True,
|
183 |
label="Frequency multiplier for returning images (higher means more frequent)")
|
184 |
# max_ret_images = gr.Number(
|
185 |
# minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
|
gill/models.py
CHANGED
@@ -878,10 +878,10 @@ def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, d
|
|
878 |
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
|
879 |
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
|
880 |
model = model.eval()
|
881 |
-
if
|
882 |
-
model = model.bfloat16()
|
883 |
-
model = model.cuda()
|
884 |
|
|
|
885 |
# Load pretrained linear mappings and [IMG] embeddings.
|
886 |
checkpoint = torch.load(model_ckpt_path)
|
887 |
state_dict = {}
|
|
|
878 |
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
|
879 |
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
|
880 |
model = model.eval()
|
881 |
+
if torch.cuda.is_available():
|
882 |
+
model = model.bfloat16().cuda()
|
|
|
883 |
|
884 |
+
if not debug:
|
885 |
# Load pretrained linear mappings and [IMG] embeddings.
|
886 |
checkpoint = torch.load(model_ckpt_path)
|
887 |
state_dict = {}
|