Allow gradient viz for J-CLIP and ResNet, disable for M-CLIP
Browse files
app.py
CHANGED
@@ -111,7 +111,7 @@ def clip_search(search_query):
|
|
111 |
|
112 |
def string_search():
|
113 |
st.session_state.disable_uploader = (
|
114 |
-
RUN_LITE and st.session_state.active_model == "
|
115 |
)
|
116 |
|
117 |
if "search_field_value" in st.session_state:
|
@@ -171,9 +171,10 @@ def init():
|
|
171 |
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
|
172 |
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
177 |
|
178 |
st.session_state.ml_model = (
|
179 |
pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
|
@@ -194,10 +195,9 @@ def init():
|
|
194 |
ja_model_name, trust_remote_code=True
|
195 |
)
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
)
|
201 |
|
202 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
203 |
"M-BERT-Base-69"
|
@@ -213,7 +213,7 @@ def init():
|
|
213 |
with open("./images_list.txt", "r", encoding="utf-8") as images_list:
|
214 |
st.session_state.image_ids = list(images_list.read().strip().split("\n"))
|
215 |
|
216 |
-
st.session_state.active_model = "
|
217 |
|
218 |
st.session_state.vision_mode = "tiled"
|
219 |
st.session_state.search_image_ids = []
|
@@ -662,8 +662,8 @@ with search_row[6]:
|
|
662 |
st.selectbox(
|
663 |
"CLIP Model:",
|
664 |
options=[
|
665 |
-
"M-CLIP (multilingual ViT)",
|
666 |
"J-CLIP (日本語 ViT)",
|
|
|
667 |
"Legacy (multilingual ResNet)",
|
668 |
],
|
669 |
key="active_model",
|
@@ -787,7 +787,7 @@ for image_id in batch:
|
|
787 |
unsafe_allow_html=True,
|
788 |
)
|
789 |
if not (
|
790 |
-
RUN_LITE and st.session_state.active_model == "
|
791 |
):
|
792 |
st.button(
|
793 |
"Explain this",
|
|
|
111 |
|
112 |
def string_search():
|
113 |
st.session_state.disable_uploader = (
|
114 |
+
RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
|
115 |
)
|
116 |
|
117 |
if "search_field_value" in st.session_state:
|
|
|
171 |
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
|
172 |
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
|
173 |
|
174 |
+
if not RUN_LITE:
|
175 |
+
st.session_state.ml_image_model, st.session_state.ml_image_preprocess = (
|
176 |
+
load(ml_model_path, device=device, jit=False)
|
177 |
+
)
|
178 |
|
179 |
st.session_state.ml_model = (
|
180 |
pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
|
|
|
195 |
ja_model_name, trust_remote_code=True
|
196 |
)
|
197 |
|
198 |
+
st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
|
199 |
+
clip.load("RN50x4", device=device)
|
200 |
+
)
|
|
|
201 |
|
202 |
st.session_state.rn_model = legacy_multilingual_clip.load_model(
|
203 |
"M-BERT-Base-69"
|
|
|
213 |
with open("./images_list.txt", "r", encoding="utf-8") as images_list:
|
214 |
st.session_state.image_ids = list(images_list.read().strip().split("\n"))
|
215 |
|
216 |
+
st.session_state.active_model = "J-CLIP (日本語 ViT)"
|
217 |
|
218 |
st.session_state.vision_mode = "tiled"
|
219 |
st.session_state.search_image_ids = []
|
|
|
662 |
st.selectbox(
|
663 |
"CLIP Model:",
|
664 |
options=[
|
|
|
665 |
"J-CLIP (日本語 ViT)",
|
666 |
+
"M-CLIP (multilingual ViT)",
|
667 |
"Legacy (multilingual ResNet)",
|
668 |
],
|
669 |
key="active_model",
|
|
|
787 |
unsafe_allow_html=True,
|
788 |
)
|
789 |
if not (
|
790 |
+
RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
|
791 |
):
|
792 |
st.button(
|
793 |
"Explain this",
|