ai-phooto / app.py
micole66's picture
Create app.py
d9da768
import streamlit as st
import time
from clip_client import Client
from docarray import Document
import os
# photo upload
IMAGES_FOLDER = 'images'
PAGE_LOAD_LOG_FILE = 'page_load_log.txt'
METRIC_TEXTS = {
'Attractivness': ('this person is attractive', 'this person is unattractive'),
'Hotness': ('this person is hot', 'this person is ugly'),
'Trustworthiness': ('this person is trustworthy', 'this person is dishonest'),
'Intelligence': ('this person is smart', 'this person is stupid'),
'Quality': ('this image looks good', 'this image looks bad'),
}
st.set_page_config(page_title='AI Photo Rater', initial_sidebar_state="auto")
st.title('AI Photo Rater')
def log_page_load():
with open(PAGE_LOAD_LOG_FILE, 'a') as f:
f.write(f'{time.time()}\n')
def get_num_page_loads():
with open(PAGE_LOAD_LOG_FILE, 'r') as f:
return len(f.readlines())
def get_earliest_page_load_time():
with open(PAGE_LOAD_LOG_FILE, 'r') as f:
lines = f.readlines()
unix_time = float(lines[0])
date_string = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(unix_time))
return date_string
def show_sidebar_metrics():
metric_options = list(METRIC_TEXTS.keys())
default_metrics = ['Attractivness', 'Trustworthiness', 'Intelligence']
st.sidebar.title('Metrics')
# metric = st.sidebar.selectbox('Select a metric', metric_options)
selected_metrics = []
for metric in metric_options:
selected = metric in default_metrics
if st.sidebar.checkbox(metric, selected):
selected_metrics.append(metric)
with st.sidebar.expander('Metric texts'):
st.write(METRIC_TEXTS)
print("selected_metrics:", selected_metrics)
return selected_metrics
def get_custom_metric():
st.sidebar.markdown('**Custom metric**:')
metric_name = st.sidebar.text_input('Metric name', placeholder='e.g. "Youth"')
metric_target = st.sidebar.text_input('Metric target', placeholder='this person is young')
metric_opposite = st.sidebar.text_input('Metric opposite', placeholder='this person is old')
return {metric_name: (metric_target, metric_opposite)}
log_page_load()
metrics = show_sidebar_metrics()
st.sidebar.markdown('---')
custom_metric = get_custom_metric()
st.sidebar.markdown('---')
st.sidebar.write(f'Page loads: {get_num_page_loads()}')
st.sidebar.write(f'Earliest page load: {get_earliest_page_load_time()}')
metric_texts = METRIC_TEXTS
print("custom_metric:", custom_metric)
custom_key = list(custom_metric.keys())[0]
if custom_key:
custom_tuple = custom_metric[custom_key]
if custom_tuple[0] and custom_tuple[1]:
metrics.append(list(custom_metric.keys())[0])
metric_texts = {**metric_texts, **custom_metric}
os.makedirs(IMAGES_FOLDER, exist_ok=True)
# photo_file = st.file_uploader("Upload a photo", type=["jpg", "png"])
photo_files = st.file_uploader("Upload a photo", accept_multiple_files=True)
# sort them
photo_files = sorted(photo_files, key=lambda x: x.name)
if not photo_files:
st.stop()
c = Client('grpcs://demo-cas.jina.ai:2096')
@st.cache(show_spinner=False)
def rate_image(image_path, target, opposite, attempt=0):
try:
r = c.rank(
[
Document(
# uri='https://www.pngall.com/wp-content/uploads/12/Britney-Spears-PNG-Image-File.png',
uri=image_path,
matches=[
Document(text=target),
Document(text=opposite),
],
)
]
)
except ConnectionError as e:
print(e)
print(f'Retrying... {attempt}')
time.sleep(2**attempt)
return rate_image(image_path, target, opposite, attempt + 1)
text_and_scores = r['@m', ['text', 'scores__clip_score__value']]
index_of_good_text = text_and_scores[0].index(target)
score = text_and_scores[1][index_of_good_text]
return score
# @st.cache
def process_image(photo_file, metrics):
col1, col2, col3 = st.columns([10,10,10])
with st.spinner('Loading...'):
with col1:
st.write('')
with col2:
st.image(photo_file, use_column_width=True)
with col3:
st.write('')
# save it
filename = f'{time.time()}'.replace('.', '_')
filename_path = f'{IMAGES_FOLDER}/{filename}'
with open(f'{filename_path}', 'wb') as f:
f.write(photo_file.read())
with st.spinner('Rating your photo...'):
scores = dict()
for metric in metrics:
target = metric_texts[metric][0]
opposite = metric_texts[metric][1]
score = rate_image(filename_path, target, opposite)
scores[metric] = score
scores['Avg'] = sum(scores.values()) / len(scores)
# plot them
import plotly.graph_objects as go
scores_percent = []
for metric in metrics:
scores_percent.append(scores[metric] * 100)
fig = go.Figure(data=[go.Bar(x=metrics, y=scores_percent)], layout=go.Layout(title='Scores'))
# range 0 to 100 for the y axis:
fig.update_layout(yaxis=dict(range=[0, 100]))
st.plotly_chart(fig, use_container_width=True)
return filename_path, scores
def get_best_image(image_scores_list, metric):
best_image = image_scores_list[0][0]
best_score = image_scores_list[0][1][metric]
for image, scores in image_scores_list[2:]:
if scores[metric] > best_score:
best_image = image
best_score = scores[metric]
return best_image
image_scores_list = []
for photo_file in photo_files:
# process_image(photo_file)
filename_path, scores = process_image(photo_file, metrics)
# image_scores_list.append((filename_path, scores))
image_scores_list.append((photo_file, scores))
st.markdown('---')
if len(photo_files) > 1:
st.title('Best image')
metric = st.selectbox('Select a metric', ['Avg'] + metrics)
image_file = get_best_image(image_scores_list, metric)
# st.image(image_file, use_column_width=True)
# from PIL import Image
# image_file = Image.open(image_path)
process_image(image_file, metrics)
st.markdown('---')
col1, col2, col3 = st.columns([10,10,10])
with col1:
st.markdown('[GiHub Repo](https://github.com/tom-doerr/ai_photo_rater)')
with col2:
st.markdown('Powered by [Jina.ai](https://jina.ai/)')