File size: 4,123 Bytes
598e715
 
83b97e5
 
598e715
 
cac1163
598e715
cac1163
598e715
 
 
 
 
f516396
598e715
 
 
 
 
163984a
598e715
 
 
 
 
 
 
da39f22
598e715
da39f22
7ec097e
83b97e5
598e715
 
 
7ec097e
598e715
 
 
 
 
 
 
 
 
83b97e5
 
 
 
598e715
 
 
 
 
 
ba54a72
0233b91
598e715
 
 
 
 
da39f22
598e715
ba54a72
 
a32d620
 
ba54a72
 
 
 
 
 
 
 
 
a32d620
9aae32c
a32d620
598e715
 
ba54a72
 
598e715
 
ba54a72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch 
from torch import nn
import bitsandbytes as bnb
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, device, n_heads, embed_dim, ffn_dim, dropout):
        super().__init__()
        self.device = device
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.self_attention = nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True)
        
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(nn.Linear(embed_dim, ffn_dim),
                                 nn.SiLU(),
                                 nn.Dropout(p=dropout),
                                 nn.Linear(ffn_dim, embed_dim),
                                 nn.Dropout(p=dropout))

    def forward(self, x):
        causal_mask = nn.Transformer.generate_square_subsequent_mask(x.shape[1]).to(self.device)
        x_in = self.layer_norm1(x)
        x = x + self.self_attention(x_in, x_in, x_in, attn_mask=causal_mask)[0]
        x = x + self.ffn(self.layer_norm2(x))
        return x


class GPTModel(nn.Module):
    def __init__(self, device, n_layers, n_heads, embed_dim, ffn_dim, n_vocab, max_seq_len, dropout):
        super().__init__()
        self.device = device
        self.max_seq_len = max_seq_len
        self.embedding = bnb.nn.StableEmbedding(n_vocab, embed_dim)
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)

        self.transformer_blocks = nn.ModuleList([TransformerBlock(device, n_heads, embed_dim, ffn_dim, dropout) for _ in range(n_layers)])

        self.layer_norm = nn.LayerNorm(embed_dim)
        self.output_projection = nn.Linear(embed_dim, n_vocab, bias=False)
        self.output_projection.weight = self.embedding.weight # Using weight sharing

    def forward(self, input_tokens):
        input_embed = self.embedding(input_tokens)
        positions = torch.arange(0, input_tokens.size(1), device=input_tokens.device).unsqueeze(0)
        input_embed = input_embed + self.positional_embedding(positions)
        x = input_embed

        for block in self.transformer_blocks:
            x = checkpoint(block, x, use_reentrant=False)

        x = self.layer_norm(x)
        x = self.output_projection(x)
        return x

    # Generate a completion from the model
    def generate(self, input_ctx, max_length, temperature=1.0, top_p=None, top_k=None):
        assert max_length <= self.max_seq_len

        self.eval()
        context_list = [i for i in input_ctx]
        with torch.no_grad():
            while len(context_list) < max_length:
                context = torch.tensor(context_list, dtype=torch.long, device=self.device).unsqueeze(0)
                logits = self(context)[0,-1,:]

                if top_p:
                    probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
                    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
                    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                    sorted_mask = cumulative_probs <= top_p
                    sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
                    sorted_mask[..., 0] = True
                    filtered_probs = sorted_probs * sorted_mask
                    filtered_probs = filtered_probs / filtered_probs.sum()
                    probs_sampled = torch.multinomial(filtered_probs, 1).item()
                    selected_token = sorted_indices[probs_sampled].item()
                elif top_k:
                    scaled_logits = logits / temperature
                    topk_probs, topk_indices = scaled_logits.topk(top_k)
                    probs = nn.functional.softmax(topk_probs, dim=-1)
                    probs_sampled = torch.multinomial(probs, 1).item()
                    selected_token = topk_indices[probs_sampled].item()
                else: # Greedy decoding
                    selected_token = logits.argmax(dim=-1).item()

                context_list.append(selected_token)
        return context_list