jbochi commited on
Commit
1eab707
1 Parent(s): 2553db6

Update generation config

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import T5Tokenizer, T5ForConditionalGeneration
2
  import gradio as gr
3
 
4
  MODEL_NAME = "jbochi/madlad400-3b-mt"
@@ -17,7 +17,10 @@ print("T5ForConditionalGeneration loaded from pretrained.")
17
 
18
  def inference(max_length, input_text, history=[]):
19
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
20
- outputs = model.generate(input_ids, max_length=max_length, bos_token_id=2)
 
 
 
21
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
  history.append((input_text, result))
23
  return history, history
 
1
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, GenerationConfig
2
  import gradio as gr
3
 
4
  MODEL_NAME = "jbochi/madlad400-3b-mt"
 
17
 
18
  def inference(max_length, input_text, history=[]):
19
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
20
+ outputs = model.generate(
21
+ input_ids=input_ids,
22
+ generation_config=GenerationConfig(decoder_start_token_id=2),
23
+ )
24
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  history.append((input_text, result))
26
  return history, history