sneha
commited on
Commit
•
8946de5
1
Parent(s):
70e8264
default radio option
Browse files
app.py
CHANGED
@@ -43,7 +43,7 @@ def download_bin():
|
|
43 |
os.rename(model_bin, bin_path)
|
44 |
|
45 |
|
46 |
-
def run_attn(input_img,fusion
|
47 |
download_bin()
|
48 |
model, embedding_dim, transform, metadata = get_model()
|
49 |
if input_img.shape[0] != 3:
|
@@ -69,7 +69,7 @@ def run_attn(input_img,fusion="min"):
|
|
69 |
return attn_img, fig
|
70 |
|
71 |
input_img = gr.Image(shape=(250,250))
|
72 |
-
input_button = gr.Radio(["min", "max", "mean"], label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
73 |
output_img = gr.Image(shape=(250,250))
|
74 |
output_plot = gr.Plot()
|
75 |
|
|
|
43 |
os.rename(model_bin, bin_path)
|
44 |
|
45 |
|
46 |
+
def run_attn(input_img,fusion):
|
47 |
download_bin()
|
48 |
model, embedding_dim, transform, metadata = get_model()
|
49 |
if input_img.shape[0] != 3:
|
|
|
69 |
return attn_img, fig
|
70 |
|
71 |
input_img = gr.Image(shape=(250,250))
|
72 |
+
input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
73 |
output_img = gr.Image(shape=(250,250))
|
74 |
output_plot = gr.Plot()
|
75 |
|