JavierFnts commited on
Commit
e942bb1
1 Parent(s): 5657ecb

Added model architecture selection

Browse files
Files changed (3) hide show
  1. app.py +9 -3
  2. images_mocker.py +0 -31
  3. session_state.py +0 -86
app.py CHANGED
@@ -26,8 +26,8 @@ def load_image_from_url(url: str) -> Image.Image:
26
  return Image.open(requests.get(url, stream=True).raw)
27
 
28
  @st.cache
29
- def load_model() -> ClipModel:
30
- return ClipModel()
31
 
32
  def init_state():
33
  if "images" not in st.session_state:
@@ -38,6 +38,8 @@ def init_state():
38
  st.session_state.predictions = None
39
  if "default_text_input" not in st.session_state:
40
  st.session_state.default_text_input = None
 
 
41
 
42
 
43
  def limit_number_images():
@@ -278,7 +280,7 @@ if __name__ == "__main__":
278
  task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
279
  st.markdown("<br>", unsafe_allow_html=True)
280
  init_state()
281
- model = load_model()
282
  if task_name == "Image classification":
283
  Sections.image_uploader(accept_multiple_files=False)
284
  if st.session_state.images is None:
@@ -311,6 +313,10 @@ if __name__ == "__main__":
311
  limit_number_prompts()
312
  Sections.multiple_images_input_preview()
313
  Sections.classification_output(model)
 
 
 
 
314
 
315
  st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
316
  "", unsafe_allow_html=True)
 
26
  return Image.open(requests.get(url, stream=True).raw)
27
 
28
  @st.cache
29
+ def load_model(model_architecture: str) -> ClipModel:
30
+ return ClipModel(model_architecture)
31
 
32
  def init_state():
33
  if "images" not in st.session_state:
 
38
  st.session_state.predictions = None
39
  if "default_text_input" not in st.session_state:
40
  st.session_state.default_text_input = None
41
+ if "model_architecture" not in st.session_state:
42
+ st.session_state.model_architecture = "RN50"
43
 
44
 
45
  def limit_number_images():
 
280
  task_name: str = col2.selectbox("", options=["Prompt ranking", "Image ranking", "Image classification"])
281
  st.markdown("<br>", unsafe_allow_html=True)
282
  init_state()
283
+ model = load_model(st.session_state.model_architecture)
284
  if task_name == "Image classification":
285
  Sections.image_uploader(accept_multiple_files=False)
286
  if st.session_state.images is None:
 
313
  limit_number_prompts()
314
  Sections.multiple_images_input_preview()
315
  Sections.classification_output(model)
316
+
317
+ with st.expander("Advanced settings"):
318
+ st.session_state.model_architecture = st.selectbox("Model architecture", options=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32',
319
+ 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], index=0)
320
 
321
  st.markdown("<br><br><br><br>Made by [@JavierFnts](https://twitter.com/JavierFnts) | [How was CLIP Playground built?](https://twitter.com/JavierFnts/status/1363522529072214019)"
322
  "", unsafe_allow_html=True)
images_mocker.py DELETED
@@ -1,31 +0,0 @@
1
- from typing import List
2
- import uuid
3
- from mock import patch
4
-
5
-
6
- class ImagesMocker:
7
- """HACK ALERT: I needed a way to call the booste API without storing the images first
8
- (as that is not allowed in streamlit sharing). If you have a better idea on hwo to this let me know!"""
9
-
10
- def __init__(self):
11
- self.pil_patch = patch('PIL.Image.open', lambda x: self.image_id2image(x))
12
- self.path_patch = patch('os.path.exists', lambda x: True)
13
- self.image_id2image_lookup = {}
14
-
15
- def start_mocking(self):
16
- self.pil_patch.start()
17
- self.path_patch.start()
18
-
19
- def stop_mocking(self):
20
- self.pil_patch.stop()
21
- self.path_patch.stop()
22
-
23
- def image_id2image(self, image_id: str):
24
- return self.image_id2image_lookup[image_id]
25
-
26
- def calculate_image_id2image_lookup(self, images: List):
27
- self.image_id2image_lookup = {str(uuid.uuid4()) + ".png": image for image in images}
28
-
29
- @property
30
- def image_ids(self):
31
- return list(self.image_id2image_lookup.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
session_state.py DELETED
@@ -1,86 +0,0 @@
1
- # From https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
2
- from streamlit.hashing import _CodeHasher
3
-
4
- try:
5
- # Before Streamlit 0.65
6
- from streamlit.ReportThread import get_report_ctx
7
- from streamlit.server.Server import Server
8
- except ModuleNotFoundError:
9
- # After Streamlit 0.65
10
- from streamlit.report_thread import get_report_ctx
11
- from streamlit.server.server import Server
12
-
13
-
14
- class SessionState:
15
- def __init__(self, session, hash_funcs):
16
- """Initialize SessionState instance."""
17
- self.__dict__["_state"] = {
18
- "data": {},
19
- "hash": None,
20
- "hasher": _CodeHasher(hash_funcs),
21
- "is_rerun": False,
22
- "session": session,
23
- }
24
-
25
- def __call__(self, **kwargs):
26
- """Initialize state data once."""
27
- for item, value in kwargs.items():
28
- if item not in self._state["data"]:
29
- self._state["data"][item] = value
30
-
31
- def __getitem__(self, item):
32
- """Return a saved state value, None if item is undefined."""
33
- return self._state["data"].get(item, None)
34
-
35
- def __getattr__(self, item):
36
- """Return a saved state value, None if item is undefined."""
37
- return self._state["data"].get(item, None)
38
-
39
- def __setitem__(self, item, value):
40
- """Set state value."""
41
- self._state["data"][item] = value
42
-
43
- def __setattr__(self, item, value):
44
- """Set state value."""
45
- self._state["data"][item] = value
46
-
47
- def clear(self):
48
- """Clear session state and request a rerun."""
49
- self._state["data"].clear()
50
- self._state["session"].request_rerun()
51
-
52
- def sync(self):
53
- """Rerun the app with all state values up to date from the beginning to fix rollbacks."""
54
-
55
- # Ensure to rerun only once to avoid infinite loops
56
- # caused by a constantly changing state value at each run.
57
- #
58
- # Example: state.value += 1
59
- if self._state["is_rerun"]:
60
- self._state["is_rerun"] = False
61
-
62
- elif self._state["hash"] is not None:
63
- if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None):
64
- self._state["is_rerun"] = True
65
- self._state["session"].request_rerun()
66
-
67
- self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None)
68
-
69
-
70
- def get_session():
71
- session_id = get_report_ctx().session_id
72
- session_info = Server.get_current()._get_session_info(session_id)
73
-
74
- if session_info is None:
75
- raise RuntimeError("Couldn't get your Streamlit Session object.")
76
-
77
- return session_info.session
78
-
79
-
80
- def get_state(hash_funcs=None):
81
- session = get_session()
82
-
83
- if not hasattr(session, "_custom_session_state"):
84
- session._custom_session_state = SessionState(session, hash_funcs)
85
-
86
- return session._custom_session_state