gchhablani commited on
Commit
bea24f7
1 Parent(s): 185a893

Allow clearing of cache

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,10 +1,9 @@
1
  from io import BytesIO
2
  import streamlit as st
3
  import pandas as pd
4
- import json
5
  import os
6
  import numpy as np
7
- from streamlit.elements import markdown
8
  from PIL import Image
9
  from model.flax_clip_vision_marian.modeling_clip_vision_marian import (
10
  FlaxCLIPVisionMarianMT,
@@ -31,7 +30,7 @@ tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
31
 
32
  @st.cache(persist=True)
33
  def generate_sequence(pixel_values, num_beams, temperature, top_p):
34
- output_ids = model.generate(input_ids=pixel_values, max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
35
  print(output_ids)
36
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
37
  return output_sequence
@@ -60,7 +59,8 @@ st.sidebar.title("Generation Parameters")
60
  num_beams = st.sidebar.number_input("Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
61
  temperature = st.sidebar.select_slider("Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
62
  top_p = st.sidebar.select_slider("Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
63
-
 
64
 
65
  image_col, intro_col = st.beta_columns([3, 8])
66
  image_col.image("./misc/sic-logo.png", use_column_width="always")
@@ -84,6 +84,10 @@ with st.beta_expander("Article"):
84
  st.write(read_markdown("acknowledgements.md"))
85
 
86
 
 
 
 
 
87
  first_index = 20
88
  # Init Session State
89
  if state.image_file is None:
@@ -124,8 +128,7 @@ new_col2.markdown(
124
  f"""**English Translation**: {translate(state.caption, 'en')}"""
125
  )
126
 
127
- with st.spinner("Loading model..."):
128
- model = load_model(checkpoints[0])
129
  sequence = ['']
130
  if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
131
  with st.spinner("Generating Sequence..."):
 
1
  from io import BytesIO
2
  import streamlit as st
3
  import pandas as pd
 
4
  import os
5
  import numpy as np
6
+ from streamlit import caching
7
  from PIL import Image
8
  from model.flax_clip_vision_marian.modeling_clip_vision_marian import (
9
  FlaxCLIPVisionMarianMT,
 
30
 
31
  @st.cache(persist=True)
32
  def generate_sequence(pixel_values, num_beams, temperature, top_p):
33
+ output_ids = state.model.generate(input_ids=pixel_values, max_length=64, num_beams=num_beams, temperature=temperature, top_p = top_p)
34
  print(output_ids)
35
  output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
36
  return output_sequence
 
59
  num_beams = st.sidebar.number_input("Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
60
  temperature = st.sidebar.select_slider("Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
61
  top_p = st.sidebar.select_slider("Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
62
+ if st.sidebar.button("Clear All Cache"):
63
+ caching.clear_cache()
64
 
65
  image_col, intro_col = st.beta_columns([3, 8])
66
  image_col.image("./misc/sic-logo.png", use_column_width="always")
 
84
  st.write(read_markdown("acknowledgements.md"))
85
 
86
 
87
+ if state.model is None:
88
+ with st.spinner("Loading model..."):
89
+ state.model = load_model(checkpoints[0])
90
+
91
  first_index = 20
92
  # Init Session State
93
  if state.image_file is None:
 
128
  f"""**English Translation**: {translate(state.caption, 'en')}"""
129
  )
130
 
131
+
 
132
  sequence = ['']
133
  if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
134
  with st.spinner("Generating Sequence..."):