szymon-piechowicz-wandb commited on
Commit
20bfdab
·
1 Parent(s): 1e281f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Third party
2
+ import huggingface_hub
3
+ import streamlit as st
4
+ import torch
5
+
6
+
7
+ class Vocab:
8
+ def __init__(self, file_path):
9
+ with open(file_path) as f:
10
+ tokens = f.read()
11
+ self.stoi = {token: i for i, token in enumerate(tokens)}
12
+ self.itos = {i: token for i, token in enumerate(tokens)}
13
+
14
+ def encode(self, s: str) -> list[int]:
15
+ return [self.stoi[c] for c in s]
16
+
17
+ def decode(self, vec: list[int]) -> str:
18
+ return "".join(self.itos[i] for i in vec)
19
+
20
+
21
+ def load_vocab():
22
+ file_path = huggingface_hub.hf_hub_download(
23
+ repo_id="szymon-piechowicz-wandb/gpt", filename="vocab.txt"
24
+ )
25
+ return Vocab(file_path)
26
+
27
+
28
+ def load_model():
29
+ file_path = huggingface_hub.hf_hub_download(
30
+ repo_id="szymon-piechowicz-wandb/gpt", filename="model_jit.pt"
31
+ )
32
+ return torch.jit.load(file_path)
33
+
34
+
35
+ def generate(
36
+ vocab: Vocab, model, block_size: int, context: str = "\n", num_tokens: int = 1000
37
+ ) -> str:
38
+ """
39
+ Repeatedly calls the model to generate the next token.
40
+ """
41
+ # (T)
42
+ context = torch.tensor(vocab.encode(context), dtype=torch.long)
43
+ contexts = context.reshape((1, context.shape[0])) # (batch_size, T)
44
+ for _ in range(num_tokens):
45
+ # (batch_size, block_size)
46
+ contexts_cropped = contexts[:, -block_size:]
47
+ logits = model(contexts_cropped) # (batch_size, block_size, vocab_size)
48
+ # get the last time step
49
+ logits = logits[:, -1, :] # (batch_size, vocab_size)
50
+ # get probabilities
51
+ # (batch_size, vocab_size)
52
+ probs = torch.nn.functional.softmax(logits, dim=-1)
53
+ # sample from the distribution
54
+ next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
55
+ # append to context
56
+ contexts = torch.cat((contexts, next), dim=1) # (batch_size, T + 1)
57
+ return vocab.decode(contexts[0].tolist())
58
+
59
+
60
+ if __name__ == "__main__":
61
+ vocab = load_vocab()
62
+ model = load_model()
63
+ block_size = model.position_embedding.weight.shape[0]
64
+ model.eval()
65
+ st.text(generate(vocab=vocab, model=model, block_size=block_size))