File size: 5,618 Bytes
09d15e8
27da979
 
 
 
 
 
 
09d15e8
 
 
 
 
935c638
09d15e8
af12f2c
 
09d15e8
 
 
27da979
 
 
09d15e8
 
af12f2c
 
 
 
 
09d15e8
 
 
27da979
 
 
 
 
 
 
5209ab6
 
27da979
 
 
09d15e8
 
 
 
 
27da979
 
 
09d15e8
 
 
 
 
0240ed4
09d15e8
 
27da979
5209ab6
 
27da979
09d15e8
 
 
 
 
 
 
 
 
27da979
09d15e8
 
 
 
27da979
09d15e8
 
 
27da979
 
 
09d15e8
 
 
 
 
 
 
 
 
2f48801
09d15e8
 
 
7c1503a
27da979
09d15e8
 
2f48801
09d15e8
 
 
6958233
 
2f48801
6958233
27da979
 
09d15e8
 
27da979
 
 
 
 
 
09d15e8
 
27da979
 
 
 
 
 
09d15e8
af12f2c
 
7c1503a
983046b
7c1503a
af12f2c
 
 
5209ab6
 
 
 
27da979
 
5209ab6
09d15e8
5209ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27da979
 
5209ab6
27da979
 
 
 
5209ab6
27da979
 
09d15e8
27da979
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from threading import Thread
import logging
import time

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer

model_id = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Running on device:\t {torch_device}")
logging.info(f"CPU threads:\t {torch.get_num_threads()}")


if torch_device == "cuda":
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_id, load_in_8bit=True, device_map="auto"
    )
else:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
try:
    model = torch.compile(model)
except Exception as e:
    logging.error(f"Unable to compile model:\t{e}")
    
tokenizer = AutoTokenizer.from_pretrained(model_id)


def run_generation(
    user_text,
    top_p,
    temperature,
    top_k,
    max_new_tokens,
    repetition_penalty=1.1,
    length_penalty=1.0,
    no_repeat_ngram_size=4,
    use_generation_config=False,
):
    st = time.perf_counter()
    # Get the model and tokenizer, and tokenize the user text.
    model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        num_beams=1,
        top_p=top_p,
        temperature=float(temperature),
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    logging.info("Total rt:\t{rt} sec".format(rt=round(time.perf_counter() - st, 3)))
    return model_output


def reset_textbox():
    return gr.update(value="")


with gr.Blocks() as demo:
    duplicate_link = (
        "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
    )
    gr.Markdown(
        "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
        "This demo showcases the use of the "
        "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
        "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
        f"[{model_id}](https://huggingface.co/{model_id}) and the Spaces free compute tier.\n\n"
        f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
        "template! 💛"
    )
    gr.Markdown("---")
    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                value="How to become a polar bear tamer?",
                label="User input",
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit", variant="primary")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=32,
                maximum=1024,
                value=256,
                step=32,
                interactive=True,
                label="Max New Tokens",
            )
            top_p = gr.Slider(
                minimum=0.05,
                maximum=1.0,
                value=0.95,
                step=0.05,
                interactive=True,
                label="Top-p (nucleus sampling)",
            )
            top_k = gr.Slider(
                minimum=1,
                maximum=50,
                value=50,
                step=1,
                interactive=True,
                label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=1.4,
                value=0.3,
                step=0.05,
                interactive=True,
                label="Temperature",
            )
            repetition_penalty = gr.Slider(
                minimum=0.9,
                maximum=2.5,
                value=1.1,
                step=0.1,
                interactive=True,
                label="Repetition Penalty",
            )
            length_penalty = gr.Slider(
                minimum=0.8,
                maximum=1.5,
                value=1.0,
                step=0.1,
                interactive=True,
                label="Length Penalty",
            )
            # temperature = gr.Slider(
            #     minimum=0.1,
            #     maximum=5.0,
            #     value=0.8,
            #     step=0.1,
            #     interactive=True,
            #     label="Temperature",
            # )
    user_text.submit(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
        model_output,
    )
    button_submit.click(
        run_generation,
        [user_text, top_p, temperature, top_k, max_new_tokens, repetition_penalty, length_penalty],
        model_output,
    )

    demo.queue(max_size=32).launch(enable_queue=True)