zamborg commited on
Commit
defbed4
1 Parent(s): 63ce95a

added image caching for hopefully easier use

Browse files
Files changed (2) hide show
  1. app.py +14 -7
  2. model.py +2 -3
app.py CHANGED
@@ -31,20 +31,20 @@ st.sidebar.markdown(
31
 
32
  with st.spinner("Loading Model"):
33
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
 
 
 
34
 
35
- random_image = get_rand_img(sample_images)
36
- rand_idx = 0
37
 
38
  st.sidebar.title("Select a sample image")
39
 
40
  if st.sidebar.button("Random Sample Image"):
41
- rand_idx, _ = get_rand_img(sample_images)
42
- sample_image = None
43
 
44
  sample_image = st.sidebar.selectbox(
45
  "",
46
  sample_images,
47
- index=rand_idx
48
  )
49
 
50
  st.sidebar.title("Select a Subreddit")
@@ -66,6 +66,7 @@ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
66
  submitted = st.form_submit_button("Submit")
67
  if uploaded_file is not None and submitted:
68
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
 
69
 
70
 
71
 
@@ -75,9 +76,15 @@ if uploaded_image is None and submitted:
75
  else:
76
  image_file = sample_image
77
 
78
- image = uploaded_image if uploaded_image is not None else Image.open(image_file)
 
 
 
 
 
 
79
 
80
- #cache the image
81
 
82
  image_dict = imageLoader.transform(image)
83
 
31
 
32
  with st.spinner("Loading Model"):
33
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
34
+
35
+
36
+ select_idx = None
37
 
 
 
38
 
39
  st.sidebar.title("Select a sample image")
40
 
41
  if st.sidebar.button("Random Sample Image"):
42
+ select_idx = get_rand_idx(sample_images)
 
43
 
44
  sample_image = st.sidebar.selectbox(
45
  "",
46
  sample_images,
47
+ index = 0 if select_idx is None else select_idx
48
  )
49
 
50
  st.sidebar.title("Select a Subreddit")
66
  submitted = st.form_submit_button("Submit")
67
  if uploaded_file is not None and submitted:
68
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
69
+ select_idx = None # set this to help rewrite the cache
70
 
71
 
72
 
76
  else:
77
  image_file = sample_image
78
 
79
+ # LOAD AND CACHE THE IMAGE
80
+ if uploaded_image is not None:
81
+ image = uploaded_image
82
+ elif select_idx is None:
83
+ image = st.session_state.image
84
+ else:
85
+ image = Image.open(image_file)
86
 
87
+ st.session_state.image = image
88
 
89
  image_dict = imageLoader.transform(image)
90
 
model.py CHANGED
@@ -124,9 +124,8 @@ def download_files():
124
  def get_samples():
125
  return glob.glob(SAMPLES_PATH)
126
 
127
- def get_rand_img(samples):
128
- i = random.randint(0,len(samples)-1)
129
- return i, samples[i]
130
 
131
  @st.cache
132
  def create_objects():
124
  def get_samples():
125
  return glob.glob(SAMPLES_PATH)
126
 
127
+ def get_rand_idx(samples):
128
+ return random.randint(0,len(samples)-1)
 
129
 
130
  @st.cache
131
  def create_objects():