ybelkada commited on
Commit
9d25789
1 Parent(s): 64ec3ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -22,6 +22,10 @@ def inference(raw_image, question, decoding_strategy):
22
  inputs["do_sample"] = True
23
  inputs["top_k"] = 50
24
  inputs["top_p"] = 0.95
 
 
 
 
25
 
26
  out = model_image_captioning.generate(**inputs)
27
  return processor.batch_decode(out, skip_special_tokens=True)[0]
@@ -29,7 +33,7 @@ def inference(raw_image, question, decoding_strategy):
29
  inputs = [
30
  gr.inputs.Image(type='pil'),
31
  gr.inputs.Textbox(lines=2, label="Context (optional)"),
32
- gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")
33
  ]
34
  outputs = gr.outputs.Textbox(label="Output")
35
 
 
22
  inputs["do_sample"] = True
23
  inputs["top_k"] = 50
24
  inputs["top_p"] = 0.95
25
+ elif decoding_strategy == "Contrastive search":
26
+ inputs["penalty_alpha"] = 0.6
27
+ inputs["top_k"] = 4
28
+ inputs["max_length"] = 512
29
 
30
  out = model_image_captioning.generate(**inputs)
31
  return processor.batch_decode(out, skip_special_tokens=True)[0]
 
33
  inputs = [
34
  gr.inputs.Image(type='pil'),
35
  gr.inputs.Textbox(lines=2, label="Context (optional)"),
36
+ gr.inputs.Radio(choices=["Beam search","Nucleus sampling", "Contrastive search"], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")
37
  ]
38
  outputs = gr.outputs.Textbox(label="Output")
39