gpt / app.py
szymon-piechowicz-wandb's picture
Update app.py
7bccc4a
# Standard
from typing import List
# Third party
import huggingface_hub
import streamlit as st
import torch
class Vocab:
def __init__(self, file_path):
with open(file_path) as f:
tokens = f.read()
self.stoi = {token: i for i, token in enumerate(tokens)}
self.itos = {i: token for i, token in enumerate(tokens)}
def encode(self, s: str) -> List[int]:
return [self.stoi[c] for c in s]
def decode(self, vec: List[int]) -> str:
return "".join(self.itos[i] for i in vec)
def load_vocab():
file_path = huggingface_hub.hf_hub_download(
repo_id="szymon-piechowicz-wandb/gpt", filename="vocab.txt"
)
return Vocab(file_path)
def load_model():
file_path = huggingface_hub.hf_hub_download(
repo_id="szymon-piechowicz-wandb/gpt", filename="model_jit.pt"
)
return torch.jit.load(file_path)
def generate(
vocab: Vocab, model, block_size: int, context: str = "\n", num_tokens: int = 1000
) -> str:
"""
Repeatedly calls the model to generate the next token.
"""
# (T)
context = torch.tensor(vocab.encode(context), dtype=torch.long)
contexts = context.reshape((1, context.shape[0])) # (batch_size, T)
for _ in range(num_tokens):
# (batch_size, block_size)
contexts_cropped = contexts[:, -block_size:]
logits = model(contexts_cropped) # (batch_size, block_size, vocab_size)
# get the last time step
logits = logits[:, -1, :] # (batch_size, vocab_size)
# get probabilities
# (batch_size, vocab_size)
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution
next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
# append to context
contexts = torch.cat((contexts, next), dim=1) # (batch_size, T + 1)
return vocab.decode(contexts[0].tolist())
if __name__ == "__main__":
vocab = load_vocab()
model = load_model()
block_size = model.position_embedding.weight.shape[0]
model.eval()
st.text(generate(vocab=vocab, model=model, block_size=block_size))