imagecode-demo / app.py
BennoKrojer's picture
Update app.py
4622c8d
raw
history blame
2.51 kB
from turtle import color, onclick
import streamlit as st
from PIL import Image, ImageOps
import glob
import json
import requests
import random
import io
random.seed(10)
if 'show' not in st.session_state:
st.session_state.show = False
if 'example_idx' not in st.session_state:
st.session_state.example_idx = 0
st.set_page_config(layout="wide")
st.markdown("**This is a demo of the *ImageCoDe* dataset. Sample an example on the left and compare all the images with index counter on the right. At the bottom there are buttons to show/hide the groundtruth!**")
col1, col2 = st.columns(2)
prefix = 'https://raw.githubusercontent.com/BennoKrojer/imagecode-val-set/main/image-sets-val/'
set2ids = json.load(open('set2ids.json', 'r'))
descriptions = json.load(open('valid_list.json', 'r'))
# example_idx = int(col1.number_input('Sample an example (description + corresponding images) from the validation set', value=0, min_value=0, max_value=len(descriptions)-1))
if col1.button('Sample an example (description + corresponding images) from the validation set'):
st.session_state.example_idx = random.randint(0, len(descriptions)-1)
img_set, true_idx, descr = descriptions[st.session_state.example_idx]
true_idx = int(true_idx)
images = [prefix+'/'+img_set+'/'+i for i in set2ids[img_set]]
img_urls = images.copy()
index = int(col2.number_input('Image Index from 0 to 9', value=0, min_value=0, max_value=9))
if col1.button('Click to reveal/hide groundtruth image index (try to guess yourself first!)'):
st.session_state.show = not st.session_state.show
col1.markdown(f'**Description for {img_set}**:')
col1.markdown(f'**{descr}**')
big_img = images[index]
img = Image.open(io.BytesIO(requests.get(images[index], stream=True).content))
img_width, img_height = img.size
smaller = min(img_width, img_height)
images[index]= ImageOps.expand(img,border=smaller//18,fill='blue')
caps = list(range(10))
cap = str(index)
if st.session_state.show:
caps[true_idx] = f'{true_idx} (TARGET IMAGE)'
img = Image.open(io.BytesIO(requests.get(img_urls[index], stream=True).content))
img_width, img_height = img.size
smaller = min(img_width, img_height)
images[true_idx] = ImageOps.expand(img,border=smaller//8,fill='green')
if true_idx == index:
cap = f'{true_idx} (TARGET IMAGE)'
else:
caps[true_idx] = f'{true_idx}'
if true_idx == index:
cap = f'{true_idx}'
col1.image(big_img, use_column_width=True, caption=cap)
col2.image(images, width=175, caption=caps)