File size: 18,510 Bytes
b5f6465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import os
import json
import random
import time
import streamlit as st
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer

MODEL_FILE = r'bt_8_LAYERs_100_DATA_PCT_768_EMBD_DIM_epoch_10.pt' ##place model file in same directory as app.py

# Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––

class MLP(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(), # replaced ReLU
            nn.Dropout(p=dropout),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()

        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = n_embd // n_head # Dimension of each head's key, query, and value
        assert self.head_dim * n_head == self.n_embd, "n_embd must be divisible by n_head"
        self.seq_length = seq_length
        self.drop = nn.Dropout(p=dropout)

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd, bias=False) # multi-head combining weight matrix

    def split_heads(self, x):
        B, S, D = x.size()
        # split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
        # output: [B, n_head, S, head_dim]
        return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        # use permute or transpose to reverse
        # taking a view earlier may produce a non-contiguous tensor, so we convert back because view needs a contiguous input
        B, _, S, head_dim = x.size() # _ is n_head which we will merge
        # output: [B, S, n_embd]
        return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)

    def scaled_dot_product(self, q, k, v, dropout, mask=None):
        # q,k,v are [B, n_head, S, head_dim]
        # the key transpose sets up batch multiplication s.t. wei = [B, n_head, S, S]
        wei = q @ k.transpose(-2,-1) / np.sqrt(self.head_dim)
        # mask is [B, 1, S, S], so simply broadcasted across each head and works as expected
        if mask is not None:
          wei = wei.masked_fill(mask, float('-inf'))
        wei = dropout(F.softmax(wei, dim=-1))
        out = wei @ v
        return out

    def forward(self, x, mask=None):
        # x: (B, S, n_embd)
        # Step 1 and 2: Project full query, key, value, then split via reshaping
        q = self.split_heads(self.query(x))
        k = self.split_heads(self.key(x))
        v = self.split_heads(self.value(x))

        # Step 3: Compute scaled dot-product attention with causal mask
        # not done. should use generate_mask
        attn = self.scaled_dot_product(q, k, v, self.drop, mask)

        # Step 4 and 5: Concatenate attention scores, return projected output matrix
        out = self.out(self.combine_heads(attn)) # (B, S, n_embd)
        return out

class Block(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, seq_length, dropout)
        self.mlp = MLP(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        # experimentally, apply layer norm before attention/MLP
        self.drop = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        # residual connection (stream)
        x = x + self.drop(self.sa(self.ln1(x), mask))
        x = x + self.drop(self.mlp(self.ln2(x)))
        return x

class PositionalEncoding(nn.Module):
  """
  Formula taken from the original Transformer paper:
  PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
  PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))

  See reference for more details:
  https://kikaben.com/transformers-positional-encoding/
  """
  def __init__(self, d_model, max_len):
      # just set d_model = n_embd and max_len = seq_len
      super().__init__()

      position = torch.arange(max_len).unsqueeze(1) # [max_len, 1]
      divisor = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)) # [d_model / 2, half for each of sin and cos]
      pe = torch.zeros(max_len, d_model)
      pe[:, 0::2] = torch.sin(position * divisor) # 0 for second dim or :?
      pe[:, 1::2] = torch.cos(position * divisor)
      self.register_buffer('pe', pe) # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired

  def forward(self, x):
      # x = torch.arange(seq_length) has shape [seq_length], so x.size(0) extracts it, then we index self.pe for the first seq_length mappings
      # note we do not add the positional embeddings to x itself yet, we simply return them
      # output = (seq_length, d_model=n_embd)
      return self.pe[:x.size(0)]

