Session21 / app.py
Navyabhat's picture
Upload 14 files
52db7c8 verified
raw
history blame
No virus
2.24 kB
import gradio as gr
import random
import torch
import pathlib
from src.model import GPTModel
from src.inference import generate as generate_text
from src.utils import vocab_size
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = "cuda:1" if torch.cuda.is_available() else "cpu"
eval_iters = 200
n_embeds = 384
n_heads = 6
n_layers = 6
dropout = 0.2
def load_model():
model_ckpt = torch.load("checkpoints/model.pth", map_location=device)
model = GPTModel(
vocab_size, n_embeds, block_size, n_heads, n_layers, dropout, device
)
model.load_state_dict(model_ckpt.state_dict())
return model
model = load_model()
def generate(prompt, max_new_tokens):
prompt = prompt.strip()
out = generate_text(prompt, model, block_size, max_new_tokens, device)
return {gpt_output: out}
with gr.Blocks() as app:
gr.Markdown("## ERA Session21 - GPT from scratch")
gr.Markdown(
"""This is an implementation of GPT [Let's build GPT: from scratch, in code, spelled out.](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=2s) by Andrej Karpathy.
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION21).
Dataset used to train: [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).
"""
)
with gr.Row():
with gr.Column():
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
max_new_tokens = gr.Slider(
minimum=10,
maximum=2500,
value=100,
step=10,
label="Select Number of Tokens to be Generated",
interactive=True,
)
submit_btn = gr.Button(value="Generate")
with gr.Column():
gpt_output = gr.TextArea(
label="Text Generated by GPT",
show_label=True,
max_lines=100,
interactive=False,
)
submit_btn.click(
generate,
inputs=[prompt_box, max_new_tokens],
outputs=[gpt_output],
)
app.launch()