File size: 1,905 Bytes
3d02b97
 
95618c0
3d02b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b4aa0
3d02b97
 
 
 
46b4aa0
3d02b97
 
 
 
 
 
 
 
 
 
 
 
46b4aa0
 
3d02b97
bd830ea
3d02b97
46b4aa0
3d02b97
 
 
 
 
 
 
 
 
 
bd830ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""#Imports"""

#!pip install transformers gradio accelerate bitsandbytes sentencepiece

import multiprocessing
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr

"""#Code"""

device = "cuda:0" if torch.cuda.is_available() else "cpu"

"""##FP 32"""

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto")

"""###Interface"""

def generate(input_text, minimum_length, maximum_length, beam_amount, temperature, repetition_penalty):
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
  outputs = model.generate(input_ids,
                          min_length=minimum_length,
                          max_new_tokens=maximum_length,
                          length_penalty=1.4,
                          num_beams=beam_amount,
                          no_repeat_ngram_size=3,
                          temperature=temperature,
                          top_k=100,
                          top_p=0.9,
                          repetition_penalty=repetition_penalty,
                          )

  return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()

title = "Flan-T5-XL Inference on GRADIO GUI"

def inference(input_text, minimum_length, maximum_length, beam_amount, temperature, repetition_penalty):
  return generate(input_text, minimum_length, maximum_length, beam_amount, temperature, repetition_penalty)

gr.Interface(
    fn=inference,
    inputs=[gr.Textbox(lines=4), gr.Slider(0, 300), gr.Slider(100, 2000), gr.Slider(1, 12, step=1), gr.Slider(1, 3, step=0.1)],
    outputs=[
        gr.Textbox(lines=2, label="Flan-T5-XL Inference")
    ],
    title=title,
    css="""
    body {background-color: lightgreen}
    .input_text input {
      background-color: lightblue !important;
    }
    """
).launch()