Javi commited on
Commit
ed1918f
1 Parent(s): 47ed623

First version working

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. requirements.txt +1 -0
  3. session_state.py +86 -0
  4. streamlit_app.py +76 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
 
1
+ .idea
2
  # Byte-compiled / optimized / DLL files
3
  __pycache__/
4
  *.py[cod]
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ streamlit~=0.76.0
session_state.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
2
+ from streamlit.hashing import _CodeHasher
3
+
4
+ try:
5
+ # Before Streamlit 0.65
6
+ from streamlit.ReportThread import get_report_ctx
7
+ from streamlit.server.Server import Server
8
+ except ModuleNotFoundError:
9
+ # After Streamlit 0.65
10
+ from streamlit.report_thread import get_report_ctx
11
+ from streamlit.server.server import Server
12
+
13
+
14
+ class SessionState:
15
+ def __init__(self, session, hash_funcs):
16
+ """Initialize SessionState instance."""
17
+ self.__dict__["_state"] = {
18
+ "data": {},
19
+ "hash": None,
20
+ "hasher": _CodeHasher(hash_funcs),
21
+ "is_rerun": False,
22
+ "session": session,
23
+ }
24
+
25
+ def __call__(self, **kwargs):
26
+ """Initialize state data once."""
27
+ for item, value in kwargs.items():
28
+ if item not in self._state["data"]:
29
+ self._state["data"][item] = value
30
+
31
+ def __getitem__(self, item):
32
+ """Return a saved state value, None if item is undefined."""
33
+ return self._state["data"].get(item, None)
34
+
35
+ def __getattr__(self, item):
36
+ """Return a saved state value, None if item is undefined."""
37
+ return self._state["data"].get(item, None)
38
+
39
+ def __setitem__(self, item, value):
40
+ """Set state value."""
41
+ self._state["data"][item] = value
42
+
43
+ def __setattr__(self, item, value):
44
+ """Set state value."""
45
+ self._state["data"][item] = value
46
+
47
+ def clear(self):
48
+ """Clear session state and request a rerun."""
49
+ self._state["data"].clear()
50
+ self._state["session"].request_rerun()
51
+
52
+ def sync(self):
53
+ """Rerun the app with all state values up to date from the beginning to fix rollbacks."""
54
+
55
+ # Ensure to rerun only once to avoid infinite loops
56
+ # caused by a constantly changing state value at each run.
57
+ #
58
+ # Example: state.value += 1
59
+ if self._state["is_rerun"]:
60
+ self._state["is_rerun"] = False
61
+
62
+ elif self._state["hash"] is not None:
63
+ if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None):
64
+ self._state["is_rerun"] = True
65
+ self._state["session"].request_rerun()
66
+
67
+ self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None)
68
+
69
+
70
+ def get_session():
71
+ session_id = get_report_ctx().session_id
72
+ session_info = Server.get_current()._get_session_info(session_id)
73
+
74
+ if session_info is None:
75
+ raise RuntimeError("Couldn't get your Streamlit Session object.")
76
+
77
+ return session_info.session
78
+
79
+
80
+ def get_state(hash_funcs=None):
81
+ session = get_session()
82
+
83
+ if not hasattr(session, "_custom_session_state"):
84
+ session._custom_session_state = SessionState(session, hash_funcs)
85
+
86
+ return session._custom_session_state
streamlit_app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ import booste
4
+
5
+ from session_state import SessionState, get_state
6
+
7
+ # Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
8
+ # Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
9
+ BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
10
+
11
+
12
+ task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
13
+
14
+ st.markdown("# CLIP playground")
15
+ st.markdown("### Try OpenAI's CLIP model in your browser")
16
+ st.markdown(" "); st.markdown(" ")
17
+ with st.beta_expander("What is CLIP?"):
18
+ st.markdown("Nice CLIP explaination")
19
+ st.markdown(" "); st.markdown(" ")
20
+ if task_name == "Image classification":
21
+ session_state = get_state()
22
+ uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
23
+ accept_multiple_files=False)
24
+ st.markdown("or choose one from")
25
+ col1, col2, col3 = st.beta_columns(3)
26
+ with col1:
27
+ default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
28
+ st.image(default_image_1, use_column_width=True)
29
+ if st.button("Select image 1"):
30
+ session_state.image = default_image_1
31
+ with col2:
32
+ default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
33
+ st.image(default_image_2, use_column_width=True)
34
+ if st.button("Select image 2"):
35
+ session_state.image = default_image_2
36
+ with col3:
37
+ default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
38
+ st.image(default_image_3, use_column_width=True)
39
+ if st.button("Select image 3"):
40
+ session_state.image = default_image_3
41
+ raw_classes = st.text_input("Enter the classes to chose from separated by a comma."
42
+ " (f.x. `banana, sailing boat, honesty, apple`)")
43
+ if raw_classes:
44
+ session_state.processed_classes = raw_classes.split(",")
45
+ input_prompts = ["A picture of a " + class_name for class_name in session_state.processed_classes]
46
+
47
+ col1, col2 = st.beta_columns([2, 1])
48
+ with col1:
49
+ st.markdown("Image to classify")
50
+ if session_state.image is not None:
51
+ st.image(session_state.image, use_column_width=True)
52
+ else:
53
+ st.warning("Select an image")
54
+
55
+ with col2:
56
+ st.markdown("Classes to choose from")
57
+ if session_state.processed_classes is not None:
58
+ for class_name in session_state.processed_classes:
59
+ st.write(class_name)
60
+ else:
61
+ st.warning("Enter the classes to classify from")
62
+
63
+ # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
64
+ if st.button("Predict"):
65
+ with st.spinner("Predicting..."):
66
+ clip_response = booste.clip(BOOSTE_API_KEY,
67
+ prompts=input_prompts,
68
+ images=[session_state.image],
69
+ pretty_print=True)
70
+ st.write(clip_response)
71
+
72
+
73
+ session_state.sync()
74
+
75
+
76
+