Spaces:
Runtime error
Runtime error
# Standard | |
from typing import List | |
# Third party | |
import huggingface_hub | |
import streamlit as st | |
import torch | |
class Vocab: | |
def __init__(self, file_path): | |
with open(file_path) as f: | |
tokens = f.read() | |
self.stoi = {token: i for i, token in enumerate(tokens)} | |
self.itos = {i: token for i, token in enumerate(tokens)} | |
def encode(self, s: str) -> List[int]: | |
return [self.stoi[c] for c in s] | |
def decode(self, vec: List[int]) -> str: | |
return "".join(self.itos[i] for i in vec) | |
def load_vocab(): | |
file_path = huggingface_hub.hf_hub_download( | |
repo_id="szymon-piechowicz-wandb/gpt", filename="vocab.txt" | |
) | |
return Vocab(file_path) | |
def load_model(): | |
file_path = huggingface_hub.hf_hub_download( | |
repo_id="szymon-piechowicz-wandb/gpt", filename="model_jit.pt" | |
) | |
return torch.jit.load(file_path) | |
def generate( | |
vocab: Vocab, model, block_size: int, context: str = "\n", num_tokens: int = 1000 | |
) -> str: | |
""" | |
Repeatedly calls the model to generate the next token. | |
""" | |
# (T) | |
context = torch.tensor(vocab.encode(context), dtype=torch.long) | |
contexts = context.reshape((1, context.shape[0])) # (batch_size, T) | |
for _ in range(num_tokens): | |
# (batch_size, block_size) | |
contexts_cropped = contexts[:, -block_size:] | |
logits = model(contexts_cropped) # (batch_size, block_size, vocab_size) | |
# get the last time step | |
logits = logits[:, -1, :] # (batch_size, vocab_size) | |
# get probabilities | |
# (batch_size, vocab_size) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# sample from the distribution | |
next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) | |
# append to context | |
contexts = torch.cat((contexts, next), dim=1) # (batch_size, T + 1) | |
return vocab.decode(contexts[0].tolist()) | |
if __name__ == "__main__": | |
vocab = load_vocab() | |
model = load_model() | |
block_size = model.position_embedding.weight.shape[0] | |
model.eval() | |
st.text(generate(vocab=vocab, model=model, block_size=block_size)) |