arcAman07 commited on
Commit
4f3596b
·
1 Parent(s): b2f8857

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -84
train.py DELETED
@@ -1,84 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import matplotlib.pyplot as plt
5
- from model import Transformer
6
-
7
- with open('/Users/deepaksharma/Documents/Python/Kaggle/GenerateKanyeLyrics/Kanye West Lyrics.txt','r',encoding='utf-8') as f:
8
- text = f.read()
9
-
10
- chars = sorted(list(set(text)))
11
-
12
- stoi = {ch:i for i,ch in enumerate(chars)}
13
- itos = {i:ch for i,ch in enumerate(chars)}
14
-
15
- encode = lambda s: [stoi[c] for c in s]
16
- decode = lambda l: ''.join([itos[c] for c in l])
17
-
18
- data = torch.tensor(encode(text), dtype=torch.long)
19
-
20
- n = int(0.9*len(text))
21
- train_data = data[:n]
22
- val_data = data[n:]
23
-
24
- def get_batch(split):
25
- if split == 'train':
26
- data = train_data
27
- elif split == 'val':
28
- data = val_data
29
- else:
30
- raise ValueError("Invalid split")
31
-
32
- ix = torch.randint(len(data)-block_size,(batch_size,))
33
- x = torch.stack([data[i:i+block_size] for i in ix])
34
- y = torch.stack([data[i+1:i+block_size+1] for i in ix])
35
- return x, y
36
-
37
- # hyperparameters
38
- batch_size = 16 # how many independent sequences will we process in parallel?
39
- block_size = 64 # what is the maximum context length for predictions?
40
- max_iters = 5000
41
- eval_interval = 100
42
- learning_rate = 1e-3
43
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
- eval_iters = 200
45
- n_embd = 128
46
- n_head = 8
47
- n_layer = 4
48
- dropout = 0.0
49
- vocab = len(chars)
50
- # ------------
51
-
52
-
53
- model = Transformer(n_embd,n_layer)
54
-
55
- print("Total params: ", sum(p.numel() for p in model.parameters()))
56
-
57
- optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
58
-
59
- for steps in range(20000):
60
- x,y = get_batch('train')
61
- logits, loss = model(x, y)
62
- optimizer.zero_grad()
63
- loss.backward()
64
- optimizer.step()
65
- if steps % 100 == 0:
66
- print("Step: ", steps, " Loss: ", loss.item())
67
-
68
- # Print model's state_dict
69
- print("Model's state_dict:")
70
- for param_tensor in model.state_dict():
71
- print(param_tensor, "\t", model.state_dict()[param_tensor].size())
72
-
73
- # Print optimizer's state_dict
74
- print("Optimizer's state_dict:")
75
- for var_name in optimizer.state_dict():
76
- print(var_name, "\t", optimizer.state_dict()[var_name])
77
-
78
- torch.save(model.state_dict(), 'kanye_weights.pth')
79
-
80
- lyrics = encode("Bitch I am back on my comma , sipping on my CocaCola, driving on a hangover ")
81
- lyrics = torch.tensor(lyrics, dtype=torch.long)
82
- lyrics = torch.stack([lyrics for _ in range(1)], dim=0)
83
-
84
- print(decode(model.generate(lyrics, max_tokens=1000)[0].tolist()))