Session21 / app.py
Navyabhat's picture
Upload 14 files
52db7c8 verified
import gradio as gr
import random
import torch
import pathlib
from src.model import GPTModel
from src.inference import generate as generate_text
from src.utils import vocab_size
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = "cuda:1" if torch.cuda.is_available() else "cpu"
eval_iters = 200
n_embeds = 384
n_heads = 6
n_layers = 6
dropout = 0.2
def load_model():
model_ckpt = torch.load("checkpoints/model.pth", map_location=device)
model = GPTModel(
vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
)
model.load_state_dict(model_ckpt.state_dict())
return model
model = load_model()
def generate(prompt, max_new_tokens):
prompt = prompt.strip()
out = generate_text(prompt, model, block_size, max_new_tokens, device)
return {gpt_output: out}
with gr.Blocks() as app:
gr.Markdown("## ERA Session21 - GPT from scratch")
gr.Markdown(
"""This is an implementation of GPT [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2s) by Andrej Karpathy.
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION21).
Dataset used to train: [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
"""
)
with gr.Row():
with gr.Column():
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
max_new_tokens = gr.Slider(
minimum=10,
maximum=2500,
value=100,
step=10,
label="Select Number of Tokens to be Generated",
interactive=True,
)
submit_btn = gr.Button(value="Generate")
with gr.Column():
gpt_output = gr.TextArea(
label="Text Generated by GPT",
show_label=True,
max_lines=100,
interactive=False,
)
submit_btn.click(
generate,
inputs=[prompt_box, max_new_tokens],
outputs=[gpt_output],
)
app.launch()