Spaces:
Build error
Build error
import transformers | |
from transformers import BloomForCausalLM | |
from transformers import BloomTokenizerFast | |
import torch | |
import gradio as gr | |
# setting device on GPU if available, else CPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(device) | |
model_name = "bigscience/bloom-1b1" | |
model = BloomForCausalLM.from_pretrained(model_name,) | |
tokenizer = BloomTokenizerFast.from_pretrained(model_name) | |
# Define the pipeline for Gradio purpose | |
def beam_gradio_pipeline(prompt,length=100): | |
result_length = length | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
return tokenizer.decode(model.generate(inputs["input_ids"], | |
max_length=result_length, | |
num_beams=2, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
)[0]) | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Andrew Lim Bloom LLM </center></h1>") | |
gr.Markdown("""<h2><center>Generate your story with a sentence or ask a question:<br><br> | |
<img src=https://aeiljuispo.cloudimg.io/v7/https://s3.amazonaws.com/moonup/production/uploads/1634806038075-5df7e9e5da6d0311fd3d53f9.png?w=200&h=200&f=face width=200px></center></h2>""") | |
gr.Markdown("""<center>******</center>""") | |
input_text = gr.Textbox(label="Prompt", lines=6) | |
buton = gr.Button("Submit ") | |
output_text = gr.Textbox(lines=6, label="The story start with :") | |
buton.click(beam_gradio_pipeline, inputs=[input_text], outputs=output_text) | |
demo.launch() |