class BetterTransformer(nn.Module):
    def __init__(self, vocab_size, seq_length, n_embd, n_head, n_layer, pad_idx, eos_token_id, device, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd, padding_idx=pad_idx)
        # we need to make sure the embedding ignores the padding token right?
        self.position_embedding = PositionalEncoding(n_embd, seq_length)
        self.blocks = nn.Sequential(*[Block(n_embd,
                                            n_head,
                                            seq_length,
                                            dropout) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.drop = nn.Dropout(dropout)
        self.seq_length = seq_length
        self.pad_idx = pad_idx
        self.eos_token_id = eos_token_id
        self.device = device
        self.init_params()

    # optional weight initialization (e.g. Xavier uniform)
    def init_params(self, default_initialization=False):
        if not default_initialization:
            for name, p in self.named_parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

    def get_causal_mask(self, x):
        """
        Generates causal mask for decoding
        """
        seq_len = x.size(-1) # x = (batch_size x seq_len)
        attn_shape = (1, seq_len, seq_len)
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') # k = 1 shifts the diagonal, so that the main diagonal gets 0's
        return (torch.from_numpy(subsequent_mask) == 0).to(self.device) # (1, seq_len x seq_len)
        # True along main diagonal + below, False elsewhere

    def get_pad_mask(self, x, pad_idx):
        """
        Generates padding mask
        """
        return (x != pad_idx).unsqueeze(1).unsqueeze(-2).to(self.device)
        # (batch_size x 1 x 1 x seq_len)

    def forward(self, x, targets=None):

        # should alr be int64 tokens but explicit cast in case
        x = x.to(torch.int64)
        B, S = x.shape

        # get mask
        mask = self.get_pad_mask(x, self.pad_idx) & self.get_causal_mask(x).to(self.device)
        # mask = (batch_size x 1 x seq_len x seq_len)

        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(S))
        x = self.drop(tok_emb + pos_emb)
        # (B, S, n_embd)
        for block in self.blocks:
            x = block(x, ~mask) # (batch_size, seq_length, n_embd)
        # negate mask to fill originally False values with -inf later
        logits = self.lm_head(x) # (batch_size, seq_length, vocab_size)

        # this code assumes teacher forcingβ€”β€”for each text of seq length S we have S autoregressive predictions,
        # thus we have B*S logits and B*S targets
        if targets is None:
            loss = None
        else:
            B, S, C = logits.shape
            logits = logits.view(B*S, C)
            targets = targets.view(B*S)
            loss = F.cross_entropy(logits, targets, ignore_index=self.pad_idx)
            # we need to make sure loss ignores the padding token right?
            # this helps it avoid wasting compute on learning PAD -> PAD, etc.

        return logits, loss


    def generate(self, input_ids, method='multinomial',
                 max_new_tokens=1000, temp=None,
                 num_beams=None, p_nucleus=None, k=None):

        # TODO: see Huggingface's .generate() function
        # https://huggingface.co/transformers/v3.4.0/_modules/transformers/generation_utils.html

        if method == 'temperature':
            assert (temp is not None) and (0 < temp) and (temp <= 1)
        # if method == 'num_beams':
        #     assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
        if method == 'top-k':
            assert isinstance(k, int) and (k > 0)

        # input_ids begins as (batch_size, seq_length)

        for _ in range(max_new_tokens):
            if method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k']:
                # i) Truncate to the most recent `max length` tokens
                text_cond = input_ids[:, -self.seq_length:]
                # ii) Retrieve predictions
                logits, loss = self(text_cond) # no loss because no targets ofc
                # model output: (batch_size, seq_length, vocab_size)
                # iii) Find last token logits of each
                logits = logits[:, -1, :] # (batch_size, vocab_size)

                # aside: if temperature sampling, divide logits by temp before applying softmax
                if method == 'temperature':
                    logits = logits / temp

                # iv) Take softmax along each
                probs = F.softmax(logits, dim=-1)

                # v) Sample next token depending on method
                if method == 'greedy':
                    next_idx = probs.argmax(dim=-1).unsqueeze(-1)

                elif method in ['multinomial', 'temperature', 'nucleus', 'top-k']:
                    if method == 'nucleus':
                        assert p_nucleus is not None and (0 < p_nucleus) and (p_nucleus <= 1)

                        sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
                        prob_cumsum = sorted_probs.cumsum(dim=-1)
                        idx_remove = prob_cumsum > p_nucleus
                        # shift one right to ensure the first token is above the threshold
                        idx_remove[..., 1:] = idx_remove[..., :-1].clone()
                        idx_remove[..., 0] = False
                        # retrieve original indices by reverse-sorting
                        remove_mask = idx_remove.gather(dim=-1,
                                          index=sorted_idx.argsort(dim=-1))
                        # ^ specifically, we do this by first argsorting the indices which were returned from argsort. this is crazy y'all
                        # you can show that this returns indices that when used to subset a sorted array, returns the original array in unsorted order
                        # https://stackoverflow.com/questions/52127723/pytorch-better-way-to-get-back-original-tensor-order-after-torch-sort
                        # torch.gather is how we apply a multi-dimensional index
                        # https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms
                        probs[remove_mask] = 0

                    if method == 'top-k':
                        remove_mask = probs < torch.topk(probs, k).values[..., -1, None] # the topk returns (B, 1), leaving only the
                        # kth largest probs (i.e. the cutoff value for each). Then mask is same size as probs (B, vocab_size)
                        probs[remove_mask] = 0

                    # Sample probabilistically via scores
                    next_idx = torch.multinomial(probs, num_samples=1) # (batch_size, 1)

                # vi) Autoregressively append to input_text
                input_ids = torch.cat((input_ids, next_idx), dim=-1)
                # end prematurely if <EOS> generated
                if next_idx == self.eos_token_id:
                  break
                # now input_text = (batch_size, seq_length + 1)

        return input_ids

# END OF Better Transformer Class –––––––––––––––––––––––––––––––––––––––––––––––

def set_seed(seed = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed) # if multi-GPU
    torch.backends.cudnn.deterministic=True # only applies to CUDA convolution operations
    torch.backends.cudnn.benchmark = False
    # usually CuDNN has heuristics as to which algorithm to pick. cudnn.benchmark benchmarks several algorithms and picks the fastest
    # often helpful if your input shapes are fixed and not changing a lot during training
    # however, this means it may pick a different algorithm even when the deterministic flag is set.
    # As such it is good practice to turn off cudnn.benchmark when turning on cudnn.deterministic

def load_tokenizer(device):
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    EMPTY_TOKENS = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)
    return tokenizer, EMPTY_TOKENS


