Nihal Nayak commited on
Commit
bf52e22
1 Parent(s): 61a66aa

greedy decoding of the output

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  model = AutoModelForCausalLM.from_pretrained("BatsResearch/bonito-v1")
6
  tokenizer = AutoTokenizer.from_pretrained("BatsResearch/bonito-v1")
@@ -26,10 +26,19 @@ def respond(
26
  temperature=temperature,
27
  top_p=top_p,
28
  do_sample=True,
 
 
 
 
 
 
 
 
 
29
  )
30
  pred_start = int(input_ids.shape[-1])
31
 
32
- response = tokenizer.decode(output[0][pred_start:], skip_special_tokens=True)
33
 
34
  # check if <|pipe|> is in the response
35
  if "<|pipe|>" in response:
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
5
  model = AutoModelForCausalLM.from_pretrained("BatsResearch/bonito-v1")
6
  tokenizer = AutoTokenizer.from_pretrained("BatsResearch/bonito-v1")
 
26
  temperature=temperature,
27
  top_p=top_p,
28
  do_sample=True,
29
+ stop_strings=["<|pipe|>"],
30
+ tokenizer=tokenizer,
31
+ )
32
+ output_with_greedy_response = model.generate(
33
+ output,
34
+ max_new_tokens=max_tokens,
35
+ temperature=0.0,
36
+ top_p=1.0,
37
+ do_sample=False,
38
  )
39
  pred_start = int(input_ids.shape[-1])
40
 
41
+ response = tokenizer.decode(output_with_greedy_response[0][pred_start:], skip_special_tokens=True)
42
 
43
  # check if <|pipe|> is in the response
44
  if "<|pipe|>" in response: