gptj-playground / app.py
wastella's picture
first
afd794c
raw
history blame
735 Bytes
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
def generate(text):
input_ids = tokenizer(text, return_tensors="pt").input_ids
gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100)
return tokenizer.batch_decode(gen_tokens)[0]
gradio_ui = gr.Interface(fn=generate, title="Use GPT-J:", description="Put your text into the box below, and have the GPT-J open source model generate the next 100 characters!", inputs=gr.inputs.Textbox(label="Put your text here!"), outputs=gr.inputs.Textbox(label="Your text:"))
gradio_ui.launch()