broadwell commited on
Commit
810de2d
1 Parent(s): c3d8208

Allow gradient viz for J-CLIP and ResNet, disable for M-CLIP

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -111,7 +111,7 @@ def clip_search(search_query):
111
 
112
  def string_search():
113
  st.session_state.disable_uploader = (
114
- RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
115
  )
116
 
117
  if "search_field_value" in st.session_state:
@@ -171,9 +171,10 @@ def init():
171
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
172
  ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
173
 
174
- st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
175
- ml_model_path, device=device, jit=False
176
- )
 
177
 
178
  st.session_state.ml_model = (
179
  pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
@@ -194,10 +195,9 @@ def init():
194
  ja_model_name, trust_remote_code=True
195
  )
196
 
197
- if not RUN_LITE:
198
- st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
199
- clip.load("RN50x4", device=device)
200
- )
201
 
202
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
203
  "M-BERT-Base-69"
@@ -213,7 +213,7 @@ def init():
213
  with open("./images_list.txt", "r", encoding="utf-8") as images_list:
214
  st.session_state.image_ids = list(images_list.read().strip().split("\n"))
215
 
216
- st.session_state.active_model = "M-CLIP (multilingual ViT)"
217
 
218
  st.session_state.vision_mode = "tiled"
219
  st.session_state.search_image_ids = []
@@ -662,8 +662,8 @@ with search_row[6]:
662
  st.selectbox(
663
  "CLIP Model:",
664
  options=[
665
- "M-CLIP (multilingual ViT)",
666
  "J-CLIP (日本語 ViT)",
 
667
  "Legacy (multilingual ResNet)",
668
  ],
669
  key="active_model",
@@ -787,7 +787,7 @@ for image_id in batch:
787
  unsafe_allow_html=True,
788
  )
789
  if not (
790
- RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
791
  ):
792
  st.button(
793
  "Explain this",
 
111
 
112
  def string_search():
113
  st.session_state.disable_uploader = (
114
+ RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
115
  )
116
 
117
  if "search_field_value" in st.session_state:
 
171
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
172
  ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
173
 
174
+ if not RUN_LITE:
175
+ st.session_state.ml_image_model, st.session_state.ml_image_preprocess = (
176
+ load(ml_model_path, device=device, jit=False)
177
+ )
178
 
179
  st.session_state.ml_model = (
180
  pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
 
195
  ja_model_name, trust_remote_code=True
196
  )
197
 
198
+ st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
199
+ clip.load("RN50x4", device=device)
200
+ )
 
201
 
202
  st.session_state.rn_model = legacy_multilingual_clip.load_model(
203
  "M-BERT-Base-69"
 
213
  with open("./images_list.txt", "r", encoding="utf-8") as images_list:
214
  st.session_state.image_ids = list(images_list.read().strip().split("\n"))
215
 
216
+ st.session_state.active_model = "J-CLIP (日本語 ViT)"
217
 
218
  st.session_state.vision_mode = "tiled"
219
  st.session_state.search_image_ids = []
 
662
  st.selectbox(
663
  "CLIP Model:",
664
  options=[
 
665
  "J-CLIP (日本語 ViT)",
666
+ "M-CLIP (multilingual ViT)",
667
  "Legacy (multilingual ResNet)",
668
  ],
669
  key="active_model",
 
787
  unsafe_allow_html=True,
788
  )
789
  if not (
790
+ RUN_LITE and st.session_state.active_model == "M-CLIP (multilingual ViT)"
791
  ):
792
  st.button(
793
  "Explain this",