sandeepmajumdar's picture
Update app.py
6d0f31a
raw
history blame
986 Bytes
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_default_tensor_type(torch.cuda.FloatTensor)
def generate(prompt):
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b1")
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b1", use_cache=True)
input_ids = tokenizer(prompt, return_tensors="pt").to(0)
sample = model.generate(**input_ids, max_length=100, num_beams = 2, num_beam_groups = 2, top_k=1, temperature=0.9, repetition_penalty = 2.0, diversity_penalty = 0.9)
return tokenizer.decode(sample[0], truncate_before_pattern=[r"\n\n^#", "^'''", "\n\n\n"])
with gr.Blocks() as demo:
gr.Markdown("Here is a Text Generation app")
inp = gr.Textbox(label="Type your prompt here and click Run", placeholder='Example: The main cloud services provided by AWS are: ')
out = gr.Textbox()
btn = gr.Button("Run")
btn.click(fn=generate, inputs=inp, outputs=out)
demo.launch()