File size: 2,083 Bytes
3d02b97
 
95618c0
3d02b97
 
 
 
 
 
 
 
 
 
edef475
3d02b97
 
edef475
3d02b97
 
 
36558f0
e6f01f1
add7d6f
 
 
 
 
 
 
 
 
 
 
 
 
3d02b97
edef475
3d02b97
36558f0
 
3d02b97
bd830ea
3d02b97
a1e9152
3d02b97
a1e9152
3d02b97
 
 
 
 
 
 
 
bd830ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""#Imports"""

#!pip install transformers gradio accelerate bitsandbytes sentencepiece

import multiprocessing
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr

"""#Code"""

device = "cuda:0" if torch.cuda.is_available() else "cpu"

"""##FP 16"""

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)

"""###Interface"""

def generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(dtype=torch.float16).to(device)
    outputs = model.generate(input_ids,
                            min_length=minimum_length,
                            max_new_tokens=maximum_length,
                            length_penalty=1.4,
                            num_beams=12,
                            no_repeat_ngram_size=3,
                            temperature=temperature,
                            top_k=100,
                            top_p=0.9,
                            repetition_penalty=repetition_penalty,
                            )

    return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()

title = "Flan-T5-XL GRADIO GUI"

def inference(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
  return generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty)

gr.Interface(
    fn=inference,
    inputs=[gr.Textbox(lines=4, label="Input"), gr.Slider(0, 300, value=20, step=10, label="Minimum length"), gr.Slider(100, 2000, value=1000, step=100, label="Maximum length"), gr.Slider(0, 2, value=0.7, step=0.1, label="Temperature"), gr.Slider(1, 3, value=2.1, step=0.1, label="Repetition penalty")],
    outputs=[
        gr.Textbox(lines=2, label="Output")
    ],
    title=title,
    css="""
    body {background-color: lightgreen}
    .input_text input {
      background-color: lightblue !important;
    }
    """
).launch()