Upload app.py
Browse files
app.py
CHANGED
@@ -27,6 +27,8 @@ from CLIP_Explainability.vit_cam import (
|
|
27 |
|
28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
29 |
|
|
|
|
|
30 |
MAX_IMG_WIDTH = 500
|
31 |
MAX_IMG_HEIGHT = 800
|
32 |
|
@@ -172,9 +174,10 @@ def init():
|
|
172 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
173 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
178 |
|
179 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
180 |
ja_model_name, trust_remote_code=True
|
@@ -183,9 +186,10 @@ def init():
|
|
183 |
ja_model_name, trust_remote_code=True
|
184 |
)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
189 |
|
190 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
191 |
"M-BERT-Base-69"
|
@@ -701,11 +705,12 @@ for image_id in batch:
|
|
701 |
<div>""",
|
702 |
unsafe_allow_html=True,
|
703 |
)
|
704 |
-
st.
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
711 |
col = (col + 1) % row_size
|
|
|
27 |
|
28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
29 |
|
30 |
+
RUN_LITE = True # Load vision model for CAM viz explainability for M-CLIP only
|
31 |
+
|
32 |
MAX_IMG_WIDTH = 500
|
33 |
MAX_IMG_HEIGHT = 800
|
34 |
|
|
|
174 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
175 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
176 |
|
177 |
+
if not RUN_LITE:
|
178 |
+
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = (
|
179 |
+
load(ja_model_path, device=device, jit=False)
|
180 |
+
)
|
181 |
|
182 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
183 |
ja_model_name, trust_remote_code=True
|
|
|
186 |
ja_model_name, trust_remote_code=True
|
187 |
)
|
188 |
|
189 |
+
if not RUN_LITE:
|
190 |
+
st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
|
191 |
+
clip.load("RN50x4", device=device)
|
192 |
+
)
|
193 |
|
194 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
195 |
"M-BERT-Base-69"
|
|
|
705 |
<div>""",
|
706 |
unsafe_allow_html=True,
|
707 |
)
|
708 |
+
if not RUN_LITE or st.session_state.active_model == "M-CLIP (multilingual ViT)":
|
709 |
+
st.button(
|
710 |
+
"Explain this",
|
711 |
+
on_click=image_modal,
|
712 |
+
args=[image_id],
|
713 |
+
use_container_width=True,
|
714 |
+
key=image_id,
|
715 |
+
)
|
716 |
col = (col + 1) % row_size
|