File size: 2,099 Bytes
61c7634
e15dae8
c9911aa
17bb1f6
 
fae45ed
61c7634
236866f
c34d9ea
8465f52
 
3ce1a28
34700d7
78f266b
34700d7
d30d4ce
684d9c0
921054e
ec11b9a
6f3fb83
62635cf
17bb1f6
5c5bd98
6f3fb83
34700d7
 
 
 
cd7c7ec
6115563
31fec50
34700d7
 
62635cf
a1b8369
 
1377bb8
8465f52
 
1377bb8
ed60eed
5c5bd98
8465f52
 
 
 
 
 
 
 
 
 
ec11b9a
ed60eed
 
 
 
cf1f1df
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from turtle import color, onclick
import streamlit as st
from PIL import Image, ImageOps
import glob
import json
import requests
import random
import io

if 'show' not in st.session_state:
    st.session_state.show = False

if 'example_idx' not in st.session_state:
    st.session_state.example_idx = 0

st.set_page_config(layout="wide")
st.markdown("**This is a demo of the *ImageCoDe* dataset. Sample an example on the left and compare all the images with index counter on the right. At the bottom there are buttons to show/hide the groundtruth!**")

col1, col2 = st.columns(2)

prefix = 'https://raw.githubusercontent.com/BennoKrojer/imagecode-val-set/main/image-sets-val/'
set2ids = json.load(open('set2ids.json', 'r'))
descriptions = json.load(open('valid_list.json', 'r'))

# example_idx = int(col1.number_input('Sample an example (description + corresponding images) from the validation set', value=0, min_value=0, max_value=len(descriptions)-1))
if col1.button('Sample an example (description + corresponding images) from the validation set'):
    st.session_state.example_idx = random.randint(0, len(descriptions)-1)
img_set, idx, descr = descriptions[st.session_state.example_idx]
idx = int(idx)
images = [prefix+'/'+img_set+'/'+i for i in set2ids[img_set]]
img_urls = images.copy()
index = int(col2.number_input('Image Index from 0 to 9', value=0, min_value=0, max_value=9))


col1.markdown(f'**Description**:')
col1.markdown(descr)

img = images[index]
images[index] = ImageOps.expand(Image.open(io.BytesIO(requests.get(images[index], stream=True).content)),border=20,fill='blue')


caps = list(range(10))
cap = str(index)

if st.session_state.show:
    caps[idx] = f'{idx} (TARGET IMAGE)'
    if idx == index:
        cap = f'{idx} (TARGET IMAGE)'
else:
    caps[idx] = f'{idx}'
    if idx == index:
        cap = f'{idx}'

if col1.button('Show groundtruth target image'):
    st.session_state.show = True
if col1.button('Hide groundtruth target image'):
    st.session_state.show = False

col1.image(img, use_column_width=True, caption=cap)
col2.image(images, width=175, caption=caps)