def load_big_model(tokenizer, device):
    ## Model architecture
    set_seed(42)
    N_HEAD = 16
    N_LAYER = 8
    N_EMBD = 768
    VOCAB_SIZE = 50258
    SEQ_LENGTH = 384

    model = BetterTransformer(VOCAB_SIZE, SEQ_LENGTH, N_EMBD, N_HEAD, N_LAYER, tokenizer.pad_token_id, tokenizer.eos_token_id, device=device)
    model.init_params()
    path = MODEL_FILE
    model.load_state_dict(torch.load(path, map_location=device)["model_state_dict"])

    return model
    
def generate(model, tokenizer, device, method=None, k=None, 
    p_nucleus=None, temp=None, max_new_tokens=None, cond="", deterministic=None):
    """
    Wrapper for generating text using the specified model. Generates unconditionally if cond=None.

    Inputs:
      -model: Decoder model to be used for text generation
      -tokenizer: Compatible tokenizer
      -device: Device of model (CPU/CUDA)
      -method (str): Decoding method for text generation ('multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k')
      -k (int): Positive integer for top-k logits to sample if top-k decoding
      -p_nucleus (float/int): Cumulative probability cutoff if nucleus/top-p decoding
      -temp (float/int): Temperature if temperature decoding
      -max_new_tokens (int): Maximum number of tokens to generate
      -cond (str=None): If provided, will serve as conditional prompt for text generation
      -deterministic (int): If deterministic, uses the specified seed for model generation
    Returns:
      -res (str): Generated text string
    """
    assert method in ['multinomial', 'temperature', 'greedy', 'nucleus', 'top-k'], \
        "method must be 'multinomial', 'temperature', 'greedy', 'nucleus', or 'top-k'"

    #if method == 'temperature':
    #    assert (temp is not None) and isinstance(temp, (int, float)) and (0 < temp) and (temp <= 1), \
    #    "temp must be defined as a number between (0, 1]"
    #if method == 'nucleus':
    #    assert (p_nucleus is not None) and isinstance(p_nucleus, (int, float)) and (0 < p_nucleus) and (p_nucleus <= 1), \
    #    "p_nucleus must be defined as a number between (0, 1]"
    ## if method == 'num_beams':
    ##     assert isinstance(num_beams, int) and (num_beams) > 0 and (num_beams) < 100
    #if method == 'top-k':
    #    assert (k is not None) and isinstance(k, int) and (k > 0) and (k < SEQ_LENGTH), \
    #    "k must be defined as an integer greater than 0 and less than the model sequence length"

    #if max_new_tokens is None:
    #    print('No max_new_tokens provided, using a default value of 250\n')
    #    max_new_tokens = 250

    #assert (max_new_tokens is not None) and isinstance(max_new_tokens, int) and (max_new_tokens) > 0 and (max_new_tokens) <= 1000, \
    #"max_new_tokens must be an integer between (0, 1000]"

    if deterministic is not None:
        set_seed(deterministic)

    if cond != "":

        cond_tokens = tokenizer(cond).input_ids

        gen_tokens = model.generate(torch.tensor(cond_tokens).unsqueeze(0).long().to(device),
                                    method=method, k=k, p_nucleus=p_nucleus, temp=temp,
                                    max_new_tokens=max_new_tokens)[0]

        # Insert delimiter to indicate where prompt ends
        gen_prep = torch.zeros(len(gen_tokens)+2).long() # make space for two more tokens for delimiter
        gen_prep -= 1
        gen_prep[:len(cond_tokens)] = gen_tokens[:len(cond_tokens)]
        gen_prep[-(len(gen_tokens)-len(cond_tokens)):] = gen_tokens[-(len(gen_tokens)-len(cond_tokens)):]
        gen_prep[gen_prep == -1] = torch.tensor(tokenizer.encode(' || ')) # insert tokens for || in between

        res = tokenizer.decode(gen_prep)
        res = re.sub(re.escape(tokenizer.bos_token), '', res, count=1) ## Remove end token
        

    else:
        empty_tokens = torch.full((1,1), tokenizer.bos_token_id, dtype=torch.long).to(device)

        res = tokenizer.batch_decode(model.generate(empty_tokens,
                                                    method=method, k=k,
                                                    p_nucleus=p_nucleus, temp=temp,
                                                    max_new_tokens=max_new_tokens))[0]

        res = re.sub(re.escape(tokenizer.bos_token), '', res, count=2) ## Remove start and end tokens

    # Clean up Unicode character issues
    # 'Ò€œ' then 'Ò€' = opening and closing double quotes
    # 'Ò€ℒ' = apostrophe
    res = re.sub(r'Ò€œ', '"', res)
    res = re.sub(r'Ò€ℒ', "'", res)
    res = re.sub(r'Ò€', '"', res)
    res = res + " [END]" ## better end token
    return res