File size: 2,101 Bytes
3d02b97
 
95618c0
3d02b97
24fb9f1
3d02b97
 
 
 
 
 
9a6e806
af0da01
3d02b97
edef475
3d02b97
 
edef475
3d02b97
 
 
36558f0
ebd7f42
add7d6f
 
 
 
9a6e806
add7d6f
 
 
 
 
 
 
 
3d02b97
edef475
3d02b97
36558f0
 
3d02b97
bd830ea
3d02b97
9a6e806
3d02b97
a1e9152
3d02b97
 
 
 
 
 
 
 
bd830ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""#Imports"""

#!pip install transformers gradio accelerate bitsandbytes sentencepiece

#import multiprocessing
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr

"""#Code"""

#torch.set_default_dtype(torch.float16)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

"""##FP 16"""

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)

"""###Interface"""

def generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
    outputs = model.generate(input_ids,
                            min_length=minimum_length,
                            max_new_tokens=maximum_length,
                            length_penalty=1.4,
                            num_beams=6,
                            no_repeat_ngram_size=3,
                            temperature=temperature,
                            top_k=100,
                            top_p=0.9,
                            repetition_penalty=repetition_penalty,
                            )

    return tokenizer.decode(outputs[0], skip_special_tokens=True).capitalize()

title = "Flan-T5-XL GRADIO GUI"

def inference(input_text, minimum_length, maximum_length, temperature, repetition_penalty):
  return generate(input_text, minimum_length, maximum_length, temperature, repetition_penalty)

gr.Interface(
    fn=inference,
    inputs=[gr.Textbox(lines=4, label="Input"), gr.Slider(0, 300, value=20, step=10, label="Minimum length"), gr.Slider(100, 2000, value=1000, step=100, label="Maximum length"), gr.Slider(0, 1, value=0.75, step=0.05, label="Temperature"), gr.Slider(1, 3, value=2.1, step=0.1, label="Repetition penalty")],
    outputs=[
        gr.Textbox(lines=2, label="Output")
    ],
    title=title,
    css="""
    body {background-color: lightgreen}
    .input_text input {
      background-color: lightblue !important;
    }
    """
).launch()