NPG commited on
Commit
edef475
·
1 Parent(s): 5e389d5
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -11,10 +11,10 @@ import gradio as gr
11
 
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
- """##FP 32"""
15
 
16
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
17
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto")
18
 
19
  """###Interface"""
20
 
@@ -24,7 +24,7 @@ def generate(input_text, minimum_length, maximum_length, temperature, repetition
24
  min_length=minimum_length,
25
  max_new_tokens=maximum_length,
26
  length_penalty=1.4,
27
- num_beams=6,
28
  no_repeat_ngram_size=3,
29
  temperature=temperature,
30
  top_k=100,
@@ -34,14 +34,14 @@ def generate(input_text, minimum_length, maximum_length, temperature, repetition
34
 
35
  return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()
36
 
37
- title = "Flan-T5-XL Inference on GRADIO GUI"
38
 
39
  def inference(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
40
  return generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty)
41
 
42
  gr.Interface(
43
  fn=inference,
44
- inputs=[gr.Textbox(lines=4, label="Input text"), gr.Slider(0, 300, value=20, step=10, label="Minimum length"), gr.Slider(100, 2000, value=1000, step=100, label="Maximum length"), gr.Slider(0, 2, value=0.8, step=0.1, label="Temperature"), gr.Slider(1, 3, value=2.1, step=0.1, label="Repetition penalty")],
45
  outputs=[
46
  gr.Textbox(lines=2, label="Flan-T5-XL Inference")
47
  ],
 
11
 
12
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
+ """##FP 16"""
15
 
16
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
17
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
18
 
19
  """###Interface"""
20
 
 
24
  min_length=minimum_length,
25
  max_new_tokens=maximum_length,
26
  length_penalty=1.4,
27
+ num_beams=12,
28
  no_repeat_ngram_size=3,
29
  temperature=temperature,
30
  top_k=100,
 
34
 
35
  return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()
36
 
37
+ title = "Flan-T5-XL GRADIO GUI"
38
 
39
  def inference(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
40
  return generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty)
41
 
42
  gr.Interface(
43
  fn=inference,
44
+ inputs=[gr.Textbox(lines=4, label="Input text"), gr.Slider(0, 300, value=20, step=10, label="Minimum length"), gr.Slider(100, 2000, value=1000, step=100, label="Maximum length"), gr.Slider(0, 2, value=0.7, step=0.1, label="Temperature"), gr.Slider(1, 3, value=2.1, step=0.1, label="Repetition penalty")],
45
  outputs=[
46
  gr.Textbox(lines=2, label="Flan-T5-XL Inference")
47
  ],