micole66 commited on
Commit
d9da768
1 Parent(s): 272df6a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ from clip_client import Client
4
+ from docarray import Document
5
+ import os
6
+
7
+ # photo upload
8
+
9
+ IMAGES_FOLDER = 'images'
10
+ PAGE_LOAD_LOG_FILE = 'page_load_log.txt'
11
+ METRIC_TEXTS = {
12
+ 'Attractivness': ('this person is attractive', 'this person is unattractive'),
13
+ 'Hotness': ('this person is hot', 'this person is ugly'),
14
+ 'Trustworthiness': ('this person is trustworthy', 'this person is dishonest'),
15
+ 'Intelligence': ('this person is smart', 'this person is stupid'),
16
+ 'Quality': ('this image looks good', 'this image looks bad'),
17
+ }
18
+
19
+ st.set_page_config(page_title='AI Photo Rater', initial_sidebar_state="auto")
20
+
21
+
22
+ st.title('AI Photo Rater')
23
+
24
+ def log_page_load():
25
+ with open(PAGE_LOAD_LOG_FILE, 'a') as f:
26
+ f.write(f'{time.time()}\n')
27
+
28
+
29
+ def get_num_page_loads():
30
+ with open(PAGE_LOAD_LOG_FILE, 'r') as f:
31
+ return len(f.readlines())
32
+
33
+ def get_earliest_page_load_time():
34
+ with open(PAGE_LOAD_LOG_FILE, 'r') as f:
35
+ lines = f.readlines()
36
+ unix_time = float(lines[0])
37
+
38
+ date_string = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(unix_time))
39
+ return date_string
40
+
41
+
42
+
43
+ def show_sidebar_metrics():
44
+ metric_options = list(METRIC_TEXTS.keys())
45
+ default_metrics = ['Attractivness', 'Trustworthiness', 'Intelligence']
46
+ st.sidebar.title('Metrics')
47
+ # metric = st.sidebar.selectbox('Select a metric', metric_options)
48
+ selected_metrics = []
49
+ for metric in metric_options:
50
+ selected = metric in default_metrics
51
+ if st.sidebar.checkbox(metric, selected):
52
+ selected_metrics.append(metric)
53
+
54
+ with st.sidebar.expander('Metric texts'):
55
+ st.write(METRIC_TEXTS)
56
+
57
+ print("selected_metrics:", selected_metrics)
58
+ return selected_metrics
59
+
60
+
61
+ def get_custom_metric():
62
+ st.sidebar.markdown('**Custom metric**:')
63
+ metric_name = st.sidebar.text_input('Metric name', placeholder='e.g. "Youth"')
64
+ metric_target = st.sidebar.text_input('Metric target', placeholder='this person is young')
65
+ metric_opposite = st.sidebar.text_input('Metric opposite', placeholder='this person is old')
66
+ return {metric_name: (metric_target, metric_opposite)}
67
+
68
+
69
+
70
+
71
+ log_page_load()
72
+
73
+ metrics = show_sidebar_metrics()
74
+ st.sidebar.markdown('---')
75
+ custom_metric = get_custom_metric()
76
+ st.sidebar.markdown('---')
77
+ st.sidebar.write(f'Page loads: {get_num_page_loads()}')
78
+ st.sidebar.write(f'Earliest page load: {get_earliest_page_load_time()}')
79
+
80
+ metric_texts = METRIC_TEXTS
81
+ print("custom_metric:", custom_metric)
82
+ custom_key = list(custom_metric.keys())[0]
83
+ if custom_key:
84
+ custom_tuple = custom_metric[custom_key]
85
+ if custom_tuple[0] and custom_tuple[1]:
86
+ metrics.append(list(custom_metric.keys())[0])
87
+ metric_texts = {**metric_texts, **custom_metric}
88
+
89
+ os.makedirs(IMAGES_FOLDER, exist_ok=True)
90
+
91
+ # photo_file = st.file_uploader("Upload a photo", type=["jpg", "png"])
92
+ photo_files = st.file_uploader("Upload a photo", accept_multiple_files=True)
93
+ # sort them
94
+ photo_files = sorted(photo_files, key=lambda x: x.name)
95
+
96
+ if not photo_files:
97
+ st.stop()
98
+
99
+
100
+
101
+ c = Client('grpcs://demo-cas.jina.ai:2096')
102
+
103
+
104
+ @st.cache(show_spinner=False)
105
+ def rate_image(image_path, target, opposite, attempt=0):
106
+ try:
107
+ r = c.rank(
108
+ [
109
+ Document(
110
+ # uri='https://www.pngall.com/wp-content/uploads/12/Britney-Spears-PNG-Image-File.png',
111
+ uri=image_path,
112
+ matches=[
113
+ Document(text=target),
114
+ Document(text=opposite),
115
+ ],
116
+ )
117
+ ]
118
+ )
119
+ except ConnectionError as e:
120
+ print(e)
121
+ print(f'Retrying... {attempt}')
122
+ time.sleep(2**attempt)
123
+ return rate_image(image_path, target, opposite, attempt + 1)
124
+ text_and_scores = r['@m', ['text', 'scores__clip_score__value']]
125
+ index_of_good_text = text_and_scores[0].index(target)
126
+ score = text_and_scores[1][index_of_good_text]
127
+ return score
128
+
129
+
130
+ # @st.cache
131
+ def process_image(photo_file, metrics):
132
+ col1, col2, col3 = st.columns([10,10,10])
133
+ with st.spinner('Loading...'):
134
+ with col1:
135
+ st.write('')
136
+ with col2:
137
+ st.image(photo_file, use_column_width=True)
138
+ with col3:
139
+ st.write('')
140
+
141
+
142
+ # save it
143
+ filename = f'{time.time()}'.replace('.', '_')
144
+ filename_path = f'{IMAGES_FOLDER}/{filename}'
145
+ with open(f'{filename_path}', 'wb') as f:
146
+ f.write(photo_file.read())
147
+
148
+
149
+
150
+
151
+
152
+ with st.spinner('Rating your photo...'):
153
+ scores = dict()
154
+ for metric in metrics:
155
+ target = metric_texts[metric][0]
156
+ opposite = metric_texts[metric][1]
157
+ score = rate_image(filename_path, target, opposite)
158
+ scores[metric] = score
159
+
160
+
161
+ scores['Avg'] = sum(scores.values()) / len(scores)
162
+
163
+ # plot them
164
+ import plotly.graph_objects as go
165
+
166
+
167
+ scores_percent = []
168
+ for metric in metrics:
169
+ scores_percent.append(scores[metric] * 100)
170
+ fig = go.Figure(data=[go.Bar(x=metrics, y=scores_percent)], layout=go.Layout(title='Scores'))
171
+ # range 0 to 100 for the y axis:
172
+ fig.update_layout(yaxis=dict(range=[0, 100]))
173
+
174
+ st.plotly_chart(fig, use_container_width=True)
175
+
176
+ return filename_path, scores
177
+
178
+
179
+ def get_best_image(image_scores_list, metric):
180
+ best_image = image_scores_list[0][0]
181
+ best_score = image_scores_list[0][1][metric]
182
+ for image, scores in image_scores_list[2:]:
183
+ if scores[metric] > best_score:
184
+ best_image = image
185
+ best_score = scores[metric]
186
+ return best_image
187
+
188
+
189
+
190
+
191
+ image_scores_list = []
192
+ for photo_file in photo_files:
193
+ # process_image(photo_file)
194
+ filename_path, scores = process_image(photo_file, metrics)
195
+ # image_scores_list.append((filename_path, scores))
196
+ image_scores_list.append((photo_file, scores))
197
+ st.markdown('---')
198
+
199
+
200
+ if len(photo_files) > 1:
201
+ st.title('Best image')
202
+ metric = st.selectbox('Select a metric', ['Avg'] + metrics)
203
+ image_file = get_best_image(image_scores_list, metric)
204
+ # st.image(image_file, use_column_width=True)
205
+ # from PIL import Image
206
+ # image_file = Image.open(image_path)
207
+ process_image(image_file, metrics)
208
+
209
+
210
+ st.markdown('---')
211
+
212
+ col1, col2, col3 = st.columns([10,10,10])
213
+ with col1:
214
+ st.markdown('[GiHub Repo](https://github.com/tom-doerr/ai_photo_rater)')
215
+
216
+ with col2:
217
+ st.markdown('Powered by [Jina.ai](https://jina.ai/)')