NPG commited on
Commit
add7d6f
·
1 Parent(s): edef475

input_ids to float32

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -19,20 +19,20 @@ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_m
19
  """###Interface"""
20
 
21
  def generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
22
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
23
- outputs = model.generate(input_ids,
24
- min_length=minimum_length,
25
- max_new_tokens=maximum_length,
26
- length_penalty=1.4,
27
- num_beams=12,
28
- no_repeat_ngram_size=3,
29
- temperature=temperature,
30
- top_k=100,
31
- top_p=0.9,
32
- repetition_penalty=repetition_penalty,
33
- )
34
-
35
- return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()
36
 
37
  title = "Flan-T5-XL GRADIO GUI"
38
 
 
19
  """###Interface"""
20
 
21
  def generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
22
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("float32").to(device)
23
+ outputs = model.generate(input_ids,
24
+ min_length=minimum_length,
25
+ max_new_tokens=maximum_length,
26
+ length_penalty=1.4,
27
+ num_beams=12,
28
+ no_repeat_ngram_size=3,
29
+ temperature=temperature,
30
+ top_k=100,
31
+ top_p=0.9,
32
+ repetition_penalty=repetition_penalty,
33
+ )
34
+
35
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()
36
 
37
  title = "Flan-T5-XL GRADIO GUI"
38