Javi commited on
Commit
b0cb25e
1 Parent(s): de38ce1

Both classification and prompt ranking working

Browse files
Files changed (1) hide show
  1. streamlit_app.py +110 -67
streamlit_app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from PIL import Image
2
  import streamlit as st
3
  import booste
@@ -9,76 +11,117 @@ from session_state import SessionState, get_state
9
  BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
10
 
11
 
12
- task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- st.markdown("# CLIP playground")
15
- st.markdown("### Try OpenAI's CLIP model in your browser")
16
- st.markdown(" "); st.markdown(" ")
17
- with st.beta_expander("What is CLIP?"):
18
- st.markdown("Nice CLIP explaination")
19
- st.markdown(" "); st.markdown(" ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if task_name == "Image classification":
21
- session_state = get_state()
22
- uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
23
- accept_multiple_files=False)
24
  st.markdown("or choose one from")
25
- col1, col2, col3 = st.beta_columns(3)
26
- with col1:
27
- default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
28
- st.image(default_image_1, use_column_width=True)
29
- if st.button("Select image 1"):
30
- session_state.image = default_image_1
31
- with col2:
32
- default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
33
- st.image(default_image_2, use_column_width=True)
34
- if st.button("Select image 2"):
35
- session_state.image = default_image_2
36
- with col3:
37
- default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
38
- st.image(default_image_3, use_column_width=True)
39
- if st.button("Select image 3"):
40
- session_state.image = default_image_3
41
- raw_classes = st.text_input("Enter the classes to chose from separated by a comma."
42
- " (f.x. `banana, sailing boat, honesty, apple`)")
43
- if raw_classes:
44
- session_state.processed_classes = raw_classes.split(",")
45
- input_prompts = ["A picture of a " + class_name for class_name in session_state.processed_classes]
46
-
47
- col1, col2 = st.beta_columns([2, 1])
48
- with col1:
49
- st.markdown("Image to classify")
50
- if session_state.image is not None:
51
- st.image(session_state.image, use_column_width=True)
52
- else:
53
- st.warning("Select an image")
54
-
55
- with col2:
56
- st.markdown("Classes to choose from")
57
- if session_state.processed_classes is not None:
58
- for class_name in session_state.processed_classes:
59
- st.write(class_name)
60
- else:
61
- st.warning("Enter the classes to classify from")
62
-
63
- # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
64
- if st.button("Predict"):
65
- with st.spinner("Predicting..."):
66
- clip_response = booste.clip(BOOSTE_API_KEY,
67
- prompts=input_prompts,
68
- images=[session_state.image],
69
- pretty_print=True)
70
- st.markdown("### Results")
71
- simplified_clip_results = [(prompt[len('A picture of a '):],
72
- list(results.values())[0]["probabilityRelativeToPrompts"])
73
- for prompt, results in clip_response.items()]
74
- simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
75
- max_class_name_length = max(len(class_name) for class_name, _ in simplified_clip_results)
76
-
77
- for prompt, probability in simplified_clip_results:
78
- progress_bar = "".join([":large_blue_circle:"] * int(probability * 10) +
79
- [":black_circle:"] * int((1 - probability) * 10))
80
- st.markdown(f"### {prompt}: {progress_bar} {probability:.3f}")
81
- st.write(clip_response)
82
 
83
  session_state.sync()
84
 
 
1
+ from typing import Optional, List
2
+
3
  from PIL import Image
4
  import streamlit as st
5
  import booste
 
11
  BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
12
 
13
 
