Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
from clip_client import Client
|
4 |
+
from docarray import Document
|
5 |
+
import os
|
6 |
+
|
7 |
+
# photo upload
|
8 |
+
|
9 |
+
IMAGES_FOLDER = 'images'
|
10 |
+
PAGE_LOAD_LOG_FILE = 'page_load_log.txt'
|
11 |
+
METRIC_TEXTS = {
|
12 |
+
'Attractivness': ('this person is attractive', 'this person is unattractive'),
|
13 |
+
'Hotness': ('this person is hot', 'this person is ugly'),
|
14 |
+
'Trustworthiness': ('this person is trustworthy', 'this person is dishonest'),
|
15 |
+
'Intelligence': ('this person is smart', 'this person is stupid'),
|
16 |
+
'Quality': ('this image looks good', 'this image looks bad'),
|
17 |
+
}
|
18 |
+
|
19 |
+
st.set_page_config(page_title='AI Photo Rater', initial_sidebar_state="auto")
|
20 |
+
|
21 |
+
|
22 |
+
st.title('AI Photo Rater')
|
23 |
+
|
24 |
+
def log_page_load():
|
25 |
+
with open(PAGE_LOAD_LOG_FILE, 'a') as f:
|
26 |
+
f.write(f'{time.time()}\n')
|
27 |
+
|
28 |
+
|
29 |
+
def get_num_page_loads():
|
30 |
+
with open(PAGE_LOAD_LOG_FILE, 'r') as f:
|
31 |
+
return len(f.readlines())
|
32 |
+
|
33 |
+
def get_earliest_page_load_time():
|
34 |
+
with open(PAGE_LOAD_LOG_FILE, 'r') as f:
|
35 |
+
lines = f.readlines()
|
36 |
+
unix_time = float(lines[0])
|
37 |
+
|
38 |
+
date_string = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(unix_time))
|
39 |
+
return date_string
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def show_sidebar_metrics():
|
44 |
+
metric_options = list(METRIC_TEXTS.keys())
|
45 |
+
default_metrics = ['Attractivness', 'Trustworthiness', 'Intelligence']
|
46 |
+
st.sidebar.title('Metrics')
|
47 |
+
# metric = st.sidebar.selectbox('Select a metric', metric_options)
|
48 |
+
selected_metrics = []
|
49 |
+
for metric in metric_options:
|
50 |
+
selected = metric in default_metrics
|
51 |
+
if st.sidebar.checkbox(metric, selected):
|
52 |
+
selected_metrics.append(metric)
|
53 |
+
|
54 |
+
with st.sidebar.expander('Metric texts'):
|
55 |
+
st.write(METRIC_TEXTS)
|
56 |
+
|
57 |
+
print("selected_metrics:", selected_metrics)
|
58 |
+
return selected_metrics
|
59 |
+
|
60 |
+
|
61 |
+
def get_custom_metric():
|
62 |
+
st.sidebar.markdown('**Custom metric**:')
|
63 |
+
metric_name = st.sidebar.text_input('Metric name', placeholder='e.g. "Youth"')
|
64 |
+
metric_target = st.sidebar.text_input('Metric target', placeholder='this person is young')
|
65 |
+
metric_opposite = st.sidebar.text_input('Metric opposite', placeholder='this person is old')
|
66 |
+
return {metric_name: (metric_target, metric_opposite)}
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
log_page_load()
|
72 |
+
|
73 |
+
metrics = show_sidebar_metrics()
|
74 |
+
st.sidebar.markdown('---')
|
75 |
+
custom_metric = get_custom_metric()
|
76 |
+
st.sidebar.markdown('---')
|
77 |
+
st.sidebar.write(f'Page loads: {get_num_page_loads()}')
|
78 |
+
st.sidebar.write(f'Earliest page load: {get_earliest_page_load_time()}')
|
79 |
+
|
80 |
+
metric_texts = METRIC_TEXTS
|
81 |
+
print("custom_metric:", custom_metric)
|
82 |
+
custom_key = list(custom_metric.keys())[0]
|
83 |
+
if custom_key:
|
84 |
+
custom_tuple = custom_metric[custom_key]
|
85 |
+
if custom_tuple[0] and custom_tuple[1]:
|
86 |
+
metrics.append(list(custom_metric.keys())[0])
|
87 |
+
metric_texts = {**metric_texts, **custom_metric}
|
88 |
+
|
89 |
+
os.makedirs(IMAGES_FOLDER, exist_ok=True)
|
90 |
+
|
91 |
+
# photo_file = st.file_uploader("Upload a photo", type=["jpg", "png"])
|
92 |
+
photo_files = st.file_uploader("Upload a photo", accept_multiple_files=True)
|
93 |
+
# sort them
|
94 |
+
photo_files = sorted(photo_files, key=lambda x: x.name)
|
95 |
+
|
96 |
+
if not photo_files:
|
97 |
+
st.stop()
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
c = Client('grpcs://demo-cas.jina.ai:2096')
|
102 |
+
|
103 |
+
|
104 |
+
@st.cache(show_spinner=False)
|
105 |
+
def rate_image(image_path, target, opposite, attempt=0):
|
106 |
+
try:
|
107 |
+
r = c.rank(
|
108 |
+
[
|
109 |
+
Document(
|
110 |
+
# uri='https://www.pngall.com/wp-content/uploads/12/Britney-Spears-PNG-Image-File.png',
|
111 |
+
uri=image_path,
|
112 |
+
matches=[
|
113 |
+
Document(text=target),
|
114 |
+
Document(text=opposite),
|
115 |
+
],
|
116 |
+
)
|
117 |
+
]
|
118 |
+
)
|
119 |
+
except ConnectionError as e:
|
120 |
+
print(e)
|
121 |
+
print(f'Retrying... {attempt}')
|
122 |
+
time.sleep(2**attempt)
|
123 |
+
return rate_image(image_path, target, opposite, attempt + 1)
|
124 |
+
text_and_scores = r['@m', ['text', 'scores__clip_score__value']]
|
125 |
+
index_of_good_text = text_and_scores[0].index(target)
|
126 |
+
score = text_and_scores[1][index_of_good_text]
|
127 |
+
return score
|
128 |
+
|
129 |
+
|
130 |
+
# @st.cache
|
131 |
+
def process_image(photo_file, metrics):
|
132 |
+
col1, col2, col3 = st.columns([10,10,10])
|
133 |
+
with st.spinner('Loading...'):
|
134 |
+
with col1:
|
135 |
+
st.write('')
|
136 |
+
with col2:
|
137 |
+
st.image(photo_file, use_column_width=True)
|
138 |
+
with col3:
|
139 |
+
st.write('')
|
140 |
+
|
141 |
+
|
142 |
+
# save it
|
143 |
+
filename = f'{time.time()}'.replace('.', '_')
|
144 |
+
filename_path = f'{IMAGES_FOLDER}/{filename}'
|
145 |
+
with open(f'{filename_path}', 'wb') as f:
|
146 |
+
f.write(photo_file.read())
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
with st.spinner('Rating your photo...'):
|
153 |
+
scores = dict()
|
154 |
+
for metric in metrics:
|
155 |
+
target = metric_texts[metric][0]
|
156 |
+
opposite = metric_texts[metric][1]
|
157 |
+
score = rate_image(filename_path, target, opposite)
|
158 |
+
scores[metric] = score
|
159 |
+
|
160 |
+
|
161 |
+
scores['Avg'] = sum(scores.values()) / len(scores)
|
162 |
+
|
163 |
+
# plot them
|
164 |
+
import plotly.graph_objects as go
|
165 |
+
|
166 |
+
|
167 |
+
scores_percent = []
|
168 |
+
for metric in metrics:
|
169 |
+
scores_percent.append(scores[metric] * 100)
|
170 |
+
fig = go.Figure(data=[go.Bar(x=metrics, y=scores_percent)], layout=go.Layout(title='Scores'))
|
171 |
+
# range 0 to 100 for the y axis:
|
172 |
+
fig.update_layout(yaxis=dict(range=[0, 100]))
|
173 |
+
|
174 |
+
st.plotly_chart(fig, use_container_width=True)
|
175 |
+
|
176 |
+
return filename_path, scores
|
177 |
+
|
178 |
+
|
179 |
+
def get_best_image(image_scores_list, metric):
|
180 |
+
best_image = image_scores_list[0][0]
|
181 |
+
best_score = image_scores_list[0][1][metric]
|
182 |
+
for image, scores in image_scores_list[2:]:
|
183 |
+
if scores[metric] > best_score:
|
184 |
+
best_image = image
|
185 |
+
best_score = scores[metric]
|
186 |
+
return best_image
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
image_scores_list = []
|
192 |
+
for photo_file in photo_files:
|
193 |
+
# process_image(photo_file)
|
194 |
+
filename_path, scores = process_image(photo_file, metrics)
|
195 |
+
# image_scores_list.append((filename_path, scores))
|
196 |
+
image_scores_list.append((photo_file, scores))
|
197 |
+
st.markdown('---')
|
198 |
+
|
199 |
+
|
200 |
+
if len(photo_files) > 1:
|
201 |
+
st.title('Best image')
|
202 |
+
metric = st.selectbox('Select a metric', ['Avg'] + metrics)
|
203 |
+
image_file = get_best_image(image_scores_list, metric)
|
204 |
+
# st.image(image_file, use_column_width=True)
|
205 |
+
# from PIL import Image
|
206 |
+
# image_file = Image.open(image_path)
|
207 |
+
process_image(image_file, metrics)
|
208 |
+
|
209 |
+
|
210 |
+
st.markdown('---')
|
211 |
+
|
212 |
+
col1, col2, col3 = st.columns([10,10,10])
|
213 |
+
with col1:
|
214 |
+
st.markdown('[GiHub Repo](https://github.com/tom-doerr/ai_photo_rater)')
|
215 |
+
|
216 |
+
with col2:
|
217 |
+
st.markdown('Powered by [Jina.ai](https://jina.ai/)')
|