File size: 1,008 Bytes
c70da6e e648a11 c70da6e 8cc2393 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
trained_tokenizer = GPT2Tokenizer.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
trained_model = GPT2LMHeadModel.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
untrained_model = GPT2Tokenizer.from_pretrained("gpt2")
untrained_tokenizer = ("gpt2")
def generate(commentary_text):
input_ids = trained_tokenizer(commentary_text, return_tensors="pt")
input_ids = input_ids['input_ids'].to(device)
output = trained_model.generate(input_ids, max_length=60, num_beams=5, do_sample=False)
return tokenizer_finetuned.decode(output[0])
# Create Gradio interface
iface = gr.Interface(fn=generate_text,
inputs="text",
outputs="text",
title="GPT-2 Text Generation",
description="Enter a prompt and GPT-2 will generate the continuation of the text.")
# Launch the app
if __name__ == "__main__":
iface.launch() |