14
+ class Sections:
15
+ @staticmethod
16
+ def header():
17
+ st.markdown("# CLIP playground")
18
+ st.markdown("### Try OpenAI's CLIP model in your browser")
19
+ st.markdown(" ");
20
+ st.markdown(" ")
21
+ with st.beta_expander("What is CLIP?"):
22
+ st.markdown("Nice CLIP explaination")
23
+ st.markdown(" ");
24
+ st.markdown(" ")
25
+
26
+ @staticmethod
27
+ def image_uploader(accept_multiple_files: bool) -> Optional[List[str]]:
28
+ uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
29
+ accept_multiple_files=accept_multiple_files)
30
+
31
+ @staticmethod
32
+ def image_picker(state: SessionState):
33
+ col1, col2, col3 = st.beta_columns(3)
34
+ with col1:
35
+ default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
36
+ st.image(default_image_1, use_column_width=True)
37
+ if st.button("Select image 1"):
38
+ state.image = default_image_1
39
+ with col2:
40
+ default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
41
+ st.image(default_image_2, use_column_width=True)
42
+ if st.button("Select image 2"):
43
+ state.image = default_image_2
44
+ with col3:
45
+ default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
46
+ st.image(default_image_3, use_column_width=True)
47
+ if st.button("Select image 3"):
48
+ state.image = default_image_3
49
+
50
+ @staticmethod
51
+ def prompts_input(state: SessionState, input_label: str, prompt_prefix: str = ''):
52
+ raw_classes = st.text_input(input_label)
53
+ if raw_classes:
54
+ state.prompts = [prompt_prefix + class_name for class_name in raw_classes.split(";") if len(class_name) > 1]
55
+ state.prompt_prefix = prompt_prefix
56
+
57
+ @staticmethod
58
+ def input_preview(state: SessionState):
59
+ col1, col2 = st.beta_columns([2, 1])
60
+ with col1:
61
+ st.markdown("Image to classify")
62
+ if state.image is not None:
63
+ st.image(state.image, use_column_width=True)
64
+ else:
65
+ st.warning("Select an image")
66
+
67
+ with col2:
68
+ st.markdown("Labels to choose from")
69
+ if state.processed_classes is not None:
70
+ for prompt in state.prompts:
71
+ st.write(prompt[len(state.prompt_prefix):])
72
+ else:
73
+ st.warning("Enter the classes to classify from")
74
 
75
+ @staticmethod
76
+ def classification_output(state: SessionState):
77
+ # Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
78
+ if st.button("Predict"):
79
+ with st.spinner("Predicting..."):
80
+ clip_response = booste.clip(BOOSTE_API_KEY,
81
+ prompts=state.prompts,
82
+ images=[state.image],
83
+ pretty_print=True)
84
+ st.markdown("### Results")
85
+ simplified_clip_results = [(prompt[len(state.prompt_prefix):],
86
+ list(results.values())[0]["probabilityRelativeToPrompts"])
87
+ for prompt, results in clip_response.items()]
88
+ simplified_clip_results = sorted(simplified_clip_results, key=lambda x: x[1], reverse=True)
89
+
90
+ for prompt, probability in simplified_clip_results:
91
+ percentage_prob = int(probability * 100)
92
+ st.markdown(
93
+ f"### ![prob](https://progress-bar.dev/{percentage_prob}/?width=200) &nbsp &nbsp {prompt}")
94
+ st.write(clip_response)
95
+
96
+
97
+ task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
98
+ session_state = get_state()
99
  if task_name == "Image classification":
100
+ Sections.header()
101
+ Sections.image_uploader(accept_multiple_files=False)
 
102
  st.markdown("or choose one from")
103
+ Sections.image_picker(session_state)
104
+ input_label = "Enter the classes to chose from separated by a semi-colon. (f.x. `banana; boat; honesty; apple`)"
105
+ Sections.prompts_input(session_state, input_label, prompt_prefix='A picture of a ')
106
+ Sections.input_preview(session_state)
107
+ Sections.classification_output(session_state)
108
+ elif task_name == "Prompt ranking":
109
+ Sections.header()
110
+ Sections.image_uploader(accept_multiple_files=False)
111
+ st.markdown("or choose one from")
112
+ Sections.image_picker(session_state)
113
+ input_label = "Enter the prompts to choose from separated by a semi-colon. " \
114
+ "(f.x. `An image that inspires; A feeling of loneliness; joyful and young; apple`)"
115
+ Sections.prompts_input(session_state, input_label)
116
+ Sections.input_preview(session_state)
117
+ Sections.classification_output(session_state)
118
+ elif task_name == "Image ranking":
119
+ Sections.header()
120
+ Sections.image_uploader(accept_multiple_files=True)
121
+ st.markdown("or use random dataset")
122
+ Sections.image_picker(session_state)
123
+
124
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  session_state.sync()
127