Norod78's picture
Update app.py
d4aa46c verified
raw history blame
No virus
1.64 kB
import os
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import torch
tok = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
model.to(device)
def generate(text = "", max_new_tokens = 128):
streamer = TextIteratorStreamer(tok, timeout=10.)
if len(text) == 0:
text = " "
inputs = tok([text], return_tensors="pt").to(device)
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=max_new_tokens, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
yield generated_text + new_text
generated_text += new_text
if tok.eos_token in generated_text:
generated_text = generated_text[: generated_text.find(tok.eos_token) if tok.eos_token else None]
streamer.end()
yield generated_text
return
return generated_text
demo = gr.Interface(
title="TextIteratorStreamer + Gradio demo",
fn=generate,
inputs=[gr.Textbox(lines=5, label="Input Text"),
gr.Slider(value=128,minimum=5, maximum=256, step=1, label="Maximum number of new tokens")],
outputs=gr.Textbox(label="Generated Text"),
allow_flagging="never"
)
demo.queue()
demo.launch()