g8a9 commited on
Commit
d2f9c91
1 Parent(s): 13702fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -16,11 +16,28 @@ st.title("Image Captioning with ViT & GePpeTto 🇮🇹")
16
 
17
  st.sidebar.markdown("## Generation parameters")
18
  max_length = st.sidebar.number_input("Max length", value=20, min_value=1)
19
- num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1)
20
- early_stopping = st.sidebar.checkbox("Early stopping", value=True)
21
- no_repeat_ngram_size= st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1)
22
  num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def generate_caption(url):
26
  image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
@@ -28,10 +45,9 @@ def generate_caption(url):
28
  generated_ids = model.generate(
29
  inputs["pixel_values"],
30
  max_length=20,
31
- num_beams=5,
32
- early_stopping=True,
33
  no_repeat_ngram_size=2,
34
  num_return_sequences=3,
 
35
  )
36
  captions = tokenizer.batch_decode(
37
  generated_ids,
 
16
 
17
  st.sidebar.markdown("## Generation parameters")
18
  max_length = st.sidebar.number_input("Max length", value=20, min_value=1)
19
+ no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1)
 
 
20
  num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1)
21
 
22
+ gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"])
23
+ if gen_mode == "beam_search":
24
+ num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1)
25
+ early_stopping = st.sidebar.checkbox("Early stopping", value=True)
26
+ gen_params = {
27
+ "num_beams": num_beams,
28
+ "early_stopping": early_stopping
29
+ }
30
+ else:
31
+ do_sample = True
32
+ top_k = st.sidebar.number_input("topk", value=30, min_value=0)
33
+ top_p = st.sidebar.number_input("topk", value=0, min_value=0)
34
+ temperature = st.sidebar.number_input("topk", value=0.7, min_value=0)
35
+ gen_params = {
36
+ "do_sample": do_sample,
37
+ "top_k": top_k,
38
+ "top_p": top_p,
39
+ "temperature": temperature
40
+ }
41
 
42
  def generate_caption(url):
43
  image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
 
45
  generated_ids = model.generate(
46
  inputs["pixel_values"],
47
  max_length=20,
 
 
48
  no_repeat_ngram_size=2,
49
  num_return_sequences=3,
50
+ **gen_params
51
  )
52
  captions = tokenizer.batch_decode(
53
  generated_